From 8099d6465caa26ce3f8ec4e638f989058d4c93f7 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:44:21 -0400 Subject: [PATCH] Clarify guarantees around remote channnel/request attribute presence --- .../AbstractAuthenticationInterceptor.java | 22 +- .../ProhibitAuthenticationInterceptor.java | 20 +- .../RequireAuthenticationInterceptor.java | 23 ++- .../filters/RemoteDeprecationFilter.java | 11 +- .../grpc/ChannelNotFoundException.java | 12 ++ .../grpc/MessagesAnonymousGrpcService.java | 2 +- .../grpc/MessagesGrpcHelper.java | 2 +- .../textsecuregcm/grpc/RequestAttributes.java | 16 ++ .../grpc/RequestAttributesInterceptor.java | 71 ++----- .../grpc/RequestAttributesUtil.java | 30 +-- .../grpc/ServerInterceptorUtil.java | 39 ++++ .../EstablishLocalGrpcConnectionHandler.java | 8 +- .../grpc/net/GrpcClientConnectionManager.java | 164 +++++++-------- .../WebsocketHandshakeCompleteHandler.java | 2 +- ...ProhibitAuthenticationInterceptorTest.java | 9 +- .../RequireAuthenticationInterceptorTest.java | 9 +- .../filters/ExternalRequestFilterTest.java | 5 +- .../filters/RemoteDeprecationFilterTest.java | 10 +- .../AccountsAnonymousGrpcServiceTest.java | 3 +- .../MockRequestAttributesInterceptor.java | 39 +--- .../grpc/ProfileAnonymousGrpcServiceTest.java | 13 +- .../grpc/ProfileGrpcServiceTest.java | 13 +- .../grpc/RequestAttributesServiceImpl.java | 14 +- .../grpc/RequestAttributesUtilTest.java | 186 +++++------------ .../net/GrpcClientConnectionManagerTest.java | 192 +++++++----------- ...eWebSocketTunnelServerIntegrationTest.java | 5 +- ...WebsocketHandshakeCompleteHandlerTest.java | 46 +++-- .../proto/request_attributes_service.proto | 9 +- 28 files changed, 405 insertions(+), 570 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelNotFoundException.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributes.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/ServerInterceptorUtil.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptor.java index 38c59e246..93530a002 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptor.java @@ -1,34 +1,22 @@ package org.whispersystems.textsecuregcm.auth.grpc; -import io.grpc.Grpc; -import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerInterceptor; -import io.grpc.Status; -import io.netty.channel.local.LocalAddress; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import java.util.Optional; +import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; abstract class AbstractAuthenticationInterceptor implements ServerInterceptor { private final GrpcClientConnectionManager grpcClientConnectionManager; - private static final Metadata EMPTY_TRAILERS = new Metadata(); - AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { this.grpcClientConnectionManager = grpcClientConnectionManager; } - protected Optional getAuthenticatedDevice(final ServerCall call) { - if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { - return grpcClientConnectionManager.getAuthenticatedDevice(localAddress); - } else { - throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); - } - } + protected Optional getAuthenticatedDevice(final ServerCall call) + throws ChannelNotFoundException { - protected ServerCall.Listener closeAsUnauthenticated(final ServerCall call) { - call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS); - return new ServerCall.Listener<>() {}; + return grpcClientConnectionManager.getAuthenticatedDevice(call); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java index 5a44ebe3f..b14465d51 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java @@ -3,12 +3,17 @@ package org.whispersystems.textsecuregcm.auth.grpc; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.Status; +import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; +import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; /** * A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not * originate from a channel that is associated with an authenticated device. Calls with an associated authenticated - * device are closed with an {@code UNAUTHENTICATED} status. + * device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined + * (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will + * reject the call with a status of {@code UNAVAILABLE}. */ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor { @@ -21,8 +26,15 @@ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInt final Metadata headers, final ServerCallHandler next) { - return getAuthenticatedDevice(call) - .map(ignored -> closeAsUnauthenticated(call)) - .orElseGet(() -> next.startCall(call, headers)); + try { + return getAuthenticatedDevice(call) + // Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited + // service via an authenticated connection, then that's actually a server configuration issue and not a + // problem with the client's request. + .map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL)) + .orElseGet(() -> next.startCall(call, headers)); + } catch (final ChannelNotFoundException e) { + return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE); + } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java index c24b0af77..f03052bab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java @@ -5,12 +5,16 @@ import io.grpc.Contexts; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.Status; +import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; +import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; /** * A "require authentication" interceptor requires that requests be issued from a connection that is associated with an * authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED} - * status. + * status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed + * before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}. */ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor { @@ -23,10 +27,17 @@ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInte final Metadata headers, final ServerCallHandler next) { - return getAuthenticatedDevice(call) - .map(authenticatedDevice -> Contexts.interceptCall(Context.current() - .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), - call, headers, next)) - .orElseGet(() -> closeAsUnauthenticated(call)); + try { + return getAuthenticatedDevice(call) + .map(authenticatedDevice -> Contexts.interceptCall(Context.current() + .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), + call, headers, next)) + // Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required + // service via an unauthenticated connection, then that's actually a server configuration issue and not a + // problem with the client's request. + .orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL)); + } catch (final ChannelNotFoundException e) { + return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE); + } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java index b48709923..3aea5b38c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java @@ -81,7 +81,16 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor { final Metadata headers, final ServerCallHandler next) { - if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) { + @Nullable final UserAgent userAgent = RequestAttributesUtil.getUserAgent() + .map(userAgentString -> { + try { + return UserAgentUtil.parseUserAgentString(userAgentString); + } catch (final UnrecognizedUserAgentException e) { + return null; + } + }).orElse(null); + + if (shouldBlock(userAgent)) { call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata()); return new ServerCall.Listener<>() {}; } else { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelNotFoundException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelNotFoundException.java new file mode 100644 index 000000000..29211b907 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelNotFoundException.java @@ -0,0 +1,12 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +/** + * Indicates that a remote channel was not found for a given server call or remote address. + */ +public class ChannelNotFoundException extends Exception { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java index b6b5265cf..c66ea8a92 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java @@ -253,7 +253,7 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me story, ephemeral, urgent, - RequestAttributesUtil.getRawUserAgent().orElse(null)); + RequestAttributesUtil.getUserAgent().orElse(null)); final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java index 1e1f73fd0..cabeca5be 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java @@ -55,7 +55,7 @@ public class MessagesGrpcHelper { messagesByDeviceId, registrationIdsByDeviceId, syncMessageSenderDeviceId, - RequestAttributesUtil.getRawUserAgent().orElse(null)); + RequestAttributesUtil.getUserAgent().orElse(null)); return SEND_MESSAGE_SUCCESS_RESPONSE; } catch (final MismatchedDevicesException e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributes.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributes.java new file mode 100644 index 000000000..23ba732eb --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributes.java @@ -0,0 +1,16 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import java.net.InetAddress; +import java.util.List; +import java.util.Locale; +import javax.annotation.Nullable; + +public record RequestAttributes(InetAddress remoteAddress, + @Nullable String userAgent, + List acceptLanguage) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java index 96b8ef9f8..b4fb0d169 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java @@ -2,28 +2,25 @@ package org.whispersystems.textsecuregcm.grpc; import io.grpc.Context; import io.grpc.Contexts; -import io.grpc.Grpc; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; -import io.netty.channel.local.LocalAddress; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.util.ua.UserAgent; -import java.net.InetAddress; -import java.util.List; -import java.util.Locale; -import java.util.Optional; +/** + * The request attributes interceptor makes request attributes from the underlying remote channel available to service + * implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}. + * All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if + * request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started). + * + * @see RequestAttributesUtil + */ public class RequestAttributesInterceptor implements ServerInterceptor { private final GrpcClientConnectionManager grpcClientConnectionManager; - private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class); - public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { this.grpcClientConnectionManager = grpcClientConnectionManager; } @@ -33,52 +30,12 @@ public class RequestAttributesInterceptor implements ServerInterceptor { final Metadata headers, final ServerCallHandler next) { - if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { - Context context = Context.current(); - - { - final Optional maybeRemoteAddress = grpcClientConnectionManager.getRemoteAddress(localAddress); - - if (maybeRemoteAddress.isEmpty()) { - // We should never have a call from a party whose remote address we can't identify - log.warn("No remote address available"); - - call.close(Status.INTERNAL, new Metadata()); - return new ServerCall.Listener<>() {}; - } - - context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get()); - } - - { - final Optional> maybeAcceptLanguage = - grpcClientConnectionManager.getAcceptableLanguages(localAddress); - - if (maybeAcceptLanguage.isPresent()) { - context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get()); - } - } - - { - final Optional maybeRawUserAgent = - grpcClientConnectionManager.getRawUserAgent(localAddress); - - if (maybeRawUserAgent.isPresent()) { - context = context.withValue(RequestAttributesUtil.RAW_USER_AGENT_CONTEXT_KEY, maybeRawUserAgent.get()); - } - } - - { - final Optional maybeUserAgent = grpcClientConnectionManager.getUserAgent(localAddress); - - if (maybeUserAgent.isPresent()) { - context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get()); - } - } - - return Contexts.interceptCall(context, call, headers, next); - } else { - throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); + try { + return Contexts.interceptCall(Context.current() + .withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, + grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next); + } catch (final ChannelNotFoundException e) { + return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE); } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtil.java index 76d3109a1..25673cfbf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtil.java @@ -3,18 +3,13 @@ package org.whispersystems.textsecuregcm.grpc; import io.grpc.Context; import java.net.InetAddress; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Optional; -import org.whispersystems.textsecuregcm.util.ua.UserAgent; public class RequestAttributesUtil { - static final Context.Key> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language"); - static final Context.Key REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address"); - static final Context.Key RAW_USER_AGENT_CONTEXT_KEY = Context.key("unparsed-user-agent"); - static final Context.Key USER_AGENT_CONTEXT_KEY = Context.key("parsed-user-agent"); + static final Context.Key REQUEST_ATTRIBUTES_CONTEXT_KEY = Context.key("request-attributes"); private static final List AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales()); @@ -23,8 +18,8 @@ public class RequestAttributesUtil { * * @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified */ - public static Optional> getAcceptableLanguages() { - return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get()); + public static List getAcceptableLanguages() { + return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().acceptLanguage(); } /** @@ -35,9 +30,7 @@ public class RequestAttributesUtil { * @return a list of distinct locales acceptable to the remote client and available in this JVM */ public static List getAvailableAcceptedLocales() { - return getAcceptableLanguages() - .map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES)) - .orElseGet(Collections::emptyList); + return Locale.filter(getAcceptableLanguages(), AVAILABLE_LOCALES); } /** @@ -46,16 +39,7 @@ public class RequestAttributesUtil { * @return the remote address of the remote client */ public static InetAddress getRemoteAddress() { - return REMOTE_ADDRESS_CONTEXT_KEY.get(); - } - - /** - * Returns the parsed user-agent of the remote client in the current gRPC request context. - * - * @return the parsed user-agent of the remote client; may be empty if unparseable or not specified - */ - public static Optional getUserAgent() { - return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get()); + return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().remoteAddress(); } /** @@ -63,7 +47,7 @@ public class RequestAttributesUtil { * * @return the unparsed user-agent of the remote client; may be empty if not specified */ - public static Optional getRawUserAgent() { - return Optional.ofNullable(RAW_USER_AGENT_CONTEXT_KEY.get()); + public static Optional getUserAgent() { + return Optional.ofNullable(REQUEST_ATTRIBUTES_CONTEXT_KEY.get().userAgent()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ServerInterceptorUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ServerInterceptorUtil.java new file mode 100644 index 000000000..816319b4c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ServerInterceptorUtil.java @@ -0,0 +1,39 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.Status; + +public class ServerInterceptorUtil { + + @SuppressWarnings("rawtypes") + private static final ServerCall.Listener NO_OP_LISTENER = new ServerCall.Listener<>() {}; + + private static final Metadata EMPTY_TRAILERS = new Metadata(); + + private ServerInterceptorUtil() { + } + + /** + * Closes the given server call with the given status, returning a no-op listener. + * + * @param call the server call to close + * @param status the status with which to close the call + * + * @return a no-op server call listener + * + * @param the type of request object handled by the server call + * @param the type of response object returned by the server call + */ + public static ServerCall.Listener closeWithStatus(final ServerCall call, final Status status) { + call.close(status, EMPTY_TRAILERS); + + //noinspection unchecked + return NO_OP_LISTENER; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java index 2c95252d6..c5433358c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java @@ -12,8 +12,10 @@ import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; import io.netty.util.ReferenceCountUtil; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; /** * An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering @@ -48,12 +50,12 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { @Override public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) { - if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) { + if (event instanceof NoiseIdentityDeterminedEvent(final Optional authenticatedDevice)) { // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the // authenticated service. - final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent() + final LocalAddress grpcServerAddress = authenticatedDevice.isPresent() ? authenticatedGrpcServerAddress : anonymousGrpcServerAddress; @@ -72,7 +74,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { if (localChannelFuture.isSuccess()) { grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(), remoteChannelContext.channel(), - noiseIdentityDeterminedEvent.authenticatedDevice()); + authenticatedDevice); // Close the local connection if the remote channel closes and vice versa remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java index 50c32086b..f3eec095c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java @@ -1,6 +1,8 @@ package org.whispersystems.textsecuregcm.grpc.net; import com.google.common.annotations.VisibleForTesting; +import io.grpc.Grpc; +import io.grpc.ServerCall; import io.netty.channel.Channel; import io.netty.channel.ChannelFutureListener; import io.netty.channel.local.LocalAddress; @@ -23,15 +25,25 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; -import org.whispersystems.textsecuregcm.util.ua.UserAgent; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; +import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; +import org.whispersystems.textsecuregcm.grpc.RequestAttributes; /** * A client connection manager associates a local connection to a local gRPC server with a remote connection through a - * Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the - * authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close - * connections associated with a given device if that device's credentials have changed and clients must reauthenticate. + * Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated + * identity of the device that opened the connection (for non-anonymous connections). It can also close connections + * associated with a given device if that device's credentials have changed and clients must reauthenticate. + *

+ * In general, all {@link ServerCall}s must have a local address that in turn should be resolvable to + * a remote channel, which must have associated request attributes and authentication status. It is possible + * that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the + * narrow window between a server call being created and the start of call execution, in which case accessor methods + * in this class will throw a {@link ChannelNotFoundException}. + *

+ * A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to + * identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s. + * Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may + * be called from any application code. */ public class GrpcClientConnectionManager implements DisconnectionRequestListener { @@ -43,94 +55,56 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice"); @VisibleForTesting - static final AttributeKey REMOTE_ADDRESS_ATTRIBUTE_KEY = - AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress"); - - @VisibleForTesting - static final AttributeKey RAW_USER_AGENT_ATTRIBUTE_KEY = - AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent"); - - @VisibleForTesting - static final AttributeKey PARSED_USER_AGENT_ATTRIBUTE_KEY = - AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent"); - - @VisibleForTesting - static final AttributeKey> ACCEPT_LANGUAGE_ATTRIBUTE_KEY = - AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage"); + public static final AttributeKey REQUEST_ATTRIBUTES_KEY = + AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes"); private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class); /** - * Returns the authenticated device associated with the given local address, if any. An authenticated device is - * available if and only if the given local address maps to an active local connection and that connection is - * authenticated (i.e. not anonymous). + * Returns the authenticated device associated with the given server call, if any. If the connection is anonymous + * (i.e. unauthenticated), the returned value will be empty. * - * @param localAddress the local address for which to find an authenticated device + * @param serverCall the gRPC server call for which to find an authenticated device * * @return the authenticated device associated with the given local address, if any + * + * @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this + * generally indicates that the channel has closed while request processing is still in progress */ - public Optional getAuthenticatedDevice(final LocalAddress localAddress) { - return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress)); + public Optional getAuthenticatedDevice(final ServerCall serverCall) + throws ChannelNotFoundException { + + return getAuthenticatedDevice(getRemoteChannel(serverCall)); } - private Optional getAuthenticatedDevice(@Nullable final Channel remoteChannel) { - return Optional.ofNullable(remoteChannel) - .map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get()); + @VisibleForTesting + Optional getAuthenticatedDevice(final Channel remoteChannel) { + return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get()); } /** - * Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may - * be unavailable if the local connection associated with the given local address has already closed, if the client - * did not provide a list of acceptable languages, or the list provided by the client could not be parsed. + * Returns the request attributes associated with the given server call. * - * @param localAddress the local address for which to find acceptable languages + * @param serverCall the gRPC server call for which to retrieve request attributes * - * @return the acceptable languages associated with the given local address, if any + * @return the request attributes associated with the given server call + * + * @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this + * generally indicates that the channel has closed while request processing is still in progress */ - public Optional> getAcceptableLanguages(final LocalAddress localAddress) { - return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) - .map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get()); + public RequestAttributes getRequestAttributes(final ServerCall serverCall) throws ChannelNotFoundException { + return getRequestAttributes(getRemoteChannel(serverCall)); } - /** - * Returns the remote address associated with the given local address, if any. A remote address may be unavailable if - * the local connection associated with the given local address has already closed. - * - * @param localAddress the local address for which to find a remote address - * - * @return the remote address associated with the given local address, if any - */ - public Optional getRemoteAddress(final LocalAddress localAddress) { - return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) - .map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); - } + @VisibleForTesting + RequestAttributes getRequestAttributes(final Channel remoteChannel) { + final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get(); - /** - * Returns the unparsed user agent provided by the client that opened the connection associated with the given local - * address. This method may return an empty value if no active local connection is associated with the given local - * address. - * - * @param localAddress the local address for which to find a User-Agent string - * - * @return the user agent string associated with the given local address - */ - public Optional getRawUserAgent(final LocalAddress localAddress) { - return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) - .map(remoteChannel -> remoteChannel.attr(RAW_USER_AGENT_ATTRIBUTE_KEY).get()); - } + if (requestAttributes == null) { + throw new IllegalStateException("Channel does not have request attributes"); + } - /** - * Returns the parsed user agent provided by the client that opened the connection associated with the given local - * address. This method may return an empty value if no active local connection is associated with the given local - * address or if the client's user-agent string was not recognized. - * - * @param localAddress the local address for which to find a User-Agent string - * - * @return the user agent associated with the given local address - */ - public Optional getUserAgent(final LocalAddress localAddress) { - return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) - .map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get()); + return requestAttributes; } /** @@ -156,11 +130,32 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice); } + private Channel getRemoteChannel(final ServerCall serverCall) throws ChannelNotFoundException { + return getRemoteChannel(getLocalAddress(serverCall)); + } + @VisibleForTesting - Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) { + Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException { + final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress); + + if (remoteChannel == null) { + throw new ChannelNotFoundException(); + } + return remoteChannelsByLocalAddress.get(localAddress); } + private static LocalAddress getLocalAddress(final ServerCall serverCall) { + // In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel. + // The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to + // a proxied Noise channel. + if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) { + throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); + } + + return localAddress; + } + /** * Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake * request with the channel via which the handshake took place. @@ -171,30 +166,23 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener * @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be * {@code null} */ - static void handleWebSocketHandshakeComplete(final Channel channel, + static void handleHandshakeComplete(final Channel channel, final InetAddress preferredRemoteAddress, @Nullable final String userAgentHeader, @Nullable final String acceptLanguageHeader) { - channel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress); - - if (StringUtils.isNotBlank(userAgentHeader)) { - channel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader); - - try { - channel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY) - .set(UserAgentUtil.parseUserAgentString(userAgentHeader)); - } catch (final UnrecognizedUserAgentException ignored) { - } - } + @Nullable List acceptLanguages = Collections.emptyList(); if (StringUtils.isNotBlank(acceptLanguageHeader)) { try { - channel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader)); + acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader); } catch (final IllegalArgumentException e) { log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e); } } + + channel.attr(REQUEST_ATTRIBUTES_KEY) + .set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages)); } /** diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java index 87b09d486..f847314b0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java @@ -74,7 +74,7 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { preferredRemoteAddress = maybePreferredRemoteAddress.get(); } - GrpcClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(), + GrpcClientConnectionManager.handleHandshakeComplete(context.channel(), preferredRemoteAddress, handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT), handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java index 2a02d9dae..a0d8c6688 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java @@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.auth.grpc; import io.grpc.Status; import org.junit.jupiter.api.Test; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; +import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -22,7 +23,7 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc } @Test - void interceptCall() { + void interceptCall() throws ChannelNotFoundException { final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); @@ -34,6 +35,10 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); - GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); + GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice); + + when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class); + + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java index e30d1735a..442b38cc8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java @@ -9,6 +9,7 @@ import java.util.Optional; import java.util.UUID; import org.junit.jupiter.api.Test; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; +import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -22,12 +23,12 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce } @Test - void interceptCall() { + void interceptCall() throws ChannelNotFoundException { final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); - GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); + GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); @@ -35,5 +36,9 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice(); assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier()); assertEquals(authenticatedDevice.deviceId(), response.getDeviceId()); + + when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class); + + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/ExternalRequestFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/ExternalRequestFilterTest.java index 152ab62e7..e3eb3240f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/filters/ExternalRequestFilterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/ExternalRequestFilterTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.filters; import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.common.net.InetAddresses; import com.google.protobuf.ByteString; import io.dropwizard.core.Application; import io.dropwizard.core.Configuration; @@ -24,7 +25,6 @@ import jakarta.ws.rs.GET; import jakarta.ws.rs.Path; import jakarta.ws.rs.client.Client; import jakarta.ws.rs.core.Response; -import java.net.InetAddress; import java.util.Collections; import java.util.EnumSet; import java.util.Set; @@ -39,6 +39,7 @@ import org.signal.chat.rpc.EchoServiceGrpc; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; +import org.whispersystems.textsecuregcm.grpc.RequestAttributes; import org.whispersystems.textsecuregcm.util.InetAddressRange; @ExtendWith(DropwizardExtensionsSupport.class) @@ -157,7 +158,7 @@ class ExternalRequestFilterTest { @BeforeEach void setUp() throws Exception { final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); - mockRequestAttributesInterceptor.setRemoteAddress(InetAddress.getByName("127.0.0.1")); + mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null)); testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest") .directExecutor() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilterTest.java index 12a08dbdd..54422b47f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilterTest.java @@ -15,6 +15,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.net.HttpHeaders; +import com.google.common.net.InetAddresses; import com.google.protobuf.ByteString; import com.vdurmont.semver4j.Semver; import io.grpc.ManagedChannel; @@ -40,11 +41,10 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; +import org.whispersystems.textsecuregcm.grpc.RequestAttributes; import org.whispersystems.textsecuregcm.grpc.StatusConstants; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; -import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; class RemoteDeprecationFilterTest { @@ -130,11 +130,7 @@ class RemoteDeprecationFilterTest { @MethodSource(value="testFilter") void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException { final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); - - try { - mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString)); - } catch (UnrecognizedUserAgentException ignored) { - } + mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), userAgentString, null)); final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest") .directExecutor() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsAnonymousGrpcServiceTest.java index 9c9b24646..12eeec8ba 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsAnonymousGrpcServiceTest.java @@ -72,7 +72,8 @@ class AccountsAnonymousGrpcServiceTest extends when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); - getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1")); + getMockRequestAttributesInterceptor().setRequestAttributes( + new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null)); return new AccountsAnonymousGrpcService(accountsManager, rateLimiters); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRequestAttributesInterceptor.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRequestAttributesInterceptor.java index 7d662f3a0..01172355a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRequestAttributesInterceptor.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRequestAttributesInterceptor.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.grpc; +import com.google.common.net.InetAddresses; import io.grpc.Context; import io.grpc.Contexts; import io.grpc.Metadata; @@ -19,25 +20,10 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgent; public class MockRequestAttributesInterceptor implements ServerInterceptor { - @Nullable - private InetAddress remoteAddress; + private RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null); - @Nullable - private UserAgent userAgent; - - @Nullable - private List acceptLanguage; - - public void setRemoteAddress(@Nullable final InetAddress remoteAddress) { - this.remoteAddress = remoteAddress; - } - - public void setUserAgent(@Nullable final UserAgent userAgent) { - this.userAgent = userAgent; - } - - public void setAcceptLanguage(@Nullable final List acceptLanguage) { - this.acceptLanguage = acceptLanguage; + public void setRequestAttributes(final RequestAttributes requestAttributes) { + this.requestAttributes = requestAttributes; } @Override @@ -45,20 +31,7 @@ public class MockRequestAttributesInterceptor implements ServerInterceptor { final Metadata headers, final ServerCallHandler next) { - Context context = Context.current(); - - if (remoteAddress != null) { - context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress); - } - - if (userAgent != null) { - context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent); - } - - if (acceptLanguage != null) { - context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage); - } - - return Contexts.interceptCall(context, serverCall, headers, next); + return Contexts.interceptCall(Context.current() + .withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes), serverCall, headers, next); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java index b57edfe0b..c9ef6a359 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java @@ -15,6 +15,7 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; +import com.google.common.net.InetAddresses; import com.google.protobuf.ByteString; import io.grpc.Status; import java.nio.charset.StandardCharsets; @@ -75,8 +76,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil; -import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest { @@ -96,13 +95,9 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest - acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()))); + RequestAttributesUtil.getAcceptableLanguages() + .forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString())); RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale -> responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag())); responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress()); - RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder() - .setPlatform(userAgent.platform().toString()) - .setVersion(userAgent.version().toString()) - .setAdditionalSpecifiers(StringUtils.stripToEmpty(userAgent.additionalSpecifiers())) - .build())); - - RequestAttributesUtil.getRawUserAgent().ifPresent(responseBuilder::setRawUserAgent); + RequestAttributesUtil.getUserAgent().ifPresent(responseBuilder::setUserAgent); responseObserver.onNext(responseBuilder.build()); responseObserver.onCompleted(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtilTest.java index 3bbcb22be..1564e94c0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtilTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtilTest.java @@ -3,172 +3,84 @@ package org.whispersystems.textsecuregcm.grpc; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import com.google.common.net.InetAddresses; -import io.grpc.ManagedChannel; -import io.grpc.Server; -import io.grpc.Status; -import io.grpc.netty.NettyChannelBuilder; -import io.grpc.netty.NettyServerBuilder; -import io.netty.channel.DefaultEventLoopGroup; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.channel.local.LocalServerChannel; -import java.io.IOException; +import io.grpc.Context; +import java.net.InetAddress; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; +import java.util.concurrent.Callable; +import javax.annotation.Nullable; import org.junit.jupiter.api.Test; -import org.signal.chat.rpc.GetRequestAttributesRequest; -import org.signal.chat.rpc.GetRequestAttributesResponse; -import org.signal.chat.rpc.RequestAttributesGrpc; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; -import org.whispersystems.textsecuregcm.util.ua.UserAgent; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; class RequestAttributesUtilTest { - private static DefaultEventLoopGroup eventLoopGroup; + private static final InetAddress REMOTE_ADDRESS = InetAddresses.forString("127.0.0.1"); - private GrpcClientConnectionManager grpcClientConnectionManager; + @Test + void getAcceptableLanguages() throws Exception { + assertEquals(Collections.emptyList(), + callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()), + RequestAttributesUtil::getAcceptableLanguages)); - private Server server; - private ManagedChannel managedChannel; - - @BeforeAll - static void setUpBeforeAll() { - eventLoopGroup = new DefaultEventLoopGroup(); - } - - @BeforeEach - void setUp() throws IOException { - final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server"); - - grpcClientConnectionManager = mock(GrpcClientConnectionManager.class); - - when(grpcClientConnectionManager.getRemoteAddress(any())) - .thenReturn(Optional.of(InetAddresses.forString("127.0.0.1"))); - - // `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make - // sure that we're using local channels and addresses - server = NettyServerBuilder.forAddress(serverAddress) - .channelType(LocalServerChannel.class) - .bossEventLoopGroup(eventLoopGroup) - .workerEventLoopGroup(eventLoopGroup) - .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) - .addService(new RequestAttributesServiceImpl()) - .build() - .start(); - - managedChannel = NettyChannelBuilder.forAddress(serverAddress) - .channelType(LocalChannel.class) - .eventLoopGroup(eventLoopGroup) - .usePlaintext() - .build(); - } - - @AfterEach - void tearDown() { - managedChannel.shutdown(); - server.shutdown(); - } - - @AfterAll - static void tearDownAfterAll() throws InterruptedException { - eventLoopGroup.shutdownGracefully().await(); + assertEquals(Locale.LanguageRange.parse("en,ja"), + callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")), + RequestAttributesUtil::getAcceptableLanguages)); } @Test - void getAcceptableLanguages() { - when(grpcClientConnectionManager.getAcceptableLanguages(any())) - .thenReturn(Optional.empty()); + void getAvailableAcceptedLocales() throws Exception { + assertEquals(Collections.emptyList(), + callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()), + RequestAttributesUtil::getAvailableAcceptedLocales)); - assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty()); + final List availableAcceptedLocales = + callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")), + RequestAttributesUtil::getAvailableAcceptedLocales); - when(grpcClientConnectionManager.getAcceptableLanguages(any())) - .thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja"))); + assertFalse(availableAcceptedLocales.isEmpty()); - assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList()); + availableAcceptedLocales.forEach(locale -> + assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()))); } @Test - void getAvailableAcceptedLocales() { - when(grpcClientConnectionManager.getAcceptableLanguages(any())) - .thenReturn(Optional.empty()); - - assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty()); - - when(grpcClientConnectionManager.getAcceptableLanguages(any())) - .thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja"))); - - final GetRequestAttributesResponse response = getRequestAttributes(); - - assertFalse(response.getAvailableAcceptedLocalesList().isEmpty()); - response.getAvailableAcceptedLocalesList().forEach(languageTag -> { - final Locale locale = Locale.forLanguageTag(languageTag); - assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage())); - }); + void getRemoteAddress() throws Exception { + assertEquals(REMOTE_ADDRESS, + callWithRequestAttributes(new RequestAttributes(REMOTE_ADDRESS, null, null), + RequestAttributesUtil::getRemoteAddress)); } @Test - void getRemoteAddress() { - when(grpcClientConnectionManager.getRemoteAddress(any())) - .thenReturn(Optional.empty()); + void getUserAgent() throws Exception { + assertEquals(Optional.empty(), + callWithRequestAttributes(buildRequestAttributes((String) null), + RequestAttributesUtil::getUserAgent)); - GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes); - - final String remoteAddressString = "6.7.8.9"; - - when(grpcClientConnectionManager.getRemoteAddress(any())) - .thenReturn(Optional.of(InetAddresses.forString(remoteAddressString))); - - assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress()); + assertEquals(Optional.of("Signal-Desktop/1.2.3 Linux"), + callWithRequestAttributes(buildRequestAttributes("Signal-Desktop/1.2.3 Linux"), + RequestAttributesUtil::getUserAgent)); } - @Test - void getUserAgent() throws UnrecognizedUserAgentException { - when(grpcClientConnectionManager.getUserAgent(any())) - .thenReturn(Optional.empty()); - - assertFalse(getRequestAttributes().hasUserAgent()); - - final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux"); - - when(grpcClientConnectionManager.getUserAgent(any())) - .thenReturn(Optional.of(userAgent)); - - final GetRequestAttributesResponse response = getRequestAttributes(); - assertTrue(response.hasUserAgent()); - assertEquals("DESKTOP", response.getUserAgent().getPlatform()); - assertEquals("1.2.3", response.getUserAgent().getVersion()); - assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers()); + private static V callWithRequestAttributes(final RequestAttributes requestAttributes, final Callable callable) throws Exception { + return Context.current() + .withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes) + .call(callable); } - @Test - void getRawUserAgent() { - when(grpcClientConnectionManager.getRawUserAgent(any())) - .thenReturn(Optional.empty()); - - assertTrue(getRequestAttributes().getRawUserAgent().isBlank()); - - final String userAgentString = "Signal-Desktop/1.2.3 Linux"; - - when(grpcClientConnectionManager.getRawUserAgent(any())) - .thenReturn(Optional.of(userAgentString)); - - assertEquals(userAgentString, getRequestAttributes().getRawUserAgent()); + private static RequestAttributes buildRequestAttributes(final String userAgent) { + return buildRequestAttributes(userAgent, Collections.emptyList()); } - private GetRequestAttributesResponse getRequestAttributes() { - return RequestAttributesGrpc.newBlockingStub(managedChannel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()); + private static RequestAttributes buildRequestAttributes(final List acceptLanguage) { + return buildRequestAttributes(null, acceptLanguage); + } + + private static RequestAttributes buildRequestAttributes(@Nullable final String userAgent, + final List acceptLanguage) { + + return new RequestAttributes(REMOTE_ADDRESS, userAgent, acceptLanguage); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java index ef132bf25..1cba6d4b1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java @@ -1,7 +1,11 @@ package org.whispersystems.textsecuregcm.grpc.net; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + import com.google.common.net.InetAddresses; -import com.vdurmont.semver4j.Semver; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.Channel; @@ -12,6 +16,12 @@ import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalServerChannel; +import java.net.InetAddress; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.UUID; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -21,20 +31,9 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; +import org.whispersystems.textsecuregcm.grpc.RequestAttributes; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; -import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; -import org.whispersystems.textsecuregcm.util.ua.UserAgent; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; - -import javax.annotation.Nullable; -import java.net.InetAddress; -import java.util.List; -import java.util.Locale; -import java.util.Optional; -import java.util.UUID; - -import static org.junit.jupiter.api.Assertions.*; class GrpcClientConnectionManagerTest { @@ -103,7 +102,7 @@ class GrpcClientConnectionManagerTest { grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice); assertEquals(maybeAuthenticatedDevice, - grpcClientConnectionManager.getAuthenticatedDevice(localChannel.localAddress())); + grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel)); } private static List> getAuthenticatedDevice() { @@ -114,170 +113,115 @@ class GrpcClientConnectionManagerTest { } @Test - void getAcceptableLanguages() { + void getRequestAttributes() { grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); - assertEquals(Optional.empty(), - grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress())); + assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel)); - final List acceptLanguageRanges = Locale.LanguageRange.parse("en,ja"); - remoteChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges); + final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null); + remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes); - assertEquals(Optional.of(acceptLanguageRanges), - grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress())); + assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel)); } @Test - void getRemoteAddress() { - grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); - - assertEquals(Optional.empty(), - grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress())); - - final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9"); - remoteChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress); - - assertEquals(Optional.of(remoteAddress), - grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress())); - } - - @Test - void getUserAgent() throws UnrecognizedUserAgentException { - grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); - - assertEquals(Optional.empty(), - grpcClientConnectionManager.getUserAgent(localChannel.localAddress())); - - final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux"); - remoteChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent); - - assertEquals(Optional.of(userAgent), - grpcClientConnectionManager.getUserAgent(localChannel.localAddress())); - } - - @Test - void closeConnection() throws InterruptedException { + void closeConnection() throws InterruptedException, ChannelNotFoundException { final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); assertTrue(remoteChannel.isOpen()); - assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); remoteChannel.close().await(); - assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertThrows(ChannelNotFoundException.class, + () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); + assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); } - @Test - void handleWebSocketHandshakeCompleteRemoteAddress() { + @ParameterizedTest + @MethodSource + void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress, + final String userAgentHeader, + final String acceptLanguageHeader, + final RequestAttributes expectedRequestAttributes) { + final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1"); - - GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel, + GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel, preferredRemoteAddress, - null, - null); - - assertEquals(preferredRemoteAddress, - embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); - } - - @ParameterizedTest - @MethodSource - void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader, - @Nullable final UserAgent expectedParsedUserAgent) { - - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - - GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel, - InetAddresses.forString("127.0.0.1"), userAgentHeader, - null); - - assertEquals(userAgentHeader, - embeddedChannel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get()); - - assertEquals(expectedParsedUserAgent, - embeddedChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get()); - } - - private static List handleWebSocketHandshakeCompleteUserAgent() { - return List.of( - // Recognized user-agent - Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")), - - // Unrecognized user-agent - Arguments.of("Not a valid user-agent string", null), - - // Missing user-agent - Arguments.of(null, null) - ); - } - - - @ParameterizedTest - @MethodSource - void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader, - @Nullable final List expectedLanguageRanges) { - - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - - GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel, - InetAddresses.forString("127.0.0.1"), - null, acceptLanguageHeader); - assertEquals(expectedLanguageRanges, - embeddedChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get()); + assertEquals(expectedRequestAttributes, + embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get()); } - private static List handleWebSocketHandshakeCompleteAcceptLanguage() { + private static List handleHandshakeCompleteRequestAttributes() { + final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1"); + return List.of( - // Parseable list - Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")), + Arguments.argumentSet("Null User-Agent and Accept-Language headers", + preferredRemoteAddress, null, null, + new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())), - // Unparsable list - Arguments.of("This is not a valid language preference list", null), + Arguments.argumentSet("Recognized User-Agent and null Accept-Language header", + preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null, + new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())), - // Missing list - Arguments.of(null, null) + Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header", + preferredRemoteAddress, "Not a valid user-agent string", null, + new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())), + + Arguments.argumentSet("Null User-Agent and parsable Accept-Language header", + preferredRemoteAddress, null, "ja,en;q=0.4", + new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))), + + Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header", + preferredRemoteAddress, null, "This is not a valid language preference list", + new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())) ); } @Test - void handleConnectionEstablishedAuthenticated() throws InterruptedException { + void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException { final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); - assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertThrows(ChannelNotFoundException.class, + () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); + assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); - assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); remoteChannel.close().await(); - assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertThrows(ChannelNotFoundException.class, + () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); + assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); } @Test - void handleConnectionEstablishedAnonymous() throws InterruptedException { - assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException { + assertThrows(ChannelNotFoundException.class, + () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); - assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); remoteChannel.close().await(); - assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertThrows(ChannelNotFoundException.class, + () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java index 5a28b617a..603ab94fa 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java @@ -523,10 +523,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes assertEquals(remoteAddress, response.getRemoteAddress()); assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList()); - - assertEquals("DESKTOP", response.getUserAgent().getPlatform()); - assertEquals("1.2.3", response.getUserAgent().getVersion()); - assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers()); + assertEquals(userAgent, response.getUserAgent()); } finally { channel.shutdown(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java index 24436a593..54b796ced 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java @@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.argumentSet; import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.mockito.Mockito.mock; @@ -16,6 +17,7 @@ import io.netty.channel.local.LocalAddress; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import io.netty.util.Attribute; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; @@ -31,6 +33,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.signal.libsignal.protocol.ecc.Curve; +import org.whispersystems.textsecuregcm.grpc.RequestAttributes; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { @@ -134,8 +137,13 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { embeddedChannel.setRemoteAddress(remoteAddress); embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); + + assertEquals(expectedRemoteAddress, - embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); + Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY)) + .map(Attribute::get) + .map(RequestAttributes::remoteAddress) + .orElse(null)); } private static List getRemoteAddress() { @@ -144,53 +152,53 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1"); return List.of( - // Recognized proxy, single forwarded-for address - Arguments.of(new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) - .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), + argumentSet("Recognized proxy, single forwarded-for address", + new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) + .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), remoteAddress, clientAddress), - // Recognized proxy, multiple forwarded-for addresses - Arguments.of(new DefaultHttpHeaders() + argumentSet("Recognized proxy, multiple forwarded-for addresses", + new DefaultHttpHeaders() .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()), remoteAddress, proxyAddress), - // No recognized proxy header, single forwarded-for address - Arguments.of(new DefaultHttpHeaders() + argumentSet("No recognized proxy header, single forwarded-for address", + new DefaultHttpHeaders() .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), remoteAddress, remoteAddress.getAddress()), - // No recognized proxy header, no forwarded-for address - Arguments.of(new DefaultHttpHeaders(), + argumentSet("No recognized proxy header, no forwarded-for address", + new DefaultHttpHeaders(), remoteAddress, remoteAddress.getAddress()), - // Incorrect proxy header, single forwarded-for address - Arguments.of(new DefaultHttpHeaders() + argumentSet("Incorrect proxy header, single forwarded-for address", + new DefaultHttpHeaders() .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect") .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), remoteAddress, remoteAddress.getAddress()), - // Recognized proxy, no forwarded-for address - Arguments.of(new DefaultHttpHeaders() + argumentSet("Recognized proxy, no forwarded-for address", + new DefaultHttpHeaders() .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), remoteAddress, remoteAddress.getAddress()), - // Recognized proxy, bogus forwarded-for address - Arguments.of(new DefaultHttpHeaders() + argumentSet("Recognized proxy, bogus forwarded-for address", + new DefaultHttpHeaders() .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"), remoteAddress, null), - // No forwarded-for address, non-InetSocketAddress remote address - Arguments.of(new DefaultHttpHeaders() + argumentSet("No forwarded-for address, non-InetSocketAddress remote address", + new DefaultHttpHeaders() .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), new LocalAddress("local-address"), null) diff --git a/service/src/test/proto/request_attributes_service.proto b/service/src/test/proto/request_attributes_service.proto index e52e9873e..6f89fdb45 100644 --- a/service/src/test/proto/request_attributes_service.proto +++ b/service/src/test/proto/request_attributes_service.proto @@ -23,14 +23,7 @@ message GetRequestAttributesResponse { repeated string acceptable_languages = 1; repeated string available_accepted_locales = 2; string remote_address = 3; - string raw_user_agent = 4; - UserAgent user_agent = 5; -} - -message UserAgent { - string platform = 1; - string version = 2; - string additional_specifiers = 3; + string user_agent = 4; } message GetAuthenticatedDeviceRequest {