Clarify guarantees around remote channnel/request attribute presence

This commit is contained in:
Jon Chambers 2025-04-18 15:44:21 -04:00 committed by GitHub
parent 28a0b9e84e
commit 8099d6465c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 405 additions and 570 deletions

View File

@ -1,34 +1,22 @@
package org.whispersystems.textsecuregcm.auth.grpc; package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import java.util.Optional; import java.util.Optional;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor { abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager; private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Metadata EMPTY_TRAILERS = new Metadata();
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager; this.grpcClientConnectionManager = grpcClientConnectionManager;
} }
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) { protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call)
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { throws ChannelNotFoundException {
return grpcClientConnectionManager.getAuthenticatedDevice(localAddress);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) { return grpcClientConnectionManager.getAuthenticatedDevice(call);
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
return new ServerCall.Listener<>() {};
} }
} }

View File

@ -3,12 +3,17 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/** /**
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not * A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated * originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
* device are closed with an {@code UNAUTHENTICATED} status. * device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
* reject the call with a status of {@code UNAVAILABLE}.
*/ */
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor { public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -21,8 +26,15 @@ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInt
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
return getAuthenticatedDevice(call) try {
.map(ignored -> closeAsUnauthenticated(call)) return getAuthenticatedDevice(call)
.orElseGet(() -> next.startCall(call, headers)); // Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited
// service via an authenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
.orElseGet(() -> next.startCall(call, headers));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
} }
} }

View File

@ -5,12 +5,16 @@ import io.grpc.Contexts;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/** /**
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an * A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED} * authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
* status. * status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
*/ */
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor { public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -23,10 +27,17 @@ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInte
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
return getAuthenticatedDevice(call) try {
.map(authenticatedDevice -> Contexts.interceptCall(Context.current() return getAuthenticatedDevice(call)
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), .map(authenticatedDevice -> Contexts.interceptCall(Context.current()
call, headers, next)) .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
.orElseGet(() -> closeAsUnauthenticated(call)); call, headers, next))
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
} }
} }

View File

@ -81,7 +81,16 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) { @Nullable final UserAgent userAgent = RequestAttributesUtil.getUserAgent()
.map(userAgentString -> {
try {
return UserAgentUtil.parseUserAgentString(userAgentString);
} catch (final UnrecognizedUserAgentException e) {
return null;
}
}).orElse(null);
if (shouldBlock(userAgent)) {
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata()); call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
return new ServerCall.Listener<>() {}; return new ServerCall.Listener<>() {};
} else { } else {

View File

@ -0,0 +1,12 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
/**
* Indicates that a remote channel was not found for a given server call or remote address.
*/
public class ChannelNotFoundException extends Exception {
}

View File

@ -253,7 +253,7 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me
story, story,
ephemeral, ephemeral,
urgent, urgent,
RequestAttributesUtil.getRawUserAgent().orElse(null)); RequestAttributesUtil.getUserAgent().orElse(null));
final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder(); final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder();

View File

@ -55,7 +55,7 @@ public class MessagesGrpcHelper {
messagesByDeviceId, messagesByDeviceId,
registrationIdsByDeviceId, registrationIdsByDeviceId,
syncMessageSenderDeviceId, syncMessageSenderDeviceId,
RequestAttributesUtil.getRawUserAgent().orElse(null)); RequestAttributesUtil.getUserAgent().orElse(null));
return SEND_MESSAGE_SUCCESS_RESPONSE; return SEND_MESSAGE_SUCCESS_RESPONSE;
} catch (final MismatchedDevicesException e) { } catch (final MismatchedDevicesException e) {

View File

@ -0,0 +1,16 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
public record RequestAttributes(InetAddress remoteAddress,
@Nullable String userAgent,
List<Locale.LanguageRange> acceptLanguage) {
}

View File

@ -2,28 +2,25 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Contexts; import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import io.grpc.Status; import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
/**
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
*
* @see RequestAttributesUtil
*/
public class RequestAttributesInterceptor implements ServerInterceptor { public class RequestAttributesInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager; private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager; this.grpcClientConnectionManager = grpcClientConnectionManager;
} }
@ -33,52 +30,12 @@ public class RequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { try {
Context context = Context.current(); return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
{ grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next);
final Optional<InetAddress> maybeRemoteAddress = grpcClientConnectionManager.getRemoteAddress(localAddress); } catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
if (maybeRemoteAddress.isEmpty()) {
// We should never have a call from a party whose remote address we can't identify
log.warn("No remote address available");
call.close(Status.INTERNAL, new Metadata());
return new ServerCall.Listener<>() {};
}
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
}
{
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
grpcClientConnectionManager.getAcceptableLanguages(localAddress);
if (maybeAcceptLanguage.isPresent()) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
}
}
{
final Optional<String> maybeRawUserAgent =
grpcClientConnectionManager.getRawUserAgent(localAddress);
if (maybeRawUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.RAW_USER_AGENT_CONTEXT_KEY, maybeRawUserAgent.get());
}
}
{
final Optional<UserAgent> maybeUserAgent = grpcClientConnectionManager.getUserAgent(localAddress);
if (maybeUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
}
}
return Contexts.interceptCall(context, call, headers, next);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
} }
} }
} }

View File

@ -3,18 +3,13 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context; import io.grpc.Context;
import java.net.InetAddress; import java.net.InetAddress;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class RequestAttributesUtil { public class RequestAttributesUtil {
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language"); static final Context.Key<RequestAttributes> REQUEST_ATTRIBUTES_CONTEXT_KEY = Context.key("request-attributes");
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
static final Context.Key<String> RAW_USER_AGENT_CONTEXT_KEY = Context.key("unparsed-user-agent");
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("parsed-user-agent");
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales()); private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
@ -23,8 +18,8 @@ public class RequestAttributesUtil {
* *
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified * @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
*/ */
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() { public static List<Locale.LanguageRange> getAcceptableLanguages() {
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get()); return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().acceptLanguage();
} }
/** /**
@ -35,9 +30,7 @@ public class RequestAttributesUtil {
* @return a list of distinct locales acceptable to the remote client and available in this JVM * @return a list of distinct locales acceptable to the remote client and available in this JVM
*/ */
public static List<Locale> getAvailableAcceptedLocales() { public static List<Locale> getAvailableAcceptedLocales() {
return getAcceptableLanguages() return Locale.filter(getAcceptableLanguages(), AVAILABLE_LOCALES);
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
.orElseGet(Collections::emptyList);
} }
/** /**
@ -46,16 +39,7 @@ public class RequestAttributesUtil {
* @return the remote address of the remote client * @return the remote address of the remote client
*/ */
public static InetAddress getRemoteAddress() { public static InetAddress getRemoteAddress() {
return REMOTE_ADDRESS_CONTEXT_KEY.get(); return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().remoteAddress();
}
/**
* Returns the parsed user-agent of the remote client in the current gRPC request context.
*
* @return the parsed user-agent of the remote client; may be empty if unparseable or not specified
*/
public static Optional<UserAgent> getUserAgent() {
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
} }
/** /**
@ -63,7 +47,7 @@ public class RequestAttributesUtil {
* *
* @return the unparsed user-agent of the remote client; may be empty if not specified * @return the unparsed user-agent of the remote client; may be empty if not specified
*/ */
public static Optional<String> getRawUserAgent() { public static Optional<String> getUserAgent() {
return Optional.ofNullable(RAW_USER_AGENT_CONTEXT_KEY.get()); return Optional.ofNullable(REQUEST_ATTRIBUTES_CONTEXT_KEY.get().userAgent());
} }
} }

View File

@ -0,0 +1,39 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
public class ServerInterceptorUtil {
@SuppressWarnings("rawtypes")
private static final ServerCall.Listener NO_OP_LISTENER = new ServerCall.Listener<>() {};
private static final Metadata EMPTY_TRAILERS = new Metadata();
private ServerInterceptorUtil() {
}
/**
* Closes the given server call with the given status, returning a no-op listener.
*
* @param call the server call to close
* @param status the status with which to close the call
*
* @return a no-op server call listener
*
* @param <ReqT> the type of request object handled by the server call
* @param <RespT> the type of response object returned by the server call
*/
public static <ReqT, RespT> ServerCall.Listener<ReqT> closeWithStatus(final ServerCall<ReqT, RespT> call, final Status status) {
call.close(status, EMPTY_TRAILERS);
//noinspection unchecked
return NO_OP_LISTENER;
}
}

View File

@ -12,8 +12,10 @@ import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
/** /**
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering * An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
@ -48,12 +50,12 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
@Override @Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) { public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) { if (event instanceof NoiseIdentityDeterminedEvent(final Optional<AuthenticatedDevice> authenticatedDevice)) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
// authenticated service. // authenticated service.
final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent() final LocalAddress grpcServerAddress = authenticatedDevice.isPresent()
? authenticatedGrpcServerAddress ? authenticatedGrpcServerAddress
: anonymousGrpcServerAddress; : anonymousGrpcServerAddress;
@ -72,7 +74,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
if (localChannelFuture.isSuccess()) { if (localChannelFuture.isSuccess()) {
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(), grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(), remoteChannelContext.channel(),
noiseIdentityDeterminedEvent.authenticatedDevice()); authenticatedDevice);
// Close the local connection if the remote channel closes and vice versa // Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close()); remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());

View File

@ -1,6 +1,8 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.Grpc;
import io.grpc.ServerCall;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
@ -23,15 +25,25 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent; import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
/** /**
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a * A client connection manager associates a local connection to a local gRPC server with a remote connection through a
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the * Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close * identity of the device that opened the connection (for non-anonymous connections). It can also close connections
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate. * associated with a given device if that device's credentials have changed and clients must reauthenticate.
* <p>
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
* narrow window between a server call being created and the start of call execution, in which case accessor methods
* in this class will throw a {@link ChannelNotFoundException}.
* <p>
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
* be called from any application code.
*/ */
public class GrpcClientConnectionManager implements DisconnectionRequestListener { public class GrpcClientConnectionManager implements DisconnectionRequestListener {
@ -43,94 +55,56 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice"); AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
@VisibleForTesting @VisibleForTesting
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY = public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress"); AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
@VisibleForTesting
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent");
@VisibleForTesting
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
@VisibleForTesting
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class); private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
/** /**
* Returns the authenticated device associated with the given local address, if any. An authenticated device is * Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
* available if and only if the given local address maps to an active local connection and that connection is * (i.e. unauthenticated), the returned value will be empty.
* authenticated (i.e. not anonymous).
* *
* @param localAddress the local address for which to find an authenticated device * @param serverCall the gRPC server call for which to find an authenticated device
* *
* @return the authenticated device associated with the given local address, if any * @return the authenticated device associated with the given local address, if any
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/ */
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) { public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall)
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress)); throws ChannelNotFoundException {
return getAuthenticatedDevice(getRemoteChannel(serverCall));
} }
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) { @VisibleForTesting
return Optional.ofNullable(remoteChannel) Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) {
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get()); return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
} }
/** /**
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may * Returns the request attributes associated with the given server call.
* be unavailable if the local connection associated with the given local address has already closed, if the client
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
* *
* @param localAddress the local address for which to find acceptable languages * @param serverCall the gRPC server call for which to retrieve request attributes
* *
* @return the acceptable languages associated with the given local address, if any * @return the request attributes associated with the given server call
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/ */
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) { public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) return getRequestAttributes(getRemoteChannel(serverCall));
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
} }
/** @VisibleForTesting
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if RequestAttributes getRequestAttributes(final Channel remoteChannel) {
* the local connection associated with the given local address has already closed. final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
*
* @param localAddress the local address for which to find a remote address
*
* @return the remote address associated with the given local address, if any
*/
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
/** if (requestAttributes == null) {
* Returns the unparsed user agent provided by the client that opened the connection associated with the given local throw new IllegalStateException("Channel does not have request attributes");
* address. This method may return an empty value if no active local connection is associated with the given local }
* address.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent string associated with the given local address
*/
public Optional<String> getRawUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(RAW_USER_AGENT_ATTRIBUTE_KEY).get());
}
/** return requestAttributes;
* Returns the parsed user agent provided by the client that opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address or if the client's user-agent string was not recognized.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent associated with the given local address
*/
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
} }
/** /**
@ -156,11 +130,32 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice); return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
} }
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRemoteChannel(getLocalAddress(serverCall));
}
@VisibleForTesting @VisibleForTesting
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) { Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
if (remoteChannel == null) {
throw new ChannelNotFoundException();
}
return remoteChannelsByLocalAddress.get(localAddress); return remoteChannelsByLocalAddress.get(localAddress);
} }
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
// a proxied Noise channel.
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
return localAddress;
}
/** /**
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake * Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
* request with the channel via which the handshake took place. * request with the channel via which the handshake took place.
@ -171,30 +166,23 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be * @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null} * {@code null}
*/ */
static void handleWebSocketHandshakeComplete(final Channel channel, static void handleHandshakeComplete(final Channel channel,
final InetAddress preferredRemoteAddress, final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader, @Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) { @Nullable final String acceptLanguageHeader) {
channel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress); @Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
if (StringUtils.isNotBlank(userAgentHeader)) {
channel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
try {
channel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
} catch (final UnrecognizedUserAgentException ignored) {
}
}
if (StringUtils.isNotBlank(acceptLanguageHeader)) { if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try { try {
channel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader)); acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
} catch (final IllegalArgumentException e) { } catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e); log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
} }
} }
channel.attr(REQUEST_ATTRIBUTES_KEY)
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
} }
/** /**

View File

@ -74,7 +74,7 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
preferredRemoteAddress = maybePreferredRemoteAddress.get(); preferredRemoteAddress = maybePreferredRemoteAddress.get();
} }
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(), GrpcClientConnectionManager.handleHandshakeComplete(context.channel(),
preferredRemoteAddress, preferredRemoteAddress,
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT), handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE)); handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));

View File

@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Status; import io.grpc.Status;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -22,7 +23,7 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
} }
@Test @Test
void interceptCall() { void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
@ -34,6 +35,10 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
} }
} }

View File

@ -9,6 +9,7 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -22,12 +23,12 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
} }
@Test @Test
void interceptCall() { void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
@ -35,5 +36,9 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice(); final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier()); assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId()); assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
} }
} }

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.dropwizard.core.Application; import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration; import io.dropwizard.core.Configuration;
@ -24,7 +25,6 @@ import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Client; import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import java.net.InetAddress;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.Set; import java.util.Set;
@ -39,6 +39,7 @@ import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.InetAddressRange; import org.whispersystems.textsecuregcm.util.InetAddressRange;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@ -157,7 +158,7 @@ class ExternalRequestFilterTest {
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRemoteAddress(InetAddress.getByName("127.0.0.1")); mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest") testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest")
.directExecutor() .directExecutor()

View File

@ -15,6 +15,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
@ -40,11 +41,10 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.grpc.StatusConstants; import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RemoteDeprecationFilterTest { class RemoteDeprecationFilterTest {
@ -130,11 +130,7 @@ class RemoteDeprecationFilterTest {
@MethodSource(value="testFilter") @MethodSource(value="testFilter")
void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException { void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), userAgentString, null));
try {
mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString));
} catch (UnrecognizedUserAgentException ignored) {
}
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest") final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor() .directExecutor()

View File

@ -72,7 +72,8 @@ class AccountsAnonymousGrpcServiceTest extends
when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1")); getMockRequestAttributesInterceptor().setRequestAttributes(
new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
return new AccountsAnonymousGrpcService(accountsManager, rateLimiters); return new AccountsAnonymousGrpcService(accountsManager, rateLimiters);
} }

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import com.google.common.net.InetAddresses;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Contexts; import io.grpc.Contexts;
import io.grpc.Metadata; import io.grpc.Metadata;
@ -19,25 +20,10 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class MockRequestAttributesInterceptor implements ServerInterceptor { public class MockRequestAttributesInterceptor implements ServerInterceptor {
@Nullable private RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null);
private InetAddress remoteAddress;
@Nullable public void setRequestAttributes(final RequestAttributes requestAttributes) {
private UserAgent userAgent; this.requestAttributes = requestAttributes;
@Nullable
private List<Locale.LanguageRange> acceptLanguage;
public void setRemoteAddress(@Nullable final InetAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
public void setUserAgent(@Nullable final UserAgent userAgent) {
this.userAgent = userAgent;
}
public void setAcceptLanguage(@Nullable final List<Locale.LanguageRange> acceptLanguage) {
this.acceptLanguage = acceptLanguage;
} }
@Override @Override
@ -45,20 +31,7 @@ public class MockRequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
Context context = Context.current(); return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes), serverCall, headers, next);
if (remoteAddress != null) {
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress);
}
if (userAgent != null) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent);
}
if (acceptLanguage != null) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage);
}
return Contexts.interceptCall(context, serverCall, headers, next);
} }
} }

View File

@ -15,6 +15,7 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -75,8 +76,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> { public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@ -96,13 +95,9 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
@Override @Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() { protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us")); getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
try { Locale.LanguageRange.parse("en-us")));
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
return new ProfileAnonymousGrpcService( return new ProfileAnonymousGrpcService(
accountsManager, accountsManager,

View File

@ -24,6 +24,7 @@ import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber; import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
@ -108,8 +109,6 @@ import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@ -177,13 +176,9 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164); PhoneNumberUtil.PhoneNumberFormat.E164);
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us")); getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
try { Locale.LanguageRange.parse("en-us")));
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter); when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());

View File

@ -1,13 +1,11 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import io.grpc.stub.StreamObserver; import io.grpc.stub.StreamObserver;
import org.apache.commons.lang3.StringUtils;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest; import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse; import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc; import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.chat.rpc.UserAgent;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
@ -20,21 +18,15 @@ public class RequestAttributesServiceImpl extends RequestAttributesGrpc.RequestA
final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder(); final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder();
RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages -> RequestAttributesUtil.getAcceptableLanguages()
acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()))); .forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()));
RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale -> RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale ->
responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag())); responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag()));
responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress()); responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress());
RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder() RequestAttributesUtil.getUserAgent().ifPresent(responseBuilder::setUserAgent);
.setPlatform(userAgent.platform().toString())
.setVersion(userAgent.version().toString())
.setAdditionalSpecifiers(StringUtils.stripToEmpty(userAgent.additionalSpecifiers()))
.build()));
RequestAttributesUtil.getRawUserAgent().ifPresent(responseBuilder::setRawUserAgent);
responseObserver.onNext(responseBuilder.build()); responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted(); responseObserver.onCompleted();

View File

@ -3,172 +3,84 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import io.grpc.ManagedChannel; import io.grpc.Context;
import io.grpc.Server; import java.net.InetAddress;
import io.grpc.Status; import java.util.Collections;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import org.junit.jupiter.api.AfterAll; import java.util.concurrent.Callable;
import org.junit.jupiter.api.AfterEach; import javax.annotation.Nullable;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RequestAttributesUtilTest { class RequestAttributesUtilTest {
private static DefaultEventLoopGroup eventLoopGroup; private static final InetAddress REMOTE_ADDRESS = InetAddresses.forString("127.0.0.1");
private GrpcClientConnectionManager grpcClientConnectionManager; @Test
void getAcceptableLanguages() throws Exception {
assertEquals(Collections.emptyList(),
callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAcceptableLanguages));
private Server server; assertEquals(Locale.LanguageRange.parse("en,ja"),
private ManagedChannel managedChannel; callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAcceptableLanguages));
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server");
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString("127.0.0.1")));
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
} }
@Test @Test
void getAcceptableLanguages() { void getAvailableAcceptedLocales() throws Exception {
when(grpcClientConnectionManager.getAcceptableLanguages(any())) assertEquals(Collections.emptyList(),
.thenReturn(Optional.empty()); callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAvailableAcceptedLocales));
assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty()); final List<Locale> availableAcceptedLocales =
callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAvailableAcceptedLocales);
when(grpcClientConnectionManager.getAcceptableLanguages(any())) assertFalse(availableAcceptedLocales.isEmpty());
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList()); availableAcceptedLocales.forEach(locale ->
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage())));
} }
@Test @Test
void getAvailableAcceptedLocales() { void getRemoteAddress() throws Exception {
when(grpcClientConnectionManager.getAcceptableLanguages(any())) assertEquals(REMOTE_ADDRESS,
.thenReturn(Optional.empty()); callWithRequestAttributes(new RequestAttributes(REMOTE_ADDRESS, null, null),
RequestAttributesUtil::getRemoteAddress));
assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty());
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
final GetRequestAttributesResponse response = getRequestAttributes();
assertFalse(response.getAvailableAcceptedLocalesList().isEmpty());
response.getAvailableAcceptedLocalesList().forEach(languageTag -> {
final Locale locale = Locale.forLanguageTag(languageTag);
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()));
});
} }
@Test @Test
void getRemoteAddress() { void getUserAgent() throws Exception {
when(grpcClientConnectionManager.getRemoteAddress(any())) assertEquals(Optional.empty(),
.thenReturn(Optional.empty()); callWithRequestAttributes(buildRequestAttributes((String) null),
RequestAttributesUtil::getUserAgent));
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes); assertEquals(Optional.of("Signal-Desktop/1.2.3 Linux"),
callWithRequestAttributes(buildRequestAttributes("Signal-Desktop/1.2.3 Linux"),
final String remoteAddressString = "6.7.8.9"; RequestAttributesUtil::getUserAgent));
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString(remoteAddressString)));
assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress());
} }
@Test private static <V> V callWithRequestAttributes(final RequestAttributes requestAttributes, final Callable<V> callable) throws Exception {
void getUserAgent() throws UnrecognizedUserAgentException { return Context.current()
when(grpcClientConnectionManager.getUserAgent(any())) .withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes)
.thenReturn(Optional.empty()); .call(callable);
assertFalse(getRequestAttributes().hasUserAgent());
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
when(grpcClientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.of(userAgent));
final GetRequestAttributesResponse response = getRequestAttributes();
assertTrue(response.hasUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} }
@Test private static RequestAttributes buildRequestAttributes(final String userAgent) {
void getRawUserAgent() { return buildRequestAttributes(userAgent, Collections.emptyList());
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getRawUserAgent().isBlank());
final String userAgentString = "Signal-Desktop/1.2.3 Linux";
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.of(userAgentString));
assertEquals(userAgentString, getRequestAttributes().getRawUserAgent());
} }
private GetRequestAttributesResponse getRequestAttributes() { private static RequestAttributes buildRequestAttributes(final List<Locale.LanguageRange> acceptLanguage) {
return RequestAttributesGrpc.newBlockingStub(managedChannel) return buildRequestAttributes(null, acceptLanguage);
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()); }
private static RequestAttributes buildRequestAttributes(@Nullable final String userAgent,
final List<Locale.LanguageRange> acceptLanguage) {
return new RequestAttributes(REMOTE_ADDRESS, userAgent, acceptLanguage);
} }
} }

View File

@ -1,7 +1,11 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -12,6 +16,12 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel; import io.netty.channel.local.LocalServerChannel;
import java.net.InetAddress;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
@ -21,20 +31,9 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
class GrpcClientConnectionManagerTest { class GrpcClientConnectionManagerTest {
@ -103,7 +102,7 @@ class GrpcClientConnectionManagerTest {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
assertEquals(maybeAuthenticatedDevice, assertEquals(maybeAuthenticatedDevice,
grpcClientConnectionManager.getAuthenticatedDevice(localChannel.localAddress())); grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel));
} }
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() { private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
@ -114,170 +113,115 @@ class GrpcClientConnectionManagerTest {
} }
@Test @Test
void getAcceptableLanguages() { void getRequestAttributes() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(), assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel));
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
final List<Locale.LanguageRange> acceptLanguageRanges = Locale.LanguageRange.parse("en,ja"); final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null);
remoteChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges); remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes);
assertEquals(Optional.of(acceptLanguageRanges), assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel));
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
} }
@Test @Test
void getRemoteAddress() { void closeConnection() throws InterruptedException, ChannelNotFoundException {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9");
remoteChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress);
assertEquals(Optional.of(remoteAddress),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
remoteChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent);
assertEquals(Optional.of(userAgent),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
}
@Test
void closeConnection() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertTrue(remoteChannel.isOpen()); assertTrue(remoteChannel.isOpen());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), assertEquals(List.of(remoteChannel),
grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await(); remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
} }
@Test @ParameterizedTest
void handleWebSocketHandshakeCompleteRemoteAddress() { @MethodSource
void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress,
final String userAgentHeader,
final String acceptLanguageHeader,
final RequestAttributes expectedRequestAttributes) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1"); GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel,
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
preferredRemoteAddress, preferredRemoteAddress,
null,
null);
assertEquals(preferredRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader,
@Nullable final UserAgent expectedParsedUserAgent) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
userAgentHeader, userAgentHeader,
null);
assertEquals(userAgentHeader,
embeddedChannel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get());
assertEquals(expectedParsedUserAgent,
embeddedChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteUserAgent() {
return List.of(
// Recognized user-agent
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
// Unrecognized user-agent
Arguments.of("Not a valid user-agent string", null),
// Missing user-agent
Arguments.of(null, null)
);
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader,
@Nullable final List<Locale.LanguageRange> expectedLanguageRanges) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
null,
acceptLanguageHeader); acceptLanguageHeader);
assertEquals(expectedLanguageRanges, assertEquals(expectedRequestAttributes,
embeddedChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get()); embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
} }
private static List<Arguments> handleWebSocketHandshakeCompleteAcceptLanguage() { private static List<Arguments> handleHandshakeCompleteRequestAttributes() {
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
return List.of( return List.of(
// Parseable list Arguments.argumentSet("Null User-Agent and Accept-Language headers",
Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")), preferredRemoteAddress, null, null,
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())),
// Unparsable list Arguments.argumentSet("Recognized User-Agent and null Accept-Language header",
Arguments.of("This is not a valid language preference list", null), preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null,
new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())),
// Missing list Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header",
Arguments.of(null, null) preferredRemoteAddress, "Not a valid user-agent string", null,
new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())),
Arguments.argumentSet("Null User-Agent and parsable Accept-Language header",
preferredRemoteAddress, null, "ja,en;q=0.4",
new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))),
Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header",
preferredRemoteAddress, null, "This is not a valid language preference list",
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList()))
); );
} }
@Test @Test
void handleConnectionEstablishedAuthenticated() throws InterruptedException { void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await(); remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
} }
@Test @Test
void handleConnectionEstablishedAnonymous() throws InterruptedException { void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException {
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
remoteChannel.close().await(); remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
} }
} }

View File

@ -523,10 +523,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
assertEquals(remoteAddress, response.getRemoteAddress()); assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList()); assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals(userAgent, response.getUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }

View File

@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -16,6 +17,7 @@ import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.Attribute;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
@ -31,6 +33,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
@ -134,8 +137,13 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
embeddedChannel.setRemoteAddress(remoteAddress); embeddedChannel.setRemoteAddress(remoteAddress);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertEquals(expectedRemoteAddress, assertEquals(expectedRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY))
.map(Attribute::get)
.map(RequestAttributes::remoteAddress)
.orElse(null));
} }
private static List<Arguments> getRemoteAddress() { private static List<Arguments> getRemoteAddress() {
@ -144,53 +152,53 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1"); final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
return List.of( return List.of(
// Recognized proxy, single forwarded-for address argumentSet("Recognized proxy, single forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
clientAddress), clientAddress),
// Recognized proxy, multiple forwarded-for addresses argumentSet("Recognized proxy, multiple forwarded-for addresses",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
remoteAddress, remoteAddress,
proxyAddress), proxyAddress),
// No recognized proxy header, single forwarded-for address argumentSet("No recognized proxy header, single forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// No recognized proxy header, no forwarded-for address argumentSet("No recognized proxy header, no forwarded-for address",
Arguments.of(new DefaultHttpHeaders(), new DefaultHttpHeaders(),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// Incorrect proxy header, single forwarded-for address argumentSet("Incorrect proxy header, single forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect") .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// Recognized proxy, no forwarded-for address argumentSet("Recognized proxy, no forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// Recognized proxy, bogus forwarded-for address argumentSet("Recognized proxy, bogus forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
remoteAddress, remoteAddress,
null), null),
// No forwarded-for address, non-InetSocketAddress remote address argumentSet("No forwarded-for address, non-InetSocketAddress remote address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
new LocalAddress("local-address"), new LocalAddress("local-address"),
null) null)

View File

@ -23,14 +23,7 @@ message GetRequestAttributesResponse {
repeated string acceptable_languages = 1; repeated string acceptable_languages = 1;
repeated string available_accepted_locales = 2; repeated string available_accepted_locales = 2;
string remote_address = 3; string remote_address = 3;
string raw_user_agent = 4; string user_agent = 4;
UserAgent user_agent = 5;
}
message UserAgent {
string platform = 1;
string version = 2;
string additional_specifiers = 3;
} }
message GetAuthenticatedDeviceRequest { message GetAuthenticatedDeviceRequest {