diff --git a/service/config/sample-secrets-bundle.yml b/service/config/sample-secrets-bundle.yml index 34068424e..b33000665 100644 --- a/service/config/sample-secrets-bundle.yml +++ b/service/config/sample-secrets-bundle.yml @@ -95,3 +95,5 @@ turn.secret: AAAAAAAAAAA= linkDevice.secret: AAAAAAAAAAA= tlsKeyStore.password: unset + +noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA diff --git a/service/config/sample.yml b/service/config/sample.yml index 0c280ca6b..65e41fc13 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -468,3 +468,7 @@ callingTurnManualTable: s3Bucket: a-bucket objectKey: an-object.tar.gz maxSize: 32777216 + +noiseTunnel: + port: 8443 + recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index 2d4175f81..9ae54f59b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -39,6 +39,7 @@ import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration; import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; +import org.whispersystems.textsecuregcm.configuration.NoiseWebSocketTunnelConfiguration; import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration; import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration; @@ -333,6 +334,11 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private MonitoredS3ObjectConfiguration callingTurnManualTable; + @Valid + @NotNull + @JsonProperty + private NoiseWebSocketTunnelConfiguration noiseTunnel; + public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() { return tlsKeyStore; } @@ -555,4 +561,8 @@ public class WhisperServerConfiguration extends Configuration { public MonitoredS3ObjectConfiguration getCallingTurnManualTable() { return callingTurnManualTable; } + + public NoiseWebSocketTunnelConfiguration getNoiseWebSocketTunnelConfiguration() { + return noiseTunnel; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 11fdb6ce9..f4fde4014 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -71,7 +71,8 @@ import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; -import org.whispersystems.textsecuregcm.auth.grpc.BasicCredentialAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; import org.whispersystems.textsecuregcm.backup.BackupAuthManager; import org.whispersystems.textsecuregcm.backup.BackupManager; import org.whispersystems.textsecuregcm.backup.BackupsDb; @@ -127,7 +128,6 @@ import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; import org.whispersystems.textsecuregcm.geo.MaxMindDatabaseManager; -import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor; import org.whispersystems.textsecuregcm.grpc.AccountsAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService; import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor; @@ -138,7 +138,8 @@ import org.whispersystems.textsecuregcm.grpc.KeysGrpcService; import org.whispersystems.textsecuregcm.grpc.PaymentsGrpcService; import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService; -import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor; +import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; +import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup; import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer; import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer; @@ -752,8 +753,7 @@ public class WhisperServerService extends Application getAuthenticatedDevice(final ServerCall call) { + if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { + return clientConnectionManager.getAuthenticatedDevice(localAddress); + } else { + throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); + } + } + + protected ServerCall.Listener closeAsUnauthenticated(final ServerCall call) { + call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS); + return new ServerCall.Listener<>() {}; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java index 26577d50e..dacfe5c43 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java @@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.auth.grpc; import io.grpc.Context; import io.grpc.Status; -import java.util.UUID; import javax.annotation.Nullable; import org.whispersystems.textsecuregcm.storage.Device; @@ -16,8 +15,7 @@ import org.whispersystems.textsecuregcm.storage.Device; */ public class AuthenticationUtil { - static final Context.Key CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY = Context.key("authenticated-aci"); - static final Context.Key CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id"); + static final Context.Key CONTEXT_AUTHENTICATED_DEVICE = Context.key("authenticated-device"); /** * Returns the account/device authenticated in the current gRPC context or throws an "unauthenticated" exception if @@ -29,11 +27,10 @@ public class AuthenticationUtil { * could be retrieved from the current gRPC context */ public static AuthenticatedDevice requireAuthenticatedDevice() { - @Nullable final UUID accountIdentifier = CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get(); - @Nullable final Byte deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get(); + @Nullable final AuthenticatedDevice authenticatedDevice = CONTEXT_AUTHENTICATED_DEVICE.get(); - if (accountIdentifier != null && deviceId != null) { - return new AuthenticatedDevice(accountIdentifier, deviceId); + if (authenticatedDevice != null) { + return authenticatedDevice; } throw Status.UNAUTHENTICATED.asRuntimeException(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java deleted file mode 100644 index 635c3497b..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth.grpc; - -import com.google.common.annotations.VisibleForTesting; -import io.dropwizard.auth.basic.BasicCredentials; -import io.grpc.Context; -import io.grpc.Contexts; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.Status; -import java.util.Optional; -import org.apache.commons.lang3.StringUtils; -import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; -import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; -import org.whispersystems.textsecuregcm.util.HeaderUtils; - -/** - * A basic credential authentication interceptor enforces the presence of a valid username and password on every call. - * Callers supply credentials by providing a username (UUID and optional device ID) and password pair in the - * {@code x-signal-basic-auth-credentials} call header. - *

- * Downstream services can retrieve the identity of the authenticated caller using methods in - * {@link AuthenticationUtil}. - *

- * Note that this authentication, while fully functional, is intended only for development and testing purposes and is - * intended to be replaced with a more robust and efficient strategy before widespread client adoption. - * - * @see AuthenticationUtil - * @see AccountAuthenticator - */ -public class BasicCredentialAuthenticationInterceptor implements ServerInterceptor { - - private final AccountAuthenticator accountAuthenticator; - - @VisibleForTesting - static final Metadata.Key BASIC_CREDENTIALS = - Metadata.Key.of("x-signal-auth", Metadata.ASCII_STRING_MARSHALLER); - - private static final Metadata EMPTY_TRAILERS = new Metadata(); - - public BasicCredentialAuthenticationInterceptor(final AccountAuthenticator accountAuthenticator) { - this.accountAuthenticator = accountAuthenticator; - } - - @Override - public ServerCall.Listener interceptCall( - final ServerCall call, - final Metadata headers, - final ServerCallHandler next) { - - final String authHeader = headers.get(BASIC_CREDENTIALS); - - if (StringUtils.isNotBlank(authHeader)) { - final Optional maybeCredentials = HeaderUtils.basicCredentialsFromAuthHeader(authHeader); - if (maybeCredentials.isEmpty()) { - call.close(Status.UNAUTHENTICATED.withDescription("Could not parse credentials"), EMPTY_TRAILERS); - } else { - final Optional maybeAuthenticatedAccount = - accountAuthenticator.authenticate(maybeCredentials.get()); - - if (maybeAuthenticatedAccount.isPresent()) { - final AuthenticatedAccount authenticatedAccount = maybeAuthenticatedAccount.get(); - - final Context context = Context.current() - .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedAccount.getAccount().getUuid()) - .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedAccount.getAuthenticatedDevice().getId()); - - return Contexts.interceptCall(context, call, headers, next); - } else { - call.close(Status.UNAUTHENTICATED.withDescription("Credentials not accepted"), EMPTY_TRAILERS); - } - } - } else { - call.close(Status.UNAUTHENTICATED.withDescription("No credentials provided"), EMPTY_TRAILERS); - } - - return new ServerCall.Listener<>() {}; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java new file mode 100644 index 000000000..4de5c4d35 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java @@ -0,0 +1,28 @@ +package org.whispersystems.textsecuregcm.auth.grpc; + +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager; + +/** + * A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not + * originate from a channel that is associated with an authenticated device. Calls with an associated authenticated + * device are closed with an {@code UNAUTHENTICATED} status. + */ +public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor { + + public ProhibitAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) { + super(clientConnectionManager); + } + + @Override + public ServerCall.Listener interceptCall(final ServerCall call, + final Metadata headers, + final ServerCallHandler next) { + + return getAuthenticatedDevice(call) + .map(ignored -> closeAsUnauthenticated(call)) + .orElseGet(() -> next.startCall(call, headers)); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java new file mode 100644 index 000000000..9ef160f1a --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java @@ -0,0 +1,32 @@ +package org.whispersystems.textsecuregcm.auth.grpc; + +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager; + +/** + * A "require authentication" interceptor requires that requests be issued from a connection that is associated with an + * authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED} + * status. + */ +public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor { + + public RequireAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) { + super(clientConnectionManager); + } + + @Override + public ServerCall.Listener interceptCall(final ServerCall call, + final Metadata headers, + final ServerCallHandler next) { + + return getAuthenticatedDevice(call) + .map(authenticatedDevice -> Contexts.interceptCall(Context.current() + .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), + call, headers, next)) + .orElseGet(() -> closeAsUnauthenticated(call)); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/NoiseWebSocketTunnelConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/NoiseWebSocketTunnelConfiguration.java new file mode 100644 index 000000000..45e0832d8 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/NoiseWebSocketTunnelConfiguration.java @@ -0,0 +1,8 @@ +package org.whispersystems.textsecuregcm.configuration; + +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Positive; +import org.whispersystems.textsecuregcm.configuration.secrets.SecretString; + +public record NoiseWebSocketTunnelConfiguration(@Positive int port, @NotNull SecretString recognizedProxySecret) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java index d1f94cf41..a72697b06 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java @@ -5,7 +5,6 @@ package org.whispersystems.textsecuregcm.filters; -import com.google.common.annotations.VisibleForTesting; import java.io.IOException; import java.util.Optional; import javax.annotation.Nullable; @@ -72,8 +71,7 @@ public class RemoteAddressFilter implements Filter { * @see X-Forwarded-For - HTTP | * MDN */ - @VisibleForTesting - static Optional getMostRecentProxy(@Nullable final String forwardedFor) { + public static Optional getMostRecentProxy(@Nullable final String forwardedFor) { return Optional.ofNullable(forwardedFor) .map(ff -> { final int idx = forwardedFor.lastIndexOf(',') + 1; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java index 49169526c..d640a8b6b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java @@ -17,6 +17,7 @@ import io.micrometer.core.instrument.Metrics; import java.io.IOException; import java.util.Map; import java.util.Set; +import javax.annotation.Nullable; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -26,6 +27,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; +import org.whispersystems.textsecuregcm.grpc.RequestAttributesUtil; import org.whispersystems.textsecuregcm.grpc.StatusConstants; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; @@ -79,7 +81,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor { final Metadata headers, final ServerCallHandler next) { - if (shouldBlock(UserAgentUtil.userAgentFromGrpcContext())) { + if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) { call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata()); return new ServerCall.Listener<>() {}; } else { @@ -87,7 +89,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor { } } - private boolean shouldBlock(final UserAgent userAgent) { + private boolean shouldBlock(@Nullable final UserAgent userAgent) { final DynamicRemoteDeprecationConfiguration configuration = dynamicConfigurationManager .getConfiguration().getRemoteDeprecationConfiguration(); final Map minimumVersionsByPlatform = configuration.getMinimumVersions(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AcceptLanguageInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AcceptLanguageInterceptor.java deleted file mode 100644 index a80c2eb32..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AcceptLanguageInterceptor.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import com.google.common.annotations.VisibleForTesting; -import io.grpc.Context; -import io.grpc.Contexts; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.micrometer.core.instrument.Metrics; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.util.ua.UserAgent; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; -import javax.annotation.Nullable; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Locale; - -import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; - -public class AcceptLanguageInterceptor implements ServerInterceptor { - private static final Logger logger = LoggerFactory.getLogger(AcceptLanguageInterceptor.class); - private static final String INVALID_ACCEPT_LANGUAGE_COUNTER_NAME = name(AcceptLanguageInterceptor.class, "invalidAcceptLanguage"); - - @VisibleForTesting - public static final Metadata.Key ACCEPTABLE_LANGUAGES_GRPC_HEADER = - Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER); - - @Override - public ServerCall.Listener interceptCall(final ServerCall call, - final Metadata headers, - final ServerCallHandler next) { - - final List locales = parseLocales(headers.get(ACCEPTABLE_LANGUAGES_GRPC_HEADER)); - - return Contexts.interceptCall( - Context.current().withValue(AcceptLanguageUtil.ACCEPTABLE_LANGUAGES_CONTEXT_KEY, locales), - call, - headers, - next); - } - - static List parseLocales(@Nullable final String acceptableLanguagesHeader) { - if (acceptableLanguagesHeader == null) { - return Collections.emptyList(); - } - try { - final List languageRanges = Locale.LanguageRange.parse(acceptableLanguagesHeader); - return Locale.filter(languageRanges, Arrays.asList(Locale.getAvailableLocales())); - } catch (final IllegalArgumentException e) { - final UserAgent userAgent = UserAgentUtil.userAgentFromGrpcContext(); - Metrics.counter(INVALID_ACCEPT_LANGUAGE_COUNTER_NAME, "platform", userAgent.getPlatform().name().toLowerCase()).increment(); - logger.debug("Could not get acceptable languages; Accept-Language: {}; User-Agent: {}", - acceptableLanguagesHeader, - userAgent, - e); - return Collections.emptyList(); - } - } -} - diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AcceptLanguageUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AcceptLanguageUtil.java deleted file mode 100644 index a60e0c133..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/AcceptLanguageUtil.java +++ /dev/null @@ -1,17 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import io.grpc.Context; -import java.util.List; -import java.util.Locale; - -public class AcceptLanguageUtil { - static final Context.Key> ACCEPTABLE_LANGUAGES_CONTEXT_KEY = Context.key("accept-language"); - public static List localeFromGrpcContext() { - return ACCEPTABLE_LANGUAGES_CONTEXT_KEY.get(); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcHelper.java index 4469556a2..3f135b8ca 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcHelper.java @@ -111,7 +111,7 @@ public class ProfileGrpcHelper { case ACI -> { responseBuilder.setUnrestrictedUnidentifiedAccess(targetAccount.isUnrestrictedUnidentifiedAccess()) .addAllBadges(buildBadges(profileBadgeConverter.convert( - AcceptLanguageUtil.localeFromGrpcContext(), + RequestAttributesUtil.getAvailableAcceptedLocales(), targetAccount.getBadges(), ProfileHelper.isSelfProfileRequest(requesterUuid, (AciServiceIdentifier) targetIdentifier)))); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java index ee33e73bd..e3b12f690 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java @@ -5,28 +5,12 @@ package org.whispersystems.textsecuregcm.grpc; -import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.limits.RateLimiter; import reactor.core.publisher.Mono; -import java.net.SocketAddress; -import java.time.Duration; class RateLimitUtil { - private static final RateLimitExceededException UNKNOWN_REMOTE_ADDRESS_EXCEPTION = - new RateLimitExceededException(Duration.ofHours(1), true); - static Mono rateLimitByRemoteAddress(final RateLimiter rateLimiter) { - return rateLimitByRemoteAddress(rateLimiter, true); - } - - static Mono rateLimitByRemoteAddress(final RateLimiter rateLimiter, final boolean failOnUnknownRemoteAddress) { - final SocketAddress remoteAddress = RemoteAddressUtil.getRemoteAddress(); - - if (remoteAddress != null) { - return rateLimiter.validateReactive(remoteAddress.toString()); - } else { - return failOnUnknownRemoteAddress ? Mono.error(UNKNOWN_REMOTE_ADDRESS_EXCEPTION) : Mono.empty(); - } + return rateLimiter.validateReactive(RequestAttributesUtil.getRemoteAddress().getHostAddress()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RemoteAddressInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RemoteAddressInterceptor.java deleted file mode 100644 index 0fe21bdb5..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RemoteAddressInterceptor.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import io.grpc.Context; -import io.grpc.Contexts; -import io.grpc.Grpc; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import java.net.SocketAddress; - -public class RemoteAddressInterceptor implements ServerInterceptor { - - @Override - public ServerCall.Listener interceptCall(final ServerCall serverCall, - final Metadata headers, - final ServerCallHandler next) { - - // Note: the specific implementation for getting a remote client address may change depending on the client - // connection strategy. The important thing is that the remote address wind up in the context for the current - // call so it can be retrieved by `RemoteAddressUtil`. - final SocketAddress remoteAddress = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); - - return Contexts.interceptCall( - Context.current().withValue(RemoteAddressUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress), - serverCall, headers, next); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RemoteAddressUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RemoteAddressUtil.java deleted file mode 100644 index 946a6a3ba..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RemoteAddressUtil.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import io.grpc.Context; -import java.net.SocketAddress; - -public class RemoteAddressUtil { - - static final Context.Key REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address"); - - /** - * Returns the socket address of the remote client in the current gRPC request context. - * - * @return the socket address of the remote client - */ - public static SocketAddress getRemoteAddress() { - return REMOTE_ADDRESS_CONTEXT_KEY.get(); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java new file mode 100644 index 000000000..d9cdbd2d6 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java @@ -0,0 +1,75 @@ +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.netty.channel.local.LocalAddress; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; +import java.net.InetAddress; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +public class RequestAttributesInterceptor implements ServerInterceptor { + + private final ClientConnectionManager clientConnectionManager; + + private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class); + + public RequestAttributesInterceptor(final ClientConnectionManager clientConnectionManager) { + this.clientConnectionManager = clientConnectionManager; + } + + @Override + public ServerCall.Listener interceptCall(final ServerCall call, + final Metadata headers, + final ServerCallHandler next) { + + if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { + Context context = Context.current(); + + { + final Optional maybeRemoteAddress = clientConnectionManager.getRemoteAddress(localAddress); + + if (maybeRemoteAddress.isEmpty()) { + // We should never have a call from a party whose remote address we can't identify + log.warn("No remote address available"); + + call.close(Status.INTERNAL, new Metadata()); + return new ServerCall.Listener<>() {}; + } + + context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get()); + } + + { + final Optional> maybeAcceptLanguage = + clientConnectionManager.getAcceptableLanguages(localAddress); + + if (maybeAcceptLanguage.isPresent()) { + context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get()); + } + } + + { + final Optional maybeUserAgent = clientConnectionManager.getUserAgent(localAddress); + + if (maybeUserAgent.isPresent()) { + context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get()); + } + } + + return Contexts.interceptCall(context, call, headers, next); + } else { + throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtil.java new file mode 100644 index 000000000..f55f5d687 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtil.java @@ -0,0 +1,59 @@ +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Context; +import java.net.InetAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; + +public class RequestAttributesUtil { + + static final Context.Key> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language"); + static final Context.Key REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address"); + static final Context.Key USER_AGENT_CONTEXT_KEY = Context.key("user-agent"); + + private static final List AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales()); + + /** + * Returns the acceptable languages listed by the remote client in the current gRPC request context. + * + * @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified + */ + public static Optional> getAcceptableLanguages() { + return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get()); + } + + /** + * Returns a list of distinct locales supported by the JVM and accepted by the remote client in the current gRPC + * context. May be empty if the client did not supply a list of acceptable languages, if the list of acceptable + * languages could not be parsed, or if none of the acceptable languages are available in the current JVM. + * + * @return a list of distinct locales acceptable to the remote client and available in this JVM + */ + public static List getAvailableAcceptedLocales() { + return getAcceptableLanguages() + .map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES)) + .orElseGet(Collections::emptyList); + } + + /** + * Returns the remote address of the remote client in the current gRPC request context. + * + * @return the remote address of the remote client + */ + public static InetAddress getRemoteAddress() { + return REMOTE_ADDRESS_CONTEXT_KEY.get(); + } + + /** + * Returns the parsed user-agent of the remote client in the current gRPC request context. + * + * @return the parsed user-agent of the remote client; may be null if unparseable or not specified + */ + public static Optional getUserAgent() { + return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get()); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/UserAgentInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/UserAgentInterceptor.java deleted file mode 100644 index 4a0d71ef4..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/UserAgentInterceptor.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import com.google.common.annotations.VisibleForTesting; - -import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; -import org.whispersystems.textsecuregcm.util.ua.UserAgent; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; - -import io.grpc.Context; -import io.grpc.Contexts; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.Status; - -public class UserAgentInterceptor implements ServerInterceptor { - @VisibleForTesting - public static final Metadata.Key USER_AGENT_GRPC_HEADER = - Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER); - - @Override - public ServerCall.Listener interceptCall(final ServerCall call, - final Metadata headers, - final ServerCallHandler next) { - - UserAgent userAgent; - try { - userAgent = UserAgentUtil.parseUserAgentString(headers.get(USER_AGENT_GRPC_HEADER)); - } catch (final UnrecognizedUserAgentException e) { - userAgent = null; - } - - final Context context = Context.current().withValue(UserAgentUtil.USER_AGENT_CONTEXT_KEY, userAgent); - return Contexts.interceptCall(context, call, headers, next); - } - -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientConnectionManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientConnectionManager.java new file mode 100644 index 000000000..73aa895c2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientConnectionManager.java @@ -0,0 +1,218 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.google.common.annotations.VisibleForTesting; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.util.AttributeKey; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import javax.annotation.Nullable; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; +import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; + +/** + * A client connection manager associates a local connection to a local gRPC server with a remote connection through a + * Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the + * authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close + * connections associated with a given device if that device's credentials have changed and clients must reauthenticate. + */ +public class ClientConnectionManager { + + private final Map remoteChannelsByLocalAddress = new ConcurrentHashMap<>(); + private final Map> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>(); + + @VisibleForTesting + static final AttributeKey AUTHENTICATED_DEVICE_ATTRIBUTE_KEY = + AttributeKey.valueOf(ClientConnectionManager.class, "authenticatedDevice"); + + @VisibleForTesting + static final AttributeKey REMOTE_ADDRESS_ATTRIBUTE_KEY = + AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress"); + + @VisibleForTesting + static final AttributeKey RAW_USER_AGENT_ATTRIBUTE_KEY = + AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent"); + + @VisibleForTesting + static final AttributeKey PARSED_USER_AGENT_ATTRIBUTE_KEY = + AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent"); + + @VisibleForTesting + static final AttributeKey> ACCEPT_LANGUAGE_ATTRIBUTE_KEY = + AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage"); + + private static final Logger log = LoggerFactory.getLogger(ClientConnectionManager.class); + + /** + * Returns the authenticated device associated with the given local address, if any. An authenticated device is + * available if and only if the given local address maps to an active local connection and that connection is + * authenticated (i.e. not anonymous). + * + * @param localAddress the local address for which to find an authenticated device + * + * @return the authenticated device associated with the given local address, if any + */ + public Optional getAuthenticatedDevice(final LocalAddress localAddress) { + return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress)); + } + + private Optional getAuthenticatedDevice(@Nullable final Channel remoteChannel) { + return Optional.ofNullable(remoteChannel) + .map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get()); + } + + /** + * Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may + * be unavailable if the local connection associated with the given local address has already closed, if the client + * did not provide a list of acceptable languages, or the list provided by the client could not be parsed. + * + * @param localAddress the local address for which to find acceptable languages + * + * @return the acceptable languages associated with the given local address, if any + */ + public Optional> getAcceptableLanguages(final LocalAddress localAddress) { + return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) + .map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get()); + } + + /** + * Returns the remote address associated with the given local address, if any. A remote address may be unavailable if + * the local connection associated with the given local address has already closed. + * + * @param localAddress the local address for which to find a remote address + * + * @return the remote address associated with the given local address, if any + */ + public Optional getRemoteAddress(final LocalAddress localAddress) { + return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) + .map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); + } + + /** + * Returns the parsed user agent provided by the client the opened the connection associated with the given local + * address. This method may return an empty value if no active local connection is associated with the given local + * address or if the client's user-agent string was not recognized. + * + * @param localAddress the local address for which to find a User-Agent string + * + * @return the user agent associated with the given local address + */ + public Optional getUserAgent(final LocalAddress localAddress) { + return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) + .map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get()); + } + + /** + * Closes any client connections to this host associated with the given authenticated device. + * + * @param authenticatedDevice the authenticated device for which to close connections + */ + public void closeConnection(final AuthenticatedDevice authenticatedDevice) { + // Channels will actually get removed from the list/map by their closeFuture listeners + remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()).forEach(channel -> + channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED + .toWebSocketCloseStatus("Reauthentication required"))) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE)); + } + + @VisibleForTesting + @Nullable List getRemoteChannelsByAuthenticatedDevice(final AuthenticatedDevice authenticatedDevice) { + return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice); + } + + @VisibleForTesting + Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) { + return remoteChannelsByLocalAddress.get(localAddress); + } + + /** + * Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake + * request with the channel via which the handshake took place. + * + * @param channel the channel that completed a WebSocket handshake + * @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake + * @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null} + * @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be + * {@code null} + */ + static void handleWebSocketHandshakeComplete(final Channel channel, + final InetAddress preferredRemoteAddress, + @Nullable final String userAgentHeader, + @Nullable final String acceptLanguageHeader) { + + channel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress); + + if (StringUtils.isNotBlank(userAgentHeader)) { + channel.attr(ClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader); + + try { + channel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY) + .set(UserAgentUtil.parseUserAgentString(userAgentHeader)); + } catch (final UnrecognizedUserAgentException ignored) { + } + } + + if (StringUtils.isNotBlank(acceptLanguageHeader)) { + try { + channel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader)); + } catch (final IllegalArgumentException e) { + log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e); + } + } + } + + /** + * Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server. + * + * @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC + * server + * @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel + * @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection + */ + void handleConnectionEstablished(final LocalChannel localChannel, + final Channel remoteChannel, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional maybeAuthenticatedDevice) { + + maybeAuthenticatedDevice.ifPresent(authenticatedDevice -> + remoteChannel.attr(ClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice)); + + remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel); + + getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice -> + remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> { + final List channels = existingChannelList != null ? existingChannelList : new ArrayList<>(); + channels.add(remoteChannel); + + return channels; + })); + + remoteChannel.closeFuture().addListener(closeFuture -> { + remoteChannelsByLocalAddress.remove(localChannel.localAddress()); + + getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice -> + remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> { + if (existingChannelList == null) { + return null; + } + + existingChannelList.remove(remoteChannel); + + return existingChannelList.isEmpty() ? null : existingChannelList; + })); + }); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java index 37be75df0..c8295a1c7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java @@ -22,15 +22,21 @@ import org.slf4j.LoggerFactory; */ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { + private final ClientConnectionManager clientConnectionManager; + private final LocalAddress authenticatedGrpcServerAddress; private final LocalAddress anonymousGrpcServerAddress; + private final List pendingReads = new ArrayList<>(); private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class); - public EstablishLocalGrpcConnectionHandler(final LocalAddress authenticatedGrpcServerAddress, + public EstablishLocalGrpcConnectionHandler(final ClientConnectionManager clientConnectionManager, + final LocalAddress authenticatedGrpcServerAddress, final LocalAddress anonymousGrpcServerAddress) { + this.clientConnectionManager = clientConnectionManager; + this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress; this.anonymousGrpcServerAddress = anonymousGrpcServerAddress; } @@ -41,7 +47,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { } @Override - public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) throws Exception { + public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) { if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) { // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to @@ -53,7 +59,6 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { new Bootstrap() .remoteAddress(grpcServerAddress) - // TODO Set local address .channel(LocalChannel.class) .group(remoteChannelContext.channel().eventLoop()) .handler(new ChannelInitializer() { @@ -63,15 +68,19 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { } }) .connect() - .addListener((ChannelFutureListener) future -> { - if (future.isSuccess()) { + .addListener((ChannelFutureListener) localChannelFuture -> { + if (localChannelFuture.isSuccess()) { + clientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(), + remoteChannelContext.channel(), + noiseHandshakeCompleteEvent.authenticatedDevice()); + // Close the local connection if the remote channel closes and vice versa - remoteChannelContext.channel().closeFuture().addListener(closeFuture -> future.channel().close()); - future.channel().closeFuture().addListener(closeFuture -> + remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close()); + localChannelFuture.channel().closeFuture().addListener(closeFuture -> remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART))); remoteChannelContext.pipeline() - .addAfter(remoteChannelContext.name(), null, new ProxyHandler(future.channel())); + .addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel())); // Flush any buffered reads we accumulated while waiting to open the connection pendingReads.forEach(remoteChannelContext::fireChannelRead); @@ -79,7 +88,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this); } else { - log.warn("Failed to establish local connection to gRPC server", future.cause()); + log.warn("Failed to establish local connection to gRPC server", localChannelFuture.cause()); remoteChannelContext.close(); } }); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java new file mode 100644 index 000000000..d31ea931e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java @@ -0,0 +1,134 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.InetAddresses; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.Optional; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; + +/** + * A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate + * Noise handshake handler for the requested path. + */ +class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { + + private final ClientPublicKeysManager clientPublicKeysManager; + + private final ECKeyPair ecKeyPair; + private final byte[] publicKeySignature; + + private final byte[] recognizedProxySecret; + + private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class); + + @VisibleForTesting + static final String RECOGNIZED_PROXY_SECRET_HEADER = "X-Signal-Recognized-Proxy"; + + @VisibleForTesting + static final String FORWARDED_FOR_HEADER = "X-Forwarded-For"; + + WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair ecKeyPair, + final byte[] publicKeySignature, + final String recognizedProxySecret) { + + this.clientPublicKeysManager = clientPublicKeysManager; + this.ecKeyPair = ecKeyPair; + this.publicKeySignature = publicKeySignature; + + // The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex- + // encoded value). We convert it into a byte array here for easier constant-time comparisons via + // MessageDigest.equals() later. + this.recognizedProxySecret = recognizedProxySecret.getBytes(StandardCharsets.UTF_8); + } + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) { + if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) { + final InetAddress preferredRemoteAddress; + { + final Optional maybePreferredRemoteAddress = + getPreferredRemoteAddress(context, handshakeCompleteEvent); + + if (maybePreferredRemoteAddress.isEmpty()) { + context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, + "Could not determine remote address")) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + + return; + } + + preferredRemoteAddress = maybePreferredRemoteAddress.get(); + } + + ClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(), + preferredRemoteAddress, + handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT), + handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE)); + + final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) { + case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH -> + new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature); + + case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH -> + new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature); + + default -> { + // The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an + // internal error if something slipped through. + throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri()); + } + }; + + context.pipeline().replace(WebsocketHandshakeCompleteHandler.this, null, noiseHandshakeHandler); + } + + context.fireUserEventTriggered(event); + } + + private Optional getPreferredRemoteAddress(final ChannelHandlerContext context, + final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) { + + final byte[] recognizedProxySecretFromHeader = + handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "") + .getBytes(StandardCharsets.UTF_8); + + final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader); + + if (trustForwardedFor && handshakeCompleteEvent.requestHeaders().contains(FORWARDED_FOR_HEADER)) { + final String forwardedFor = handshakeCompleteEvent.requestHeaders().get(FORWARDED_FOR_HEADER); + + return RemoteAddressFilter.getMostRecentProxy(forwardedFor).map(mostRecentProxy -> { + try { + return InetAddresses.forString(mostRecentProxy); + } catch (final IllegalArgumentException e) { + log.warn("Failed to parse forwarded-for address: {}", forwardedFor, e); + return null; + } + }); + } else { + // Either we don't trust the forwarded-for header or it's not present + if (context.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) { + return Optional.of(inetSocketAddress.getAddress()); + } else { + log.warn("Channel's remote address was not an InetSocketAddress"); + return Optional.empty(); + } + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java deleted file mode 100644 index 5b743a61c..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java +++ /dev/null @@ -1,52 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; - -/** - * A WebSocket handshake listener waits for a WebSocket handshake to complete, then replaces itself with the appropriate - * Noise handshake handler for the requested path. - */ -class WebsocketHandshakeCompleteListener extends ChannelInboundHandlerAdapter { - - private final ClientPublicKeysManager clientPublicKeysManager; - - private final ECKeyPair ecKeyPair; - private final byte[] publicKeySignature; - - WebsocketHandshakeCompleteListener(final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair ecKeyPair, - final byte[] publicKeySignature) { - - this.clientPublicKeysManager = clientPublicKeysManager; - this.ecKeyPair = ecKeyPair; - this.publicKeySignature = publicKeySignature; - } - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) { - if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) { - final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) { - case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH -> - new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature); - - case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH -> - new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature); - - default -> { - // The HttpHandler should have caught all of these cases already; we'll consider it an internal error if - // something slipped through. - throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri()); - } - }; - - context.pipeline().replace(WebsocketHandshakeCompleteListener.this, null, noiseHandshakeHandler); - } - - context.fireUserEventTriggered(event); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java index f60e0f838..2f44150eb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java @@ -48,11 +48,13 @@ public class WebsocketNoiseTunnelServer implements Managed { final PrivateKey tlsPrivateKey, final NioEventLoopGroup eventLoopGroup, final Executor delegatedTaskExecutor, + final ClientConnectionManager clientConnectionManager, final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair, final byte[] publicKeySignature, final LocalAddress authenticatedGrpcServerAddress, - final LocalAddress anonymousGrpcServerAddress) throws SSLException { + final LocalAddress anonymousGrpcServerAddress, + final String recognizedProxySecret) throws SSLException { final SslProvider sslProvider; @@ -88,10 +90,10 @@ public class WebsocketNoiseTunnelServer implements Managed { .addLast(new RejectUnsupportedMessagesHandler()) // The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once // a WebSocket handshake has been completed - .addLast(new WebsocketHandshakeCompleteListener(clientPublicKeysManager, ecKeyPair, publicKeySignature)) + .addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature, recognizedProxySecret)) // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler // once the Noise handshake has completed - .addLast(new EstablishLocalGrpcConnectionHandler(authenticatedGrpcServerAddress, anonymousGrpcServerAddress)) + .addLast(new EstablishLocalGrpcConnectionHandler(clientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress)) .addLast(new ErrorHandler()); } }); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/ua/UserAgentUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/ua/UserAgentUtil.java index fe04b9f91..c9bea4101 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/ua/UserAgentUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/ua/UserAgentUtil.java @@ -7,15 +7,12 @@ package org.whispersystems.textsecuregcm.util.ua; import com.google.common.annotations.VisibleForTesting; import com.vdurmont.semver4j.Semver; -import io.grpc.Context; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.commons.lang3.StringUtils; public class UserAgentUtil { - public static final Context.Key USER_AGENT_CONTEXT_KEY = Context.key("x-signal-user-agent"); - private static final Pattern STANDARD_UA_PATTERN = Pattern.compile("^Signal-(Android|Desktop|iOS)/([^ ]+)( (.+))?$", Pattern.CASE_INSENSITIVE); public static UserAgent parseUserAgentString(final String userAgentString) throws UnrecognizedUserAgentException { @@ -36,10 +33,6 @@ public class UserAgentUtil { throw new UnrecognizedUserAgentException(); } - public static UserAgent userAgentFromGrpcContext() { - return USER_AGENT_CONTEXT_KEY.get(); - } - @VisibleForTesting static UserAgent parseStandardUserAgentString(final String userAgentString) { final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptorTest.java new file mode 100644 index 000000000..81baf8686 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptorTest.java @@ -0,0 +1,77 @@ +package org.whispersystems.textsecuregcm.auth.grpc; + +import static org.mockito.Mockito.mock; + +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import java.io.IOException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; +import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; +import org.signal.chat.rpc.RequestAttributesGrpc; +import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; +import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager; + +abstract class AbstractAuthenticationInterceptorTest { + + private static DefaultEventLoopGroup eventLoopGroup; + + private ClientConnectionManager clientConnectionManager; + + private Server server; + private ManagedChannel managedChannel; + + @BeforeAll + static void setUpBeforeAll() { + eventLoopGroup = new DefaultEventLoopGroup(); + } + + @BeforeEach + void setUp() throws IOException { + final LocalAddress serverAddress = new LocalAddress("test-authentication-interceptor-server"); + + clientConnectionManager = mock(ClientConnectionManager.class); + + // `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make + // sure that we're using local channels and addresses + server = NettyServerBuilder.forAddress(serverAddress) + .channelType(LocalServerChannel.class) + .bossEventLoopGroup(eventLoopGroup) + .workerEventLoopGroup(eventLoopGroup) + .intercept(getInterceptor()) + .addService(new RequestAttributesServiceImpl()) + .build() + .start(); + + managedChannel = NettyChannelBuilder.forAddress(serverAddress) + .channelType(LocalChannel.class) + .eventLoopGroup(eventLoopGroup) + .usePlaintext() + .build(); + } + + @AfterEach + void tearDown() { + managedChannel.shutdown(); + server.shutdown(); + } + + protected abstract AbstractAuthenticationInterceptor getInterceptor(); + + protected ClientConnectionManager getClientConnectionManager() { + return clientConnectionManager; + } + + protected GetAuthenticatedDeviceResponse getAuthenticatedDevice() { + return RequestAttributesGrpc.newBlockingStub(managedChannel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java deleted file mode 100644 index 8bf1a184a..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth.grpc; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.grpc.CallCredentials; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.Server; -import io.grpc.Status; -import io.grpc.StatusRuntimeException; -import io.grpc.inprocess.InProcessChannelBuilder; -import io.grpc.inprocess.InProcessServerBuilder; -import java.io.IOException; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.Executor; -import java.util.stream.Stream; -import org.apache.commons.lang3.RandomStringUtils; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.signal.chat.rpc.EchoRequest; -import org.signal.chat.rpc.EchoServiceGrpc; -import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; -import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; -import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.HeaderUtils; -import org.whispersystems.textsecuregcm.util.Pair; - -class BasicCredentialAuthenticationInterceptorTest { - - private Server server; - private ManagedChannel managedChannel; - - private AccountAuthenticator accountAuthenticator; - - - @BeforeEach - void setUp() throws IOException { - accountAuthenticator = mock(AccountAuthenticator.class); - - final BasicCredentialAuthenticationInterceptor authenticationInterceptor = - new BasicCredentialAuthenticationInterceptor(accountAuthenticator); - - final String serverName = InProcessServerBuilder.generateName(); - - server = InProcessServerBuilder.forName(serverName) - .directExecutor() - .intercept(authenticationInterceptor) - .addService(new EchoServiceImpl()) - .build() - .start(); - - managedChannel = InProcessChannelBuilder.forName(serverName) - .directExecutor() - .build(); - } - - @AfterEach - void tearDown() { - managedChannel.shutdown(); - server.shutdown(); - } - - @ParameterizedTest - @MethodSource - void interceptCall(final Metadata headers, final boolean acceptCredentials, final boolean expectAuthentication) { - if (acceptCredentials) { - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(UUID.randomUUID()); - - final Device device = mock(Device.class); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - - when(accountAuthenticator.authenticate(any())) - .thenReturn(Optional.of(new AuthenticatedAccount(account, device))); - } else { - when(accountAuthenticator.authenticate(any())) - .thenReturn(Optional.empty()); - } - - final EchoServiceGrpc.EchoServiceBlockingStub stub = EchoServiceGrpc.newBlockingStub(managedChannel) - .withCallCredentials(new CallCredentials() { - @Override - public void applyRequestMetadata(final RequestInfo requestInfo, final Executor appExecutor, final MetadataApplier applier) { - applier.apply(headers); - } - - @Override - public void thisUsesUnstableApi() { - } - }); - - if (expectAuthentication) { - assertDoesNotThrow(() -> stub.echo(EchoRequest.newBuilder().build())); - } else { - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> stub.echo(EchoRequest.newBuilder().build())); - - assertEquals(Status.UNAUTHENTICATED.getCode(), exception.getStatus().getCode()); - } - } - - private static Stream interceptCall() { - final Metadata malformedCredentialHeaders = new Metadata(); - malformedCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, "Incorrect"); - - final Metadata structurallyValidCredentialHeaders = new Metadata(); - structurallyValidCredentialHeaders.put( - BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, - HeaderUtils.basicAuthHeader(UUID.randomUUID().toString(), RandomStringUtils.randomAlphanumeric(16)) - ); - - return Stream.of( - Arguments.of(new Metadata(), true, false), - Arguments.of(malformedCredentialHeaders, true, false), - Arguments.of(structurallyValidCredentialHeaders, false, false), - Arguments.of(structurallyValidCredentialHeaders, true, true) - ); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java index 1db4e8c45..de8da8350 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java @@ -13,15 +13,14 @@ import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import java.util.UUID; import javax.annotation.Nullable; -import org.whispersystems.textsecuregcm.util.Pair; public class MockAuthenticationInterceptor implements ServerInterceptor { @Nullable - private Pair authenticatedDevice; + private AuthenticatedDevice authenticatedDevice; public void setAuthenticatedDevice(final UUID accountIdentifier, final byte deviceId) { - authenticatedDevice = new Pair<>(accountIdentifier, deviceId); + authenticatedDevice = new AuthenticatedDevice(accountIdentifier, deviceId); } public void clearAuthenticatedDevice() { @@ -33,14 +32,10 @@ public class MockAuthenticationInterceptor implements ServerInterceptor { final Metadata headers, final ServerCallHandler next) { - if (authenticatedDevice != null) { - final Context context = Context.current() - .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedDevice.first()) - .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedDevice.second()); - - return Contexts.interceptCall(context, call, headers, next); - } - - return next.startCall(call, headers); + return authenticatedDevice != null + ? Contexts.interceptCall( + Context.current().withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), + call, headers, next) + : next.startCall(call, headers); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java new file mode 100644 index 000000000..50f74d980 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java @@ -0,0 +1,40 @@ +package org.whispersystems.textsecuregcm.auth.grpc; + +import io.grpc.Status; +import org.junit.jupiter.api.Test; +import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; +import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; +import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +import java.util.Optional; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest { + + @Override + protected AbstractAuthenticationInterceptor getInterceptor() { + return new ProhibitAuthenticationInterceptor(getClientConnectionManager()); + } + + @Test + void interceptCall() { + final ClientConnectionManager clientConnectionManager = getClientConnectionManager(); + + when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); + + final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice(); + assertTrue(response.getAccountIdentifier().isEmpty()); + assertEquals(0, response.getDeviceId()); + + final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); + when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); + + GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java new file mode 100644 index 000000000..a8b2f5e8f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java @@ -0,0 +1,39 @@ +package org.whispersystems.textsecuregcm.auth.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import io.grpc.Status; +import java.util.Optional; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; +import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; +import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest { + + @Override + protected AbstractAuthenticationInterceptor getInterceptor() { + return new RequireAuthenticationInterceptor(getClientConnectionManager()); + } + + @Test + void interceptCall() { + final ClientConnectionManager clientConnectionManager = getClientConnectionManager(); + + when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); + + GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); + + final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); + when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); + + final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice(); + assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier()); + assertEquals(authenticatedDevice.deviceId(), response.getDeviceId()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilterTest.java index 11db14dca..77382640a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilterTest.java @@ -39,10 +39,12 @@ import org.signal.chat.rpc.EchoServiceGrpc; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; +import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.StatusConstants; -import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; +import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; +import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; class RemoteDeprecationFilterTest { @@ -126,17 +128,25 @@ class RemoteDeprecationFilterTest { @ParameterizedTest @MethodSource(value="testFilter") - void testGrpcFilter(final String userAgent, final boolean expectDeprecation) throws Exception { + void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException { + final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); + + try { + mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString)); + } catch (UnrecognizedUserAgentException ignored) { + } + final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest") .directExecutor() .addService(new EchoServiceImpl()) .intercept(filterConfiguredForTest()) - .intercept(new UserAgentInterceptor()) + .intercept(mockRequestAttributesInterceptor) .build() .start(); + final ManagedChannel channel = InProcessChannelBuilder.forName("RemoteDeprecationFilterTest") .directExecutor() - .userAgent(userAgent) + .userAgent(userAgentString) .build(); try { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AcceptLanguageInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AcceptLanguageInterceptorTest.java deleted file mode 100644 index c56ee22f7..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AcceptLanguageInterceptorTest.java +++ /dev/null @@ -1,79 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc; - -import com.google.protobuf.ByteString; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.Server; -import io.grpc.inprocess.InProcessChannelBuilder; -import io.grpc.inprocess.InProcessServerBuilder; -import io.grpc.stub.MetadataUtils; -import io.grpc.stub.StreamObserver; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.signal.chat.rpc.EchoRequest; -import org.signal.chat.rpc.EchoResponse; -import org.signal.chat.rpc.EchoServiceGrpc; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.Locale; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Stream; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class AcceptLanguageInterceptorTest { - @ParameterizedTest - @MethodSource - void parseLocale(final String header, final List expectedLocales) throws IOException, InterruptedException { - final AtomicReference> observedLocales = new AtomicReference<>(null); - final EchoServiceImpl serviceImpl = new EchoServiceImpl() { - @Override - public void echo(EchoRequest req, StreamObserver responseObserver) { - observedLocales.set(AcceptLanguageUtil.localeFromGrpcContext()); - super.echo(req, responseObserver); - } - }; - - final Server testServer = InProcessServerBuilder.forName("AcceptLanguageTest") - .directExecutor() - .addService(serviceImpl) - .intercept(new AcceptLanguageInterceptor()) - .intercept(new UserAgentInterceptor()) - .build() - .start(); - - try { - final ManagedChannel channel = InProcessChannelBuilder.forName("AcceptLanguageTest") - .directExecutor() - .userAgent("Signal-Android/1.2.3") - .build(); - - final Metadata metadata = new Metadata(); - metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, header); - - final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel) - .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); - - final EchoRequest request = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("test request")).build(); - client.echo(request); - assertEquals(expectedLocales, observedLocales.get()); - } finally { - testServer.shutdownNow(); - testServer.awaitTermination(); - } - } - - private static Stream parseLocale() { - return Stream.of( - // en-US-POSIX is a special locale that exists alongside en-US. It matches because of the definition of - // basic filtering in RFC 4647 (https://datatracker.ietf.org/doc/html/rfc4647#section-3.3.1) - Arguments.of("en-US,fr-CA", List.of(Locale.forLanguageTag("en-US-POSIX"), Locale.forLanguageTag("en-US"), Locale.forLanguageTag("fr-CA"))), - Arguments.of("en-US; q=0.9, fr-CA", List.of(Locale.forLanguageTag("fr-CA"), Locale.forLanguageTag("en-US-POSIX"), Locale.forLanguageTag("en-US"))), - Arguments.of("invalid-locale,fr-CA", List.of(Locale.forLanguageTag("fr-CA"))), - Arguments.of("", Collections.emptyList()), - Arguments.of("acompletely,unexpectedfor , mat", Collections.emptyList()) - ); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsAnonymousGrpcServiceTest.java index 2feb4d802..c0d80af7c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsAnonymousGrpcServiceTest.java @@ -13,9 +13,9 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.common.net.InetAddresses; import com.google.protobuf.ByteString; import io.grpc.Status; -import java.net.InetSocketAddress; import java.time.Duration; import java.util.Optional; import java.util.UUID; @@ -72,7 +72,7 @@ class AccountsAnonymousGrpcServiceTest extends when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); - getMockRemoteAddressInterceptor().setRemoteAddress(new InetSocketAddress("127.0.0.1", 12345)); + getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1")); return new AccountsAnonymousGrpcService(accountsManager, rateLimiters); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java index 6f0a5b0c4..a2724b1bd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java @@ -29,21 +29,21 @@ public final class GrpcTestUtils { public static void setupAuthenticatedExtension( final GrpcServerExtension extension, final MockAuthenticationInterceptor mockAuthenticationInterceptor, - final MockRemoteAddressInterceptor mockRemoteAddressInterceptor, + final MockRequestAttributesInterceptor mockRequestAttributesInterceptor, final UUID authenticatedAci, final byte authenticatedDeviceId, final BindableService service) { mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId); extension.getServiceRegistry() - .addService(ServerInterceptors.intercept(service, mockRemoteAddressInterceptor, mockAuthenticationInterceptor, new ErrorMappingInterceptor())); + .addService(ServerInterceptors.intercept(service, mockRequestAttributesInterceptor, mockAuthenticationInterceptor, new ErrorMappingInterceptor())); } public static void setupUnauthenticatedExtension( final GrpcServerExtension extension, - final MockRemoteAddressInterceptor mockRemoteAddressInterceptor, + final MockRequestAttributesInterceptor mockRequestAttributesInterceptor, final BindableService service) { extension.getServiceRegistry() - .addService(ServerInterceptors.intercept(service, mockRemoteAddressInterceptor, new ErrorMappingInterceptor())); + .addService(ServerInterceptors.intercept(service, mockRequestAttributesInterceptor, new ErrorMappingInterceptor())); } public static void assertStatusException(final Status expected, final Executable serviceCall) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRemoteAddressInterceptor.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRemoteAddressInterceptor.java deleted file mode 100644 index a30aa79db..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRemoteAddressInterceptor.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import io.grpc.Context; -import io.grpc.Contexts; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import java.net.SocketAddress; -import javax.annotation.Nullable; - -public class MockRemoteAddressInterceptor implements ServerInterceptor { - - @Nullable - private SocketAddress remoteAddress; - - public void setRemoteAddress(@Nullable final SocketAddress remoteAddress) { - this.remoteAddress = remoteAddress; - } - - @Override - public ServerCall.Listener interceptCall(final ServerCall serverCall, - final Metadata headers, - final ServerCallHandler next) { - - return remoteAddress == null - ? next.startCall(serverCall, headers) - : Contexts.interceptCall( - Context.current().withValue(RemoteAddressUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress), - serverCall, headers, next); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRequestAttributesInterceptor.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRequestAttributesInterceptor.java new file mode 100644 index 000000000..7d662f3a0 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MockRequestAttributesInterceptor.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import java.net.InetAddress; +import java.util.List; +import java.util.Locale; +import javax.annotation.Nullable; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; + +public class MockRequestAttributesInterceptor implements ServerInterceptor { + + @Nullable + private InetAddress remoteAddress; + + @Nullable + private UserAgent userAgent; + + @Nullable + private List acceptLanguage; + + public void setRemoteAddress(@Nullable final InetAddress remoteAddress) { + this.remoteAddress = remoteAddress; + } + + public void setUserAgent(@Nullable final UserAgent userAgent) { + this.userAgent = userAgent; + } + + public void setAcceptLanguage(@Nullable final List acceptLanguage) { + this.acceptLanguage = acceptLanguage; + } + + @Override + public ServerCall.Listener interceptCall(final ServerCall serverCall, + final Metadata headers, + final ServerCallHandler next) { + + Context context = Context.current(); + + if (remoteAddress != null) { + context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress); + } + + if (userAgent != null) { + context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent); + } + + if (acceptLanguage != null) { + context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage); + } + + return Contexts.interceptCall(context, serverCall, headers, next); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java index c6f0e2abf..fc7488df3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java @@ -17,16 +17,13 @@ import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import com.google.protobuf.ByteString; -import io.grpc.Channel; -import io.grpc.Metadata; import io.grpc.Status; -import io.grpc.stub.MetadataUtils; -import java.lang.reflect.InvocationTargetException; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -79,6 +76,8 @@ import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil; +import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; +import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest { @@ -100,6 +99,14 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest responseObserver) { + + final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder(); + + RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages -> + acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()))); + + RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale -> + responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag())); + + responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress()); + + RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder() + .setPlatform(userAgent.getPlatform().toString()) + .setVersion(userAgent.getVersion().toString()) + .setAdditionalSpecifiers(userAgent.getAdditionalSpecifiers().orElse("")) + .build())); + + responseObserver.onNext(responseBuilder.build()); + responseObserver.onCompleted(); + } + + @Override + public void getAuthenticatedDevice(final GetAuthenticatedDeviceRequest request, + final StreamObserver responseObserver) { + + final GetAuthenticatedDeviceResponse.Builder responseBuilder = GetAuthenticatedDeviceResponse.newBuilder(); + + try { + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); + + responseBuilder.setAccountIdentifier(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier())); + responseBuilder.setDeviceId(authenticatedDevice.deviceId()); + } catch (final Exception ignored) { + } + + responseObserver.onNext(responseBuilder.build()); + responseObserver.onCompleted(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtilTest.java new file mode 100644 index 000000000..ef8e64e65 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesUtilTest.java @@ -0,0 +1,160 @@ +package org.whispersystems.textsecuregcm.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.net.InetAddresses; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import org.apache.commons.lang3.StringUtils; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.signal.chat.rpc.GetRequestAttributesRequest; +import org.signal.chat.rpc.GetRequestAttributesResponse; +import org.signal.chat.rpc.RequestAttributesGrpc; +import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager; +import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; +import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; + +class RequestAttributesUtilTest { + + private static DefaultEventLoopGroup eventLoopGroup; + + private ClientConnectionManager clientConnectionManager; + + private Server server; + private ManagedChannel managedChannel; + + @BeforeAll + static void setUpBeforeAll() { + eventLoopGroup = new DefaultEventLoopGroup(); + } + + @BeforeEach + void setUp() throws IOException { + final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server"); + + clientConnectionManager = mock(ClientConnectionManager.class); + + when(clientConnectionManager.getRemoteAddress(any())) + .thenReturn(Optional.of(InetAddresses.forString("127.0.0.1"))); + + // `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make + // sure that we're using local channels and addresses + server = NettyServerBuilder.forAddress(serverAddress) + .channelType(LocalServerChannel.class) + .bossEventLoopGroup(eventLoopGroup) + .workerEventLoopGroup(eventLoopGroup) + .intercept(new RequestAttributesInterceptor(clientConnectionManager)) + .addService(new RequestAttributesServiceImpl()) + .build() + .start(); + + managedChannel = NettyChannelBuilder.forAddress(serverAddress) + .channelType(LocalChannel.class) + .eventLoopGroup(eventLoopGroup) + .usePlaintext() + .build(); + } + + @AfterEach + void tearDown() { + managedChannel.shutdown(); + server.shutdown(); + } + + @AfterAll + static void tearDownAfterAll() throws InterruptedException { + eventLoopGroup.shutdownGracefully().await(); + } + + @Test + void getAcceptableLanguages() { + when(clientConnectionManager.getAcceptableLanguages(any())) + .thenReturn(Optional.empty()); + + assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty()); + + when(clientConnectionManager.getAcceptableLanguages(any())) + .thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja"))); + + assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList()); + } + + @Test + void getAvailableAcceptedLocales() { + when(clientConnectionManager.getAcceptableLanguages(any())) + .thenReturn(Optional.empty()); + + assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty()); + + when(clientConnectionManager.getAcceptableLanguages(any())) + .thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja"))); + + final GetRequestAttributesResponse response = getRequestAttributes(); + + assertFalse(response.getAvailableAcceptedLocalesList().isEmpty()); + response.getAvailableAcceptedLocalesList().forEach(languageTag -> { + final Locale locale = Locale.forLanguageTag(languageTag); + assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage())); + }); + } + + @Test + void getRemoteAddress() { + when(clientConnectionManager.getRemoteAddress(any())) + .thenReturn(Optional.empty()); + + GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes); + + final String remoteAddressString = "6.7.8.9"; + + when(clientConnectionManager.getRemoteAddress(any())) + .thenReturn(Optional.of(InetAddresses.forString(remoteAddressString))); + + assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress()); + } + + @Test + void getUserAgent() throws UnrecognizedUserAgentException { + when(clientConnectionManager.getUserAgent(any())) + .thenReturn(Optional.empty()); + + assertFalse(getRequestAttributes().hasUserAgent()); + + final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux"); + + when(clientConnectionManager.getUserAgent(any())) + .thenReturn(Optional.of(userAgent)); + + final GetRequestAttributesResponse response = getRequestAttributes(); + assertTrue(response.hasUserAgent()); + assertEquals("DESKTOP", response.getUserAgent().getPlatform()); + assertEquals("1.2.3", response.getUserAgent().getVersion()); + assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers()); + } + + private GetRequestAttributesResponse getRequestAttributes() { + return RequestAttributesGrpc.newBlockingStub(managedChannel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java index 246ea8638..bf24c6750 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java @@ -60,7 +60,7 @@ public abstract class SimpleBaseGrpcTest observedUserAgent = new AtomicReference<>(null); - final EchoServiceImpl serviceImpl = new EchoServiceImpl() { - @Override - public void echo(EchoRequest req, StreamObserver responseObserver) { - observedUserAgent.set(UserAgentUtil.userAgentFromGrpcContext()); - super.echo(req, responseObserver); - } - }; - - final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest") - .directExecutor() - .addService(serviceImpl) - .intercept(new UserAgentInterceptor()) - .build() - .start(); - - try { - final ManagedChannel channel = InProcessChannelBuilder.forName("RemoteDeprecationFilterTest") - .directExecutor() - .userAgent(header) - .build(); - - final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel); - - final EchoRequest req = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("cluck cluck, i'm a parrot")).build(); - assertEquals("cluck cluck, i'm a parrot", client.echo(req).getPayload().toStringUtf8()); - if (platform == null) { - assertNull(observedUserAgent.get()); - } else { - assertEquals(platform, observedUserAgent.get().getPlatform()); - assertEquals(new Semver(version), observedUserAgent.get().getVersion()); - // can't assert on the additional specifiers because they include internal details of the grpc in-process channel itself - } - } finally { - testServer.shutdownNow(); - testServer.awaitTermination(); - } - } - - private static Stream testInterceptor() { - return Stream.of( - Arguments.of(null, null, null), - Arguments.of("", null, null), - Arguments.of("Unrecognized UA", null, null), - Arguments.of("Signal-Android/4.68.3", ClientPlatform.ANDROID, "4.68.3"), - Arguments.of("Signal-iOS/3.9.0", ClientPlatform.IOS, "3.9.0"), - Arguments.of("Signal-Desktop/1.2.3", ClientPlatform.DESKTOP, "1.2.3"), - Arguments.of("Signal-Desktop/8.0.0-beta.2", ClientPlatform.DESKTOP, "8.0.0-beta.2"), - Arguments.of("Signal-iOS/8.0.0-beta.2", ClientPlatform.IOS, "8.0.0-beta.2")); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java deleted file mode 100644 index 112c1f20f..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java +++ /dev/null @@ -1,21 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.grpc.stub.StreamObserver; -import org.signal.chat.rpc.AuthenticationTypeGrpc; -import org.signal.chat.rpc.GetAuthenticatedRequest; -import org.signal.chat.rpc.GetAuthenticatedResponse; - -public class AuthenticationTypeService extends AuthenticationTypeGrpc.AuthenticationTypeImplBase { - - private final boolean authenticated; - - public AuthenticationTypeService(final boolean authenticated) { - this.authenticated = authenticated; - } - - @Override - public void getAuthenticated(final GetAuthenticatedRequest request, final StreamObserver responseObserver) { - responseObserver.onNext(GetAuthenticatedResponse.newBuilder().setAuthenticated(authenticated).build()); - responseObserver.onCompleted(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientConnectionManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientConnectionManagerTest.java new file mode 100644 index 000000000..6921e39f2 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientConnectionManagerTest.java @@ -0,0 +1,283 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.google.common.net.InetAddresses; +import com.vdurmont.semver4j.Semver; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; +import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; +import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; + +import javax.annotation.Nullable; +import java.net.InetAddress; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.*; + +class ClientConnectionManagerTest { + + private static EventLoopGroup eventLoopGroup; + + private LocalChannel localChannel; + private LocalChannel remoteChannel; + + private LocalServerChannel localServerChannel; + + private ClientConnectionManager clientConnectionManager; + + @BeforeAll + static void setUpBeforeAll() { + eventLoopGroup = new DefaultEventLoopGroup(); + } + + @BeforeEach + void setUp() throws InterruptedException { + eventLoopGroup = new DefaultEventLoopGroup(); + + clientConnectionManager = new ClientConnectionManager(); + + // We have to jump through some hoops to get "real" LocalChannel instances to test with, and so we run a trivial + // local server to which we can open trivial local connections + localServerChannel = (LocalServerChannel) new ServerBootstrap() + .group(eventLoopGroup) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer<>() { + @Override + protected void initChannel(final Channel channel) { + } + }) + .bind(new LocalAddress("test-server")) + .await() + .channel(); + + final Bootstrap clientBootstrap = new Bootstrap() + .group(eventLoopGroup) + .channel(LocalChannel.class) + .handler(new ChannelInitializer<>() { + @Override + protected void initChannel(final Channel ch) { + } + }); + + localChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel(); + remoteChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel(); + } + + @AfterEach + void tearDown() throws InterruptedException { + localChannel.close().await(); + remoteChannel.close().await(); + localServerChannel.close().await(); + } + + @AfterAll + static void tearDownAfterAll() throws InterruptedException { + eventLoopGroup.shutdownGracefully().await(); + } + + @ParameterizedTest + @MethodSource + void getAuthenticatedDevice(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional maybeAuthenticatedDevice) { + clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice); + + assertEquals(maybeAuthenticatedDevice, + clientConnectionManager.getAuthenticatedDevice(localChannel.localAddress())); + } + + private static List> getAuthenticatedDevice() { + return List.of( + Optional.of(new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID)), + Optional.empty() + ); + } + + @Test + void getAcceptableLanguages() { + clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); + + assertEquals(Optional.empty(), + clientConnectionManager.getAcceptableLanguages(localChannel.localAddress())); + + final List acceptLanguageRanges = Locale.LanguageRange.parse("en,ja"); + remoteChannel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges); + + assertEquals(Optional.of(acceptLanguageRanges), + clientConnectionManager.getAcceptableLanguages(localChannel.localAddress())); + } + + @Test + void getRemoteAddress() { + clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); + + assertEquals(Optional.empty(), + clientConnectionManager.getRemoteAddress(localChannel.localAddress())); + + final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9"); + remoteChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress); + + assertEquals(Optional.of(remoteAddress), + clientConnectionManager.getRemoteAddress(localChannel.localAddress())); + } + + @Test + void getUserAgent() throws UnrecognizedUserAgentException { + clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); + + assertEquals(Optional.empty(), + clientConnectionManager.getUserAgent(localChannel.localAddress())); + + final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux"); + remoteChannel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent); + + assertEquals(Optional.of(userAgent), + clientConnectionManager.getUserAgent(localChannel.localAddress())); + } + + @Test + void closeConnection() throws InterruptedException { + final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); + + clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); + + assertTrue(remoteChannel.isOpen()); + + assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertEquals(List.of(remoteChannel), + clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); + + remoteChannel.close().await(); + + assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); + } + + @Test + void handleWebSocketHandshakeCompleteRemoteAddress() { + final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); + + final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1"); + + ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel, + preferredRemoteAddress, + null, + null); + + assertEquals(preferredRemoteAddress, + embeddedChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); + } + + @ParameterizedTest + @MethodSource + void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader, + @Nullable final UserAgent expectedParsedUserAgent) { + + final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); + + ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel, + InetAddresses.forString("127.0.0.1"), + userAgentHeader, + null); + + assertEquals(userAgentHeader, + embeddedChannel.attr(ClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get()); + + assertEquals(expectedParsedUserAgent, + embeddedChannel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get()); + } + + private static List handleWebSocketHandshakeCompleteUserAgent() { + return List.of( + // Recognized user-agent + Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")), + + // Unrecognized user-agent + Arguments.of("Not a valid user-agent string", null), + + // Missing user-agent + Arguments.of(null, null) + ); + } + + + @ParameterizedTest + @MethodSource + void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader, + @Nullable final List expectedLanguageRanges) { + + final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); + + ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel, + InetAddresses.forString("127.0.0.1"), + null, + acceptLanguageHeader); + + assertEquals(expectedLanguageRanges, + embeddedChannel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get()); + } + + private static List handleWebSocketHandshakeCompleteAcceptLanguage() { + return List.of( + // Parseable list + Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")), + + // Unparsable list + Arguments.of("This is not a valid language preference list", null), + + // Missing list + Arguments.of(null, null) + ); + } + + @Test + void handleConnectionEstablishedAuthenticated() throws InterruptedException { + final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); + + assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); + + clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); + + assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertEquals(List.of(remoteChannel), clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); + + remoteChannel.close().await(); + + assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); + } + + @Test + void handleConnectionEstablishedAnonymous() throws InterruptedException { + assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + + clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); + + assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + + remoteChannel.close().await(); + + assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java index ec22d6af7..777d2a70f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java @@ -8,8 +8,8 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; import io.netty.handler.codec.http.websocketx.WebSocketVersion; @@ -35,6 +35,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { private final ECPublicKey rootPublicKey; @Nullable private final UUID accountIdentifier; private final byte deviceId; + private final HttpHeaders headers; private final SocketAddress remoteServerAddress; private final WebSocketCloseListener webSocketCloseListener; @@ -50,6 +51,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { final ECPublicKey rootPublicKey, @Nullable final UUID accountIdentifier, final byte deviceId, + final HttpHeaders headers, final SocketAddress remoteServerAddress, final WebSocketCloseListener webSocketCloseListener) { @@ -60,6 +62,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { this.rootPublicKey = rootPublicKey; this.accountIdentifier = accountIdentifier; this.deviceId = deviceId; + this.headers = headers; this.remoteServerAddress = remoteServerAddress; this.webSocketCloseListener = webSocketCloseListener; } @@ -87,7 +90,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { WebSocketVersion.V13, null, false, - new DefaultHttpHeaders(), + headers, Noise.MAX_PACKET_LEN, 10_000)) .addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener)) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java index fb31da1c6..ae115e9f2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java @@ -11,6 +11,7 @@ import java.net.SocketAddress; import java.net.URI; import java.security.cert.X509Certificate; import java.util.UUID; +import io.netty.handler.codec.http.HttpHeaders; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; import javax.annotation.Nullable; @@ -30,6 +31,7 @@ class WebSocketNoiseTunnelClient implements AutoCloseable { final ECPublicKey rootPublicKey, @Nullable final UUID accountIdentifier, final byte deviceId, + final HttpHeaders headers, final X509Certificate trustedServerCertificate, final NioEventLoopGroup eventLoopGroup, final WebSocketCloseListener webSocketCloseListener) { @@ -48,6 +50,7 @@ class WebSocketNoiseTunnelClient implements AutoCloseable { rootPublicKey, accountIdentifier, deviceId, + headers, remoteServerAddress, webSocketCloseListener)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java index 431edec62..d70d059d2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java @@ -1,7 +1,6 @@ package org.whispersystems.textsecuregcm.grpc.net; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; @@ -17,6 +16,8 @@ import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.URI; @@ -35,28 +36,41 @@ import java.security.cert.X509Certificate; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.util.Base64; +import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManagerFactory; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.signal.chat.rpc.AuthenticationTypeGrpc; -import org.signal.chat.rpc.GetAuthenticatedRequest; -import org.signal.chat.rpc.GetAuthenticatedResponse; +import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; +import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; +import org.signal.chat.rpc.GetRequestAttributesRequest; +import org.signal.chat.rpc.GetRequestAttributesResponse; +import org.signal.chat.rpc.RequestAttributesGrpc; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; +import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; +import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.UUIDUtil; class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest { @@ -66,6 +80,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes private static X509Certificate serverTlsCertificate; + private ClientConnectionManager clientConnectionManager; private ClientPublicKeysManager clientPublicKeysManager; private ECKeyPair rootKeyPair; @@ -79,6 +94,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID(); private static final byte DEVICE_ID = Device.PRIMARY_ID; + private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.randomAlphanumeric(16); + // Please note that this certificate/key are used only for testing and are not used anywhere outside of this test. // They were generated with: // @@ -133,6 +150,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes clientKeyPair = Curve.generateKeyPair(); final ECKeyPair serverKeyPair = Curve.generateKeyPair(); + clientConnectionManager = new ClientConnectionManager(); + clientPublicKeysManager = mock(ClientPublicKeysManager.class); when(clientPublicKeysManager.findPublicKey(any(), anyByte())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); @@ -146,7 +165,9 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { @Override protected void configureServer(final ServerBuilder serverBuilder) { - serverBuilder.addService(new AuthenticationTypeService(true)); + serverBuilder.addService(new RequestAttributesServiceImpl()) + .intercept(new RequestAttributesInterceptor(clientConnectionManager)) + .intercept(new RequireAuthenticationInterceptor(clientConnectionManager)); } }; @@ -155,7 +176,9 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { @Override protected void configureServer(final ServerBuilder serverBuilder) { - serverBuilder.addService(new AuthenticationTypeService(false)); + serverBuilder.addService(new RequestAttributesServiceImpl()) + .intercept(new RequestAttributesInterceptor(clientConnectionManager)) + .intercept(new ProhibitAuthenticationInterceptor(clientConnectionManager)); } }; @@ -166,11 +189,13 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes serverTlsPrivateKey, nioEventLoopGroup, delegatedTaskExecutor, + clientConnectionManager, clientPublicKeysManager, serverKeyPair, rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()), authenticatedGrpcServerAddress, - anonymousGrpcServerAddress); + anonymousGrpcServerAddress, + RECOGNIZED_PROXY_SECRET); websocketNoiseTunnelServer.start(); } @@ -198,10 +223,11 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); try { - final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel) - .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()); + final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - assertTrue(response.getAuthenticated()); + assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); + assertEquals(DEVICE_ID, response.getDeviceId()); } finally { channel.shutdown(); } @@ -215,15 +241,15 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes // Try to verify the server's public key with something other than the key with which it was signed try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = - buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) { + buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) { final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); try { //noinspection ResultOfMethodCallIgnored GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> AuthenticationTypeGrpc.newBlockingStub(channel) - .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); } finally { channel.shutdown(); } @@ -247,8 +273,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes try { //noinspection ResultOfMethodCallIgnored GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> AuthenticationTypeGrpc.newBlockingStub(channel) - .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); } finally { channel.shutdown(); } @@ -272,8 +298,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes try { //noinspection ResultOfMethodCallIgnored GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> AuthenticationTypeGrpc.newBlockingStub(channel) - .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); } finally { channel.shutdown(); } @@ -294,6 +320,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes rootKeyPair.getPublicKey(), ACCOUNT_IDENTIFIER, DEVICE_ID, + new DefaultHttpHeaders(), serverTlsCertificate, nioEventLoopGroup, webSocketCloseListener) @@ -304,8 +331,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes try { //noinspection ResultOfMethodCallIgnored GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> AuthenticationTypeGrpc.newBlockingStub(channel) - .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); } finally { channel.shutdown(); } @@ -320,10 +347,11 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); try { - final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel) - .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()); + final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - assertFalse(response.getAuthenticated()); + assertTrue(response.getAccountIdentifier().isEmpty()); + assertEquals(0, response.getDeviceId()); } finally { channel.shutdown(); } @@ -336,15 +364,15 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes // Try to verify the server's public key with something other than the key with which it was signed try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = - buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) { + buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) { final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); try { //noinspection ResultOfMethodCallIgnored GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> AuthenticationTypeGrpc.newBlockingStub(channel) - .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); } finally { channel.shutdown(); } @@ -365,6 +393,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes rootKeyPair.getPublicKey(), null, (byte) 0, + new DefaultHttpHeaders(), serverTlsCertificate, nioEventLoopGroup, webSocketCloseListener) @@ -375,8 +404,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes try { //noinspection ResultOfMethodCallIgnored GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> AuthenticationTypeGrpc.newBlockingStub(channel) - .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); } finally { channel.shutdown(); } @@ -438,6 +467,86 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } + @Test + void getRequestAttributes() throws InterruptedException { + final String remoteAddress = "4.5.6.7"; + final String acceptLanguage = "en"; + final String userAgent = "Signal-Desktop/1.2.3 Linux"; + + final HttpHeaders headers = new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) + .add("X-Forwarded-For", remoteAddress) + .add("Accept-Language", acceptLanguage) + .add("User-Agent", userAgent); + + try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = + buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), headers)) { + + final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); + + try { + final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()); + + assertEquals(remoteAddress, response.getRemoteAddress()); + assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList()); + + assertEquals("DESKTOP", response.getUserAgent().getPlatform()); + assertEquals("1.2.3", response.getUserAgent().getVersion()); + assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers()); + } finally { + channel.shutdown(); + } + } + } + + @Test + void closeForReauthentication() throws InterruptedException { + final CountDownLatch connectionCloseLatch = new CountDownLatch(1); + final AtomicInteger serverCloseStatusCode = new AtomicInteger(0); + final AtomicBoolean closedByServer = new AtomicBoolean(false); + + final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() { + + @Override + public void handleWebSocketClosedByClient(final int statusCode) { + serverCloseStatusCode.set(statusCode); + closedByServer.set(false); + connectionCloseLatch.countDown(); + } + + @Override + public void handleWebSocketClosedByServer(final int statusCode) { + serverCloseStatusCode.set(statusCode); + closedByServer.set(true); + connectionCloseLatch.countDown(); + } + }; + + try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAuthenticatedClient(webSocketCloseListener)) { + + final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); + + try { + final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); + + assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); + assertEquals(DEVICE_ID, response.getDeviceId()); + + clientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID)); + assertTrue(connectionCloseLatch.await(2, TimeUnit.SECONDS)); + + assertEquals(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED.getStatusCode(), + serverCloseStatusCode.get()); + + assertTrue(closedByServer.get()); + } finally { + channel.shutdown(); + } + } + } + private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException { return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER); } @@ -445,11 +554,12 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener) throws InterruptedException { - return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey()); + return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders()); } private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener, - final ECPublicKey rootPublicKey) throws InterruptedException { + final ECPublicKey rootPublicKey, + final HttpHeaders headers) throws InterruptedException { return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(), WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI, @@ -458,6 +568,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes rootPublicKey, ACCOUNT_IDENTIFIER, DEVICE_ID, + headers, serverTlsCertificate, nioEventLoopGroup, webSocketCloseListener) @@ -465,11 +576,12 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes } private WebSocketNoiseTunnelClient buildAndStartAnonymousClient() throws InterruptedException { - return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey()); + return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders()); } private WebSocketNoiseTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener, - final ECPublicKey rootPublicKey) throws InterruptedException { + final ECPublicKey rootPublicKey, + final HttpHeaders headers) throws InterruptedException { return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(), WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI, @@ -478,6 +590,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes rootPublicKey, null, (byte) 0, + headers, serverTlsCertificate, nioEventLoopGroup, webSocketCloseListener) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java new file mode 100644 index 000000000..68d8e1515 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java @@ -0,0 +1,202 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; + +import com.google.common.net.InetAddresses; +import com.vdurmont.semver4j.Semver; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import javax.annotation.Nullable; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.signal.libsignal.protocol.ecc.Curve; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; + +class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { + + private UserEventRecordingHandler userEventRecordingHandler; + private MutableRemoteAddressEmbeddedChannel embeddedChannel; + + private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.randomAlphanumeric(16); + + private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter { + + private final List receivedEvents = new ArrayList<>(); + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) { + receivedEvents.add(event); + } + + public List getReceivedEvents() { + return receivedEvents; + } + } + + private static class MutableRemoteAddressEmbeddedChannel extends EmbeddedChannel { + + private SocketAddress remoteAddress; + + public MutableRemoteAddressEmbeddedChannel(final ChannelHandler... handlers) { + super(handlers); + } + + @Override + protected SocketAddress remoteAddress0() { + return isActive() ? remoteAddress : null; + } + + public void setRemoteAddress(final SocketAddress remoteAddress) { + this.remoteAddress = remoteAddress; + } + } + + @BeforeEach + void setUp() { + userEventRecordingHandler = new UserEventRecordingHandler(); + + embeddedChannel = new MutableRemoteAddressEmbeddedChannel( + new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class), + Curve.generateKeyPair(), + new byte[64], + RECOGNIZED_PROXY_SECRET), + userEventRecordingHandler); + + embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0)); + } + + @ParameterizedTest + @MethodSource + void handleWebSocketHandshakeComplete(final String uri, final Class expectedHandlerClass) { + final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = + new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null); + + embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); + + assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass)); + + assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents()); + } + + private static List handleWebSocketHandshakeComplete() { + return List.of( + Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class), + Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class)); + } + + @Test + void handleWebSocketHandshakeCompleteUnexpectedPath() { + final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = + new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null); + + embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); + + assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class)); + assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException()); + } + + @Test + void handleUnrecognizedEvent() { + final Object unrecognizedEvent = new Object(); + + embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent); + assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents()); + } + + @ParameterizedTest + @MethodSource + void getRemoteAddress(final HttpHeaders headers, final SocketAddress remoteAddress, @Nullable InetAddress expectedRemoteAddress) { + final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = + new WebSocketServerProtocolHandler.HandshakeComplete( + WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, headers, null); + + embeddedChannel.setRemoteAddress(remoteAddress); + embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); + + assertEquals(expectedRemoteAddress, + embeddedChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); + } + + private static List getRemoteAddress() { + final InetSocketAddress remoteAddress = new InetSocketAddress("5.6.7.8", 0); + final InetAddress clientAddress = InetAddresses.forString("1.2.3.4"); + final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1"); + + return List.of( + // Recognized proxy, single forwarded-for address + Arguments.of(new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) + .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), + remoteAddress, + clientAddress), + + // Recognized proxy, multiple forwarded-for addresses + Arguments.of(new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) + .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()), + remoteAddress, + proxyAddress), + + // No recognized proxy header, single forwarded-for address + Arguments.of(new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), + remoteAddress, + remoteAddress.getAddress()), + + // No recognized proxy header, no forwarded-for address + Arguments.of(new DefaultHttpHeaders(), + remoteAddress, + remoteAddress.getAddress()), + + // Incorrect proxy header, single forwarded-for address + Arguments.of(new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect") + .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), + remoteAddress, + remoteAddress.getAddress()), + + // Recognized proxy, no forwarded-for address + Arguments.of(new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), + remoteAddress, + remoteAddress.getAddress()), + + // Recognized proxy, bogus forwarded-for address + Arguments.of(new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) + .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"), + remoteAddress, + null), + + // No forwarded-for address, non-InetSocketAddress remote address + Arguments.of(new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), + new LocalAddress("local-address"), + null) + ); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java deleted file mode 100644 index 98f901b28..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java +++ /dev/null @@ -1,91 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.DefaultHttpHeaders; -import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.signal.libsignal.protocol.ecc.Curve; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; - -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.mock; - -class WebsocketHandshakeCompleteListenerTest extends AbstractLeakDetectionTest { - - private UserEventRecordingHandler userEventRecordingHandler; - private EmbeddedChannel embeddedChannel; - - private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter { - - private final List receivedEvents = new ArrayList<>(); - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) { - receivedEvents.add(event); - } - - public List getReceivedEvents() { - return receivedEvents; - } - } - - @BeforeEach - void setUp() { - userEventRecordingHandler = new UserEventRecordingHandler(); - - embeddedChannel = new EmbeddedChannel( - new WebsocketHandshakeCompleteListener(mock(ClientPublicKeysManager.class), Curve.generateKeyPair(), new byte[64]), - userEventRecordingHandler); - } - - @ParameterizedTest - @MethodSource - void handleWebSocketHandshakeComplete(final String uri, final Class expectedHandlerClass) { - final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = - new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null); - - embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); - - assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class)); - assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass)); - - assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents()); - } - - private static List handleWebSocketHandshakeComplete() { - return List.of( - Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class), - Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class)); - } - - @Test - void handleWebSocketHandshakeCompleteUnexpectedPath() { - final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = - new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null); - - embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); - - assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class)); - assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException()); - } - - @Test - void handleUnrecognizedEvent() { - final Object unrecognizedEvent = new Object(); - - embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent); - assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents()); - } -} diff --git a/service/src/test/proto/authentication_type_service.proto b/service/src/test/proto/authentication_type_service.proto deleted file mode 100644 index 53a60a731..000000000 --- a/service/src/test/proto/authentication_type_service.proto +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -syntax = "proto3"; - -option java_multiple_files = true; - -package org.signal.chat.rpc; - -// A simple test service that identifies its authentication type to callers -service AuthenticationType { - rpc GetAuthenticated (GetAuthenticatedRequest) returns (GetAuthenticatedResponse) {} -} - -message GetAuthenticatedRequest { -} - -message GetAuthenticatedResponse { - bool authenticated = 1; -} diff --git a/service/src/test/proto/request_attributes_service.proto b/service/src/test/proto/request_attributes_service.proto new file mode 100644 index 000000000..f7df61d64 --- /dev/null +++ b/service/src/test/proto/request_attributes_service.proto @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +syntax = "proto3"; + +option java_multiple_files = true; + +package org.signal.chat.rpc; + +// A simple test service that echoes request attributes to callers +service RequestAttributes { + rpc GetRequestAttributes (GetRequestAttributesRequest) returns (GetRequestAttributesResponse) {} + + rpc GetAuthenticatedDevice (GetAuthenticatedDeviceRequest) returns (GetAuthenticatedDeviceResponse) {} +} + +message GetRequestAttributesRequest { +} + +message GetRequestAttributesResponse { + repeated string acceptable_languages = 1; + repeated string available_accepted_locales = 2; + string remote_address = 3; + UserAgent user_agent = 4; +} + +message UserAgent { + string platform = 1; + string version = 2; + string additional_specifiers = 3; +} + +message GetAuthenticatedDeviceRequest { +} + +message GetAuthenticatedDeviceResponse { + bytes account_identifier = 1; + uint32 device_id = 2; +}