Introduce a Noise-over-WebSocket client connection manager
This commit is contained in:
parent
075a08884b
commit
aec6ac019f
|
@ -95,3 +95,5 @@ turn.secret: AAAAAAAAAAA=
|
|||
linkDevice.secret: AAAAAAAAAAA=
|
||||
|
||||
tlsKeyStore.password: unset
|
||||
|
||||
noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA
|
||||
|
|
|
@ -468,3 +468,7 @@ callingTurnManualTable:
|
|||
s3Bucket: a-bucket
|
||||
objectKey: an-object.tar.gz
|
||||
maxSize: 32777216
|
||||
|
||||
noiseTunnel:
|
||||
port: 8443
|
||||
recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
|
|||
import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.NoiseWebSocketTunnelConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration;
|
||||
|
@ -333,6 +334,11 @@ public class WhisperServerConfiguration extends Configuration {
|
|||
@JsonProperty
|
||||
private MonitoredS3ObjectConfiguration callingTurnManualTable;
|
||||
|
||||
@Valid
|
||||
@NotNull
|
||||
@JsonProperty
|
||||
private NoiseWebSocketTunnelConfiguration noiseTunnel;
|
||||
|
||||
public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() {
|
||||
return tlsKeyStore;
|
||||
}
|
||||
|
@ -555,4 +561,8 @@ public class WhisperServerConfiguration extends Configuration {
|
|||
public MonitoredS3ObjectConfiguration getCallingTurnManualTable() {
|
||||
return callingTurnManualTable;
|
||||
}
|
||||
|
||||
public NoiseWebSocketTunnelConfiguration getNoiseWebSocketTunnelConfiguration() {
|
||||
return noiseTunnel;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,7 +71,8 @@ import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
|
|||
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
||||
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
|
||||
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.BasicCredentialAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
|
||||
import org.whispersystems.textsecuregcm.backup.BackupManager;
|
||||
import org.whispersystems.textsecuregcm.backup.BackupsDb;
|
||||
|
@ -127,7 +128,6 @@ import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
|
|||
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
|
||||
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
|
||||
import org.whispersystems.textsecuregcm.geo.MaxMindDatabaseManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.AccountsAnonymousGrpcService;
|
||||
import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService;
|
||||
import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
|
||||
|
@ -138,7 +138,8 @@ import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
|
|||
import org.whispersystems.textsecuregcm.grpc.PaymentsGrpcService;
|
||||
import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService;
|
||||
import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService;
|
||||
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
|
||||
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
|
||||
|
@ -752,8 +753,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
geoIpCityDatabaseManager
|
||||
);
|
||||
|
||||
final BasicCredentialAuthenticationInterceptor basicCredentialAuthenticationInterceptor =
|
||||
new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager));
|
||||
final ClientConnectionManager clientConnectionManager = new ClientConnectionManager();
|
||||
|
||||
final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup();
|
||||
|
||||
|
@ -762,8 +762,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
new MetricCollectingServerInterceptor(Metrics.globalRegistry);
|
||||
|
||||
final ErrorMappingInterceptor errorMappingInterceptor = new ErrorMappingInterceptor();
|
||||
final AcceptLanguageInterceptor acceptLanguageInterceptor = new AcceptLanguageInterceptor();
|
||||
final UserAgentInterceptor userAgentInterceptor = new UserAgentInterceptor();
|
||||
final RequestAttributesInterceptor requestAttributesInterceptor =
|
||||
new RequestAttributesInterceptor(clientConnectionManager);
|
||||
|
||||
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("grpc-anonymous");
|
||||
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("grpc-authenticated");
|
||||
|
@ -778,9 +778,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
// TODO: specialize metrics with user-agent platform
|
||||
.intercept(metricCollectingServerInterceptor)
|
||||
.intercept(errorMappingInterceptor)
|
||||
.intercept(acceptLanguageInterceptor)
|
||||
.intercept(remoteDeprecationFilter)
|
||||
.intercept(userAgentInterceptor)
|
||||
.intercept(requestAttributesInterceptor)
|
||||
.intercept(new ProhibitAuthenticationInterceptor(clientConnectionManager))
|
||||
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
|
||||
.addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
|
||||
.addService(new PaymentsGrpcService(currencyManager))
|
||||
|
@ -799,10 +799,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
// TODO: specialize metrics with user-agent platform
|
||||
.intercept(metricCollectingServerInterceptor)
|
||||
.intercept(errorMappingInterceptor)
|
||||
.intercept(acceptLanguageInterceptor)
|
||||
.intercept(remoteDeprecationFilter)
|
||||
.intercept(userAgentInterceptor)
|
||||
.intercept(new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager)))
|
||||
.intercept(requestAttributesInterceptor)
|
||||
.intercept(new RequireAuthenticationInterceptor(clientConnectionManager))
|
||||
.addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager))
|
||||
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
|
||||
.addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters))
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
|
||||
import java.util.Optional;
|
||||
|
||||
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
|
||||
|
||||
private final ClientConnectionManager clientConnectionManager;
|
||||
|
||||
private static final Metadata EMPTY_TRAILERS = new Metadata();
|
||||
|
||||
AbstractAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) {
|
||||
this.clientConnectionManager = clientConnectionManager;
|
||||
}
|
||||
|
||||
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) {
|
||||
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
|
||||
return clientConnectionManager.getAuthenticatedDevice(localAddress);
|
||||
} else {
|
||||
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
|
||||
}
|
||||
}
|
||||
|
||||
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) {
|
||||
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
|
||||
return new ServerCall.Listener<>() {};
|
||||
}
|
||||
}
|
|
@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.auth.grpc;
|
|||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Status;
|
||||
import java.util.UUID;
|
||||
import javax.annotation.Nullable;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
||||
|
@ -16,8 +15,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
|
|||
*/
|
||||
public class AuthenticationUtil {
|
||||
|
||||
static final Context.Key<UUID> CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY = Context.key("authenticated-aci");
|
||||
static final Context.Key<Byte> CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id");
|
||||
static final Context.Key<AuthenticatedDevice> CONTEXT_AUTHENTICATED_DEVICE = Context.key("authenticated-device");
|
||||
|
||||
/**
|
||||
* Returns the account/device authenticated in the current gRPC context or throws an "unauthenticated" exception if
|
||||
|
@ -29,11 +27,10 @@ public class AuthenticationUtil {
|
|||
* could be retrieved from the current gRPC context
|
||||
*/
|
||||
public static AuthenticatedDevice requireAuthenticatedDevice() {
|
||||
@Nullable final UUID accountIdentifier = CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get();
|
||||
@Nullable final Byte deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get();
|
||||
@Nullable final AuthenticatedDevice authenticatedDevice = CONTEXT_AUTHENTICATED_DEVICE.get();
|
||||
|
||||
if (accountIdentifier != null && deviceId != null) {
|
||||
return new AuthenticatedDevice(accountIdentifier, deviceId);
|
||||
if (authenticatedDevice != null) {
|
||||
return authenticatedDevice;
|
||||
}
|
||||
|
||||
throw Status.UNAUTHENTICATED.asRuntimeException();
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.dropwizard.auth.basic.BasicCredentials;
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import java.util.Optional;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
|
||||
/**
|
||||
* A basic credential authentication interceptor enforces the presence of a valid username and password on every call.
|
||||
* Callers supply credentials by providing a username (UUID and optional device ID) and password pair in the
|
||||
* {@code x-signal-basic-auth-credentials} call header.
|
||||
* <p/>
|
||||
* Downstream services can retrieve the identity of the authenticated caller using methods in
|
||||
* {@link AuthenticationUtil}.
|
||||
* <p/>
|
||||
* Note that this authentication, while fully functional, is intended only for development and testing purposes and is
|
||||
* intended to be replaced with a more robust and efficient strategy before widespread client adoption.
|
||||
*
|
||||
* @see AuthenticationUtil
|
||||
* @see AccountAuthenticator
|
||||
*/
|
||||
public class BasicCredentialAuthenticationInterceptor implements ServerInterceptor {
|
||||
|
||||
private final AccountAuthenticator accountAuthenticator;
|
||||
|
||||
@VisibleForTesting
|
||||
static final Metadata.Key<String> BASIC_CREDENTIALS =
|
||||
Metadata.Key.of("x-signal-auth", Metadata.ASCII_STRING_MARSHALLER);
|
||||
|
||||
private static final Metadata EMPTY_TRAILERS = new Metadata();
|
||||
|
||||
public BasicCredentialAuthenticationInterceptor(final AccountAuthenticator accountAuthenticator) {
|
||||
this.accountAuthenticator = accountAuthenticator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
|
||||
final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
final String authHeader = headers.get(BASIC_CREDENTIALS);
|
||||
|
||||
if (StringUtils.isNotBlank(authHeader)) {
|
||||
final Optional<BasicCredentials> maybeCredentials = HeaderUtils.basicCredentialsFromAuthHeader(authHeader);
|
||||
if (maybeCredentials.isEmpty()) {
|
||||
call.close(Status.UNAUTHENTICATED.withDescription("Could not parse credentials"), EMPTY_TRAILERS);
|
||||
} else {
|
||||
final Optional<AuthenticatedAccount> maybeAuthenticatedAccount =
|
||||
accountAuthenticator.authenticate(maybeCredentials.get());
|
||||
|
||||
if (maybeAuthenticatedAccount.isPresent()) {
|
||||
final AuthenticatedAccount authenticatedAccount = maybeAuthenticatedAccount.get();
|
||||
|
||||
final Context context = Context.current()
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedAccount.getAccount().getUuid())
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedAccount.getAuthenticatedDevice().getId());
|
||||
|
||||
return Contexts.interceptCall(context, call, headers, next);
|
||||
} else {
|
||||
call.close(Status.UNAUTHENTICATED.withDescription("Credentials not accepted"), EMPTY_TRAILERS);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
call.close(Status.UNAUTHENTICATED.withDescription("No credentials provided"), EMPTY_TRAILERS);
|
||||
}
|
||||
|
||||
return new ServerCall.Listener<>() {};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
|
||||
|
||||
/**
|
||||
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
|
||||
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
|
||||
* device are closed with an {@code UNAUTHENTICATED} status.
|
||||
*/
|
||||
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
|
||||
|
||||
public ProhibitAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) {
|
||||
super(clientConnectionManager);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
return getAuthenticatedDevice(call)
|
||||
.map(ignored -> closeAsUnauthenticated(call))
|
||||
.orElseGet(() -> next.startCall(call, headers));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
|
||||
|
||||
/**
|
||||
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
|
||||
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
|
||||
* status.
|
||||
*/
|
||||
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
|
||||
|
||||
public RequireAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) {
|
||||
super(clientConnectionManager);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
return getAuthenticatedDevice(call)
|
||||
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
|
||||
call, headers, next))
|
||||
.orElseGet(() -> closeAsUnauthenticated(call));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package org.whispersystems.textsecuregcm.configuration;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
import javax.validation.constraints.Positive;
|
||||
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
|
||||
|
||||
public record NoiseWebSocketTunnelConfiguration(@Positive int port, @NotNull SecretString recognizedProxySecret) {
|
||||
}
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.filters;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.io.IOException;
|
||||
import java.util.Optional;
|
||||
import javax.annotation.Nullable;
|
||||
|
@ -72,8 +71,7 @@ public class RemoteAddressFilter implements Filter {
|
|||
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP |
|
||||
* MDN</a>
|
||||
*/
|
||||
@VisibleForTesting
|
||||
static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
|
||||
public static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
|
||||
return Optional.ofNullable(forwardedFor)
|
||||
.map(ff -> {
|
||||
final int idx = forwardedFor.lastIndexOf(',') + 1;
|
||||
|
|
|
@ -17,6 +17,7 @@ import io.micrometer.core.instrument.Metrics;
|
|||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.servlet.Filter;
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.ServletException;
|
||||
|
@ -26,6 +27,7 @@ import javax.servlet.http.HttpServletRequest;
|
|||
import javax.servlet.http.HttpServletResponse;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
|
||||
|
@ -79,7 +81,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (shouldBlock(UserAgentUtil.userAgentFromGrpcContext())) {
|
||||
if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) {
|
||||
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
|
||||
return new ServerCall.Listener<>() {};
|
||||
} else {
|
||||
|
@ -87,7 +89,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
|
|||
}
|
||||
}
|
||||
|
||||
private boolean shouldBlock(final UserAgent userAgent) {
|
||||
private boolean shouldBlock(@Nullable final UserAgent userAgent) {
|
||||
final DynamicRemoteDeprecationConfiguration configuration = dynamicConfigurationManager
|
||||
.getConfiguration().getRemoteDeprecationConfiguration();
|
||||
final Map<ClientPlatform, Semver> minimumVersionsByPlatform = configuration.getMinimumVersions();
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||
|
||||
public class AcceptLanguageInterceptor implements ServerInterceptor {
|
||||
private static final Logger logger = LoggerFactory.getLogger(AcceptLanguageInterceptor.class);
|
||||
private static final String INVALID_ACCEPT_LANGUAGE_COUNTER_NAME = name(AcceptLanguageInterceptor.class, "invalidAcceptLanguage");
|
||||
|
||||
@VisibleForTesting
|
||||
public static final Metadata.Key<String> ACCEPTABLE_LANGUAGES_GRPC_HEADER =
|
||||
Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER);
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
final List<Locale> locales = parseLocales(headers.get(ACCEPTABLE_LANGUAGES_GRPC_HEADER));
|
||||
|
||||
return Contexts.interceptCall(
|
||||
Context.current().withValue(AcceptLanguageUtil.ACCEPTABLE_LANGUAGES_CONTEXT_KEY, locales),
|
||||
call,
|
||||
headers,
|
||||
next);
|
||||
}
|
||||
|
||||
static List<Locale> parseLocales(@Nullable final String acceptableLanguagesHeader) {
|
||||
if (acceptableLanguagesHeader == null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
try {
|
||||
final List<Locale.LanguageRange> languageRanges = Locale.LanguageRange.parse(acceptableLanguagesHeader);
|
||||
return Locale.filter(languageRanges, Arrays.asList(Locale.getAvailableLocales()));
|
||||
} catch (final IllegalArgumentException e) {
|
||||
final UserAgent userAgent = UserAgentUtil.userAgentFromGrpcContext();
|
||||
Metrics.counter(INVALID_ACCEPT_LANGUAGE_COUNTER_NAME, "platform", userAgent.getPlatform().name().toLowerCase()).increment();
|
||||
logger.debug("Could not get acceptable languages; Accept-Language: {}; User-Agent: {}",
|
||||
acceptableLanguagesHeader,
|
||||
userAgent,
|
||||
e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
public class AcceptLanguageUtil {
|
||||
static final Context.Key<List<Locale>> ACCEPTABLE_LANGUAGES_CONTEXT_KEY = Context.key("accept-language");
|
||||
public static List<Locale> localeFromGrpcContext() {
|
||||
return ACCEPTABLE_LANGUAGES_CONTEXT_KEY.get();
|
||||
}
|
||||
}
|
|
@ -111,7 +111,7 @@ public class ProfileGrpcHelper {
|
|||
case ACI -> {
|
||||
responseBuilder.setUnrestrictedUnidentifiedAccess(targetAccount.isUnrestrictedUnidentifiedAccess())
|
||||
.addAllBadges(buildBadges(profileBadgeConverter.convert(
|
||||
AcceptLanguageUtil.localeFromGrpcContext(),
|
||||
RequestAttributesUtil.getAvailableAcceptedLocales(),
|
||||
targetAccount.getBadges(),
|
||||
ProfileHelper.isSelfProfileRequest(requesterUuid, (AciServiceIdentifier) targetIdentifier))));
|
||||
|
||||
|
|
|
@ -5,28 +5,12 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||
import reactor.core.publisher.Mono;
|
||||
import java.net.SocketAddress;
|
||||
import java.time.Duration;
|
||||
|
||||
class RateLimitUtil {
|
||||
|
||||
private static final RateLimitExceededException UNKNOWN_REMOTE_ADDRESS_EXCEPTION =
|
||||
new RateLimitExceededException(Duration.ofHours(1), true);
|
||||
|
||||
static Mono<Void> rateLimitByRemoteAddress(final RateLimiter rateLimiter) {
|
||||
return rateLimitByRemoteAddress(rateLimiter, true);
|
||||
}
|
||||
|
||||
static Mono<Void> rateLimitByRemoteAddress(final RateLimiter rateLimiter, final boolean failOnUnknownRemoteAddress) {
|
||||
final SocketAddress remoteAddress = RemoteAddressUtil.getRemoteAddress();
|
||||
|
||||
if (remoteAddress != null) {
|
||||
return rateLimiter.validateReactive(remoteAddress.toString());
|
||||
} else {
|
||||
return failOnUnknownRemoteAddress ? Mono.error(UNKNOWN_REMOTE_ADDRESS_EXCEPTION) : Mono.empty();
|
||||
}
|
||||
return rateLimiter.validateReactive(RequestAttributesUtil.getRemoteAddress().getHostAddress());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import java.net.SocketAddress;
|
||||
|
||||
public class RemoteAddressInterceptor implements ServerInterceptor {
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> serverCall,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
// Note: the specific implementation for getting a remote client address may change depending on the client
|
||||
// connection strategy. The important thing is that the remote address wind up in the context for the current
|
||||
// call so it can be retrieved by `RemoteAddressUtil`.
|
||||
final SocketAddress remoteAddress = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
|
||||
|
||||
return Contexts.interceptCall(
|
||||
Context.current().withValue(RemoteAddressUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress),
|
||||
serverCall, headers, next);
|
||||
}
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import java.net.SocketAddress;
|
||||
|
||||
public class RemoteAddressUtil {
|
||||
|
||||
static final Context.Key<SocketAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
|
||||
|
||||
/**
|
||||
* Returns the socket address of the remote client in the current gRPC request context.
|
||||
*
|
||||
* @return the socket address of the remote client
|
||||
*/
|
||||
public static SocketAddress getRemoteAddress() {
|
||||
return REMOTE_ADDRESS_CONTEXT_KEY.get();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import java.net.InetAddress;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
|
||||
public class RequestAttributesInterceptor implements ServerInterceptor {
|
||||
|
||||
private final ClientConnectionManager clientConnectionManager;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
|
||||
|
||||
public RequestAttributesInterceptor(final ClientConnectionManager clientConnectionManager) {
|
||||
this.clientConnectionManager = clientConnectionManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
|
||||
Context context = Context.current();
|
||||
|
||||
{
|
||||
final Optional<InetAddress> maybeRemoteAddress = clientConnectionManager.getRemoteAddress(localAddress);
|
||||
|
||||
if (maybeRemoteAddress.isEmpty()) {
|
||||
// We should never have a call from a party whose remote address we can't identify
|
||||
log.warn("No remote address available");
|
||||
|
||||
call.close(Status.INTERNAL, new Metadata());
|
||||
return new ServerCall.Listener<>() {};
|
||||
}
|
||||
|
||||
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
|
||||
}
|
||||
|
||||
{
|
||||
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
|
||||
clientConnectionManager.getAcceptableLanguages(localAddress);
|
||||
|
||||
if (maybeAcceptLanguage.isPresent()) {
|
||||
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
final Optional<UserAgent> maybeUserAgent = clientConnectionManager.getUserAgent(localAddress);
|
||||
|
||||
if (maybeUserAgent.isPresent()) {
|
||||
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
|
||||
}
|
||||
}
|
||||
|
||||
return Contexts.interceptCall(context, call, headers, next);
|
||||
} else {
|
||||
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import java.net.InetAddress;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
|
||||
public class RequestAttributesUtil {
|
||||
|
||||
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language");
|
||||
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
|
||||
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("user-agent");
|
||||
|
||||
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
|
||||
|
||||
/**
|
||||
* Returns the acceptable languages listed by the remote client in the current gRPC request context.
|
||||
*
|
||||
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
|
||||
*/
|
||||
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() {
|
||||
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a list of distinct locales supported by the JVM and accepted by the remote client in the current gRPC
|
||||
* context. May be empty if the client did not supply a list of acceptable languages, if the list of acceptable
|
||||
* languages could not be parsed, or if none of the acceptable languages are available in the current JVM.
|
||||
*
|
||||
* @return a list of distinct locales acceptable to the remote client and available in this JVM
|
||||
*/
|
||||
public static List<Locale> getAvailableAcceptedLocales() {
|
||||
return getAcceptableLanguages()
|
||||
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
|
||||
.orElseGet(Collections::emptyList);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the remote address of the remote client in the current gRPC request context.
|
||||
*
|
||||
* @return the remote address of the remote client
|
||||
*/
|
||||
public static InetAddress getRemoteAddress() {
|
||||
return REMOTE_ADDRESS_CONTEXT_KEY.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the parsed user-agent of the remote client in the current gRPC request context.
|
||||
*
|
||||
* @return the parsed user-agent of the remote client; may be null if unparseable or not specified
|
||||
*/
|
||||
public static Optional<UserAgent> getUserAgent() {
|
||||
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
|
||||
public class UserAgentInterceptor implements ServerInterceptor {
|
||||
@VisibleForTesting
|
||||
public static final Metadata.Key<String> USER_AGENT_GRPC_HEADER =
|
||||
Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER);
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
UserAgent userAgent;
|
||||
try {
|
||||
userAgent = UserAgentUtil.parseUserAgentString(headers.get(USER_AGENT_GRPC_HEADER));
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
userAgent = null;
|
||||
}
|
||||
|
||||
final Context context = Context.current().withValue(UserAgentUtil.USER_AGENT_CONTEXT_KEY, userAgent);
|
||||
return Contexts.interceptCall(context, call, headers, next);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,218 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.util.AttributeKey;
|
||||
import java.net.InetAddress;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
|
||||
/**
|
||||
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
|
||||
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the
|
||||
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close
|
||||
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate.
|
||||
*/
|
||||
public class ClientConnectionManager {
|
||||
|
||||
private final Map<LocalAddress, Channel> remoteChannelsByLocalAddress = new ConcurrentHashMap<>();
|
||||
private final Map<AuthenticatedDevice, List<Channel>> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>();
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<AuthenticatedDevice> AUTHENTICATED_DEVICE_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(ClientConnectionManager.class, "authenticatedDevice");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(ClientConnectionManager.class);
|
||||
|
||||
/**
|
||||
* Returns the authenticated device associated with the given local address, if any. An authenticated device is
|
||||
* available if and only if the given local address maps to an active local connection and that connection is
|
||||
* authenticated (i.e. not anonymous).
|
||||
*
|
||||
* @param localAddress the local address for which to find an authenticated device
|
||||
*
|
||||
* @return the authenticated device associated with the given local address, if any
|
||||
*/
|
||||
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) {
|
||||
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress));
|
||||
}
|
||||
|
||||
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) {
|
||||
return Optional.ofNullable(remoteChannel)
|
||||
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may
|
||||
* be unavailable if the local connection associated with the given local address has already closed, if the client
|
||||
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
|
||||
*
|
||||
* @param localAddress the local address for which to find acceptable languages
|
||||
*
|
||||
* @return the acceptable languages associated with the given local address, if any
|
||||
*/
|
||||
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if
|
||||
* the local connection associated with the given local address has already closed.
|
||||
*
|
||||
* @param localAddress the local address for which to find a remote address
|
||||
*
|
||||
* @return the remote address associated with the given local address, if any
|
||||
*/
|
||||
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the parsed user agent provided by the client the opened the connection associated with the given local
|
||||
* address. This method may return an empty value if no active local connection is associated with the given local
|
||||
* address or if the client's user-agent string was not recognized.
|
||||
*
|
||||
* @param localAddress the local address for which to find a User-Agent string
|
||||
*
|
||||
* @return the user agent associated with the given local address
|
||||
*/
|
||||
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes any client connections to this host associated with the given authenticated device.
|
||||
*
|
||||
* @param authenticatedDevice the authenticated device for which to close connections
|
||||
*/
|
||||
public void closeConnection(final AuthenticatedDevice authenticatedDevice) {
|
||||
// Channels will actually get removed from the list/map by their closeFuture listeners
|
||||
remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()).forEach(channel ->
|
||||
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
|
||||
.toWebSocketCloseStatus("Reauthentication required")))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
@Nullable List<Channel> getRemoteChannelsByAuthenticatedDevice(final AuthenticatedDevice authenticatedDevice) {
|
||||
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) {
|
||||
return remoteChannelsByLocalAddress.get(localAddress);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
|
||||
* request with the channel via which the handshake took place.
|
||||
*
|
||||
* @param channel the channel that completed a WebSocket handshake
|
||||
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
|
||||
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
|
||||
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
|
||||
* {@code null}
|
||||
*/
|
||||
static void handleWebSocketHandshakeComplete(final Channel channel,
|
||||
final InetAddress preferredRemoteAddress,
|
||||
@Nullable final String userAgentHeader,
|
||||
@Nullable final String acceptLanguageHeader) {
|
||||
|
||||
channel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress);
|
||||
|
||||
if (StringUtils.isNotBlank(userAgentHeader)) {
|
||||
channel.attr(ClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
|
||||
|
||||
try {
|
||||
channel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
|
||||
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
|
||||
} catch (final UnrecognizedUserAgentException ignored) {
|
||||
}
|
||||
}
|
||||
|
||||
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
|
||||
try {
|
||||
channel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader));
|
||||
} catch (final IllegalArgumentException e) {
|
||||
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server.
|
||||
*
|
||||
* @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC
|
||||
* server
|
||||
* @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel
|
||||
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
|
||||
*/
|
||||
void handleConnectionEstablished(final LocalChannel localChannel,
|
||||
final Channel remoteChannel,
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
|
||||
|
||||
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
|
||||
remoteChannel.attr(ClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
|
||||
|
||||
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
|
||||
|
||||
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
|
||||
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
|
||||
final List<Channel> channels = existingChannelList != null ? existingChannelList : new ArrayList<>();
|
||||
channels.add(remoteChannel);
|
||||
|
||||
return channels;
|
||||
}));
|
||||
|
||||
remoteChannel.closeFuture().addListener(closeFuture -> {
|
||||
remoteChannelsByLocalAddress.remove(localChannel.localAddress());
|
||||
|
||||
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
|
||||
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
|
||||
if (existingChannelList == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
existingChannelList.remove(remoteChannel);
|
||||
|
||||
return existingChannelList.isEmpty() ? null : existingChannelList;
|
||||
}));
|
||||
});
|
||||
}
|
||||
}
|
|
@ -22,15 +22,21 @@ import org.slf4j.LoggerFactory;
|
|||
*/
|
||||
class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final ClientConnectionManager clientConnectionManager;
|
||||
|
||||
private final LocalAddress authenticatedGrpcServerAddress;
|
||||
private final LocalAddress anonymousGrpcServerAddress;
|
||||
|
||||
private final List<Object> pendingReads = new ArrayList<>();
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class);
|
||||
|
||||
public EstablishLocalGrpcConnectionHandler(final LocalAddress authenticatedGrpcServerAddress,
|
||||
public EstablishLocalGrpcConnectionHandler(final ClientConnectionManager clientConnectionManager,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress) {
|
||||
|
||||
this.clientConnectionManager = clientConnectionManager;
|
||||
|
||||
this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress;
|
||||
this.anonymousGrpcServerAddress = anonymousGrpcServerAddress;
|
||||
}
|
||||
|
@ -41,7 +47,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) throws Exception {
|
||||
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
|
||||
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
|
||||
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
|
||||
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
|
||||
|
@ -53,7 +59,6 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
|
||||
new Bootstrap()
|
||||
.remoteAddress(grpcServerAddress)
|
||||
// TODO Set local address
|
||||
.channel(LocalChannel.class)
|
||||
.group(remoteChannelContext.channel().eventLoop())
|
||||
.handler(new ChannelInitializer<LocalChannel>() {
|
||||
|
@ -63,15 +68,19 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
}
|
||||
})
|
||||
.connect()
|
||||
.addListener((ChannelFutureListener) future -> {
|
||||
if (future.isSuccess()) {
|
||||
.addListener((ChannelFutureListener) localChannelFuture -> {
|
||||
if (localChannelFuture.isSuccess()) {
|
||||
clientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
|
||||
remoteChannelContext.channel(),
|
||||
noiseHandshakeCompleteEvent.authenticatedDevice());
|
||||
|
||||
// Close the local connection if the remote channel closes and vice versa
|
||||
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
|
||||
future.channel().closeFuture().addListener(closeFuture ->
|
||||
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
|
||||
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
|
||||
remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
|
||||
|
||||
remoteChannelContext.pipeline()
|
||||
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(future.channel()));
|
||||
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
|
||||
|
||||
// Flush any buffered reads we accumulated while waiting to open the connection
|
||||
pendingReads.forEach(remoteChannelContext::fireChannelRead);
|
||||
|
@ -79,7 +88,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
|
||||
remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this);
|
||||
} else {
|
||||
log.warn("Failed to establish local connection to gRPC server", future.cause());
|
||||
log.warn("Failed to establish local connection to gRPC server", localChannelFuture.cause());
|
||||
remoteChannelContext.close();
|
||||
}
|
||||
});
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.net.InetAddresses;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.Optional;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
/**
|
||||
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
|
||||
* Noise handshake handler for the requested path.
|
||||
*/
|
||||
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private final ECKeyPair ecKeyPair;
|
||||
private final byte[] publicKeySignature;
|
||||
|
||||
private final byte[] recognizedProxySecret;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
|
||||
|
||||
@VisibleForTesting
|
||||
static final String RECOGNIZED_PROXY_SECRET_HEADER = "X-Signal-Recognized-Proxy";
|
||||
|
||||
@VisibleForTesting
|
||||
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
|
||||
|
||||
WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final byte[] publicKeySignature,
|
||||
final String recognizedProxySecret) {
|
||||
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
this.publicKeySignature = publicKeySignature;
|
||||
|
||||
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
|
||||
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
|
||||
// MessageDigest.equals() later.
|
||||
this.recognizedProxySecret = recognizedProxySecret.getBytes(StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||
final InetAddress preferredRemoteAddress;
|
||||
{
|
||||
final Optional<InetAddress> maybePreferredRemoteAddress =
|
||||
getPreferredRemoteAddress(context, handshakeCompleteEvent);
|
||||
|
||||
if (maybePreferredRemoteAddress.isEmpty()) {
|
||||
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR,
|
||||
"Could not determine remote address"))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
preferredRemoteAddress = maybePreferredRemoteAddress.get();
|
||||
}
|
||||
|
||||
ClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(),
|
||||
preferredRemoteAddress,
|
||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
|
||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));
|
||||
|
||||
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
|
||||
case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH ->
|
||||
new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
|
||||
|
||||
case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH ->
|
||||
new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
|
||||
|
||||
default -> {
|
||||
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
|
||||
// internal error if something slipped through.
|
||||
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
||||
}
|
||||
};
|
||||
|
||||
context.pipeline().replace(WebsocketHandshakeCompleteHandler.this, null, noiseHandshakeHandler);
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||
|
||||
final byte[] recognizedProxySecretFromHeader =
|
||||
handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "")
|
||||
.getBytes(StandardCharsets.UTF_8);
|
||||
|
||||
final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader);
|
||||
|
||||
if (trustForwardedFor && handshakeCompleteEvent.requestHeaders().contains(FORWARDED_FOR_HEADER)) {
|
||||
final String forwardedFor = handshakeCompleteEvent.requestHeaders().get(FORWARDED_FOR_HEADER);
|
||||
|
||||
return RemoteAddressFilter.getMostRecentProxy(forwardedFor).map(mostRecentProxy -> {
|
||||
try {
|
||||
return InetAddresses.forString(mostRecentProxy);
|
||||
} catch (final IllegalArgumentException e) {
|
||||
log.warn("Failed to parse forwarded-for address: {}", forwardedFor, e);
|
||||
return null;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Either we don't trust the forwarded-for header or it's not present
|
||||
if (context.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) {
|
||||
return Optional.of(inetSocketAddress.getAddress());
|
||||
} else {
|
||||
log.warn("Channel's remote address was not an InetSocketAddress");
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
/**
|
||||
* A WebSocket handshake listener waits for a WebSocket handshake to complete, then replaces itself with the appropriate
|
||||
* Noise handshake handler for the requested path.
|
||||
*/
|
||||
class WebsocketHandshakeCompleteListener extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private final ECKeyPair ecKeyPair;
|
||||
private final byte[] publicKeySignature;
|
||||
|
||||
WebsocketHandshakeCompleteListener(final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final byte[] publicKeySignature) {
|
||||
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
this.publicKeySignature = publicKeySignature;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
|
||||
case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH ->
|
||||
new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
|
||||
|
||||
case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH ->
|
||||
new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
|
||||
|
||||
default -> {
|
||||
// The HttpHandler should have caught all of these cases already; we'll consider it an internal error if
|
||||
// something slipped through.
|
||||
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
||||
}
|
||||
};
|
||||
|
||||
context.pipeline().replace(WebsocketHandshakeCompleteListener.this, null, noiseHandshakeHandler);
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
}
|
|
@ -48,11 +48,13 @@ public class WebsocketNoiseTunnelServer implements Managed {
|
|||
final PrivateKey tlsPrivateKey,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final Executor delegatedTaskExecutor,
|
||||
final ClientConnectionManager clientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final byte[] publicKeySignature,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress) throws SSLException {
|
||||
final LocalAddress anonymousGrpcServerAddress,
|
||||
final String recognizedProxySecret) throws SSLException {
|
||||
|
||||
final SslProvider sslProvider;
|
||||
|
||||
|
@ -88,10 +90,10 @@ public class WebsocketNoiseTunnelServer implements Managed {
|
|||
.addLast(new RejectUnsupportedMessagesHandler())
|
||||
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
|
||||
// a WebSocket handshake has been completed
|
||||
.addLast(new WebsocketHandshakeCompleteListener(clientPublicKeysManager, ecKeyPair, publicKeySignature))
|
||||
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature, recognizedProxySecret))
|
||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||
// once the Noise handshake has completed
|
||||
.addLast(new EstablishLocalGrpcConnectionHandler(authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
||||
.addLast(new EstablishLocalGrpcConnectionHandler(clientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
||||
.addLast(new ErrorHandler());
|
||||
}
|
||||
});
|
||||
|
|
|
@ -7,15 +7,12 @@ package org.whispersystems.textsecuregcm.util.ua;
|
|||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.vdurmont.semver4j.Semver;
|
||||
import io.grpc.Context;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class UserAgentUtil {
|
||||
|
||||
public static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("x-signal-user-agent");
|
||||
|
||||
private static final Pattern STANDARD_UA_PATTERN = Pattern.compile("^Signal-(Android|Desktop|iOS)/([^ ]+)( (.+))?$", Pattern.CASE_INSENSITIVE);
|
||||
|
||||
public static UserAgent parseUserAgentString(final String userAgentString) throws UnrecognizedUserAgentException {
|
||||
|
@ -36,10 +33,6 @@ public class UserAgentUtil {
|
|||
throw new UnrecognizedUserAgentException();
|
||||
}
|
||||
|
||||
public static UserAgent userAgentFromGrpcContext() {
|
||||
return USER_AGENT_CONTEXT_KEY.get();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static UserAgent parseStandardUserAgentString(final String userAgentString) {
|
||||
final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
import io.netty.channel.DefaultEventLoopGroup;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.local.LocalServerChannel;
|
||||
import java.io.IOException;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
|
||||
|
||||
abstract class AbstractAuthenticationInterceptorTest {
|
||||
|
||||
private static DefaultEventLoopGroup eventLoopGroup;
|
||||
|
||||
private ClientConnectionManager clientConnectionManager;
|
||||
|
||||
private Server server;
|
||||
private ManagedChannel managedChannel;
|
||||
|
||||
@BeforeAll
|
||||
static void setUpBeforeAll() {
|
||||
eventLoopGroup = new DefaultEventLoopGroup();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws IOException {
|
||||
final LocalAddress serverAddress = new LocalAddress("test-authentication-interceptor-server");
|
||||
|
||||
clientConnectionManager = mock(ClientConnectionManager.class);
|
||||
|
||||
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
|
||||
// sure that we're using local channels and addresses
|
||||
server = NettyServerBuilder.forAddress(serverAddress)
|
||||
.channelType(LocalServerChannel.class)
|
||||
.bossEventLoopGroup(eventLoopGroup)
|
||||
.workerEventLoopGroup(eventLoopGroup)
|
||||
.intercept(getInterceptor())
|
||||
.addService(new RequestAttributesServiceImpl())
|
||||
.build()
|
||||
.start();
|
||||
|
||||
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
|
||||
.channelType(LocalChannel.class)
|
||||
.eventLoopGroup(eventLoopGroup)
|
||||
.usePlaintext()
|
||||
.build();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
managedChannel.shutdown();
|
||||
server.shutdown();
|
||||
}
|
||||
|
||||
protected abstract AbstractAuthenticationInterceptor getInterceptor();
|
||||
|
||||
protected ClientConnectionManager getClientConnectionManager() {
|
||||
return clientConnectionManager;
|
||||
}
|
||||
|
||||
protected GetAuthenticatedDeviceResponse getAuthenticatedDevice() {
|
||||
return RequestAttributesGrpc.newBlockingStub(managedChannel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
}
|
||||
}
|
|
@ -1,135 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import io.grpc.CallCredentials;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||
import io.grpc.inprocess.InProcessServerBuilder;
|
||||
import java.io.IOException;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.stream.Stream;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.signal.chat.rpc.EchoRequest;
|
||||
import org.signal.chat.rpc.EchoServiceGrpc;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
||||
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
|
||||
class BasicCredentialAuthenticationInterceptorTest {
|
||||
|
||||
private Server server;
|
||||
private ManagedChannel managedChannel;
|
||||
|
||||
private AccountAuthenticator accountAuthenticator;
|
||||
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws IOException {
|
||||
accountAuthenticator = mock(AccountAuthenticator.class);
|
||||
|
||||
final BasicCredentialAuthenticationInterceptor authenticationInterceptor =
|
||||
new BasicCredentialAuthenticationInterceptor(accountAuthenticator);
|
||||
|
||||
final String serverName = InProcessServerBuilder.generateName();
|
||||
|
||||
server = InProcessServerBuilder.forName(serverName)
|
||||
.directExecutor()
|
||||
.intercept(authenticationInterceptor)
|
||||
.addService(new EchoServiceImpl())
|
||||
.build()
|
||||
.start();
|
||||
|
||||
managedChannel = InProcessChannelBuilder.forName(serverName)
|
||||
.directExecutor()
|
||||
.build();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
managedChannel.shutdown();
|
||||
server.shutdown();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void interceptCall(final Metadata headers, final boolean acceptCredentials, final boolean expectAuthentication) {
|
||||
if (acceptCredentials) {
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getUuid()).thenReturn(UUID.randomUUID());
|
||||
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
|
||||
when(accountAuthenticator.authenticate(any()))
|
||||
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
|
||||
} else {
|
||||
when(accountAuthenticator.authenticate(any()))
|
||||
.thenReturn(Optional.empty());
|
||||
}
|
||||
|
||||
final EchoServiceGrpc.EchoServiceBlockingStub stub = EchoServiceGrpc.newBlockingStub(managedChannel)
|
||||
.withCallCredentials(new CallCredentials() {
|
||||
@Override
|
||||
public void applyRequestMetadata(final RequestInfo requestInfo, final Executor appExecutor, final MetadataApplier applier) {
|
||||
applier.apply(headers);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void thisUsesUnstableApi() {
|
||||
}
|
||||
});
|
||||
|
||||
if (expectAuthentication) {
|
||||
assertDoesNotThrow(() -> stub.echo(EchoRequest.newBuilder().build()));
|
||||
} else {
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> stub.echo(EchoRequest.newBuilder().build()));
|
||||
|
||||
assertEquals(Status.UNAUTHENTICATED.getCode(), exception.getStatus().getCode());
|
||||
}
|
||||
}
|
||||
|
||||
private static Stream<Arguments> interceptCall() {
|
||||
final Metadata malformedCredentialHeaders = new Metadata();
|
||||
malformedCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, "Incorrect");
|
||||
|
||||
final Metadata structurallyValidCredentialHeaders = new Metadata();
|
||||
structurallyValidCredentialHeaders.put(
|
||||
BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS,
|
||||
HeaderUtils.basicAuthHeader(UUID.randomUUID().toString(), RandomStringUtils.randomAlphanumeric(16))
|
||||
);
|
||||
|
||||
return Stream.of(
|
||||
Arguments.of(new Metadata(), true, false),
|
||||
Arguments.of(malformedCredentialHeaders, true, false),
|
||||
Arguments.of(structurallyValidCredentialHeaders, false, false),
|
||||
Arguments.of(structurallyValidCredentialHeaders, true, true)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -13,15 +13,14 @@ import io.grpc.ServerCallHandler;
|
|||
import io.grpc.ServerInterceptor;
|
||||
import java.util.UUID;
|
||||
import javax.annotation.Nullable;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
|
||||
public class MockAuthenticationInterceptor implements ServerInterceptor {
|
||||
|
||||
@Nullable
|
||||
private Pair<UUID, Byte> authenticatedDevice;
|
||||
private AuthenticatedDevice authenticatedDevice;
|
||||
|
||||
public void setAuthenticatedDevice(final UUID accountIdentifier, final byte deviceId) {
|
||||
authenticatedDevice = new Pair<>(accountIdentifier, deviceId);
|
||||
authenticatedDevice = new AuthenticatedDevice(accountIdentifier, deviceId);
|
||||
}
|
||||
|
||||
public void clearAuthenticatedDevice() {
|
||||
|
@ -33,14 +32,10 @@ public class MockAuthenticationInterceptor implements ServerInterceptor {
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (authenticatedDevice != null) {
|
||||
final Context context = Context.current()
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedDevice.first())
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedDevice.second());
|
||||
|
||||
return Contexts.interceptCall(context, call, headers, next);
|
||||
}
|
||||
|
||||
return next.startCall(call, headers);
|
||||
return authenticatedDevice != null
|
||||
? Contexts.interceptCall(
|
||||
Context.current().withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
|
||||
call, headers, next)
|
||||
: next.startCall(call, headers);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.Status;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
|
||||
|
||||
@Override
|
||||
protected AbstractAuthenticationInterceptor getInterceptor() {
|
||||
return new ProhibitAuthenticationInterceptor(getClientConnectionManager());
|
||||
}
|
||||
|
||||
@Test
|
||||
void interceptCall() {
|
||||
final ClientConnectionManager clientConnectionManager = getClientConnectionManager();
|
||||
|
||||
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
|
||||
|
||||
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
|
||||
assertTrue(response.getAccountIdentifier().isEmpty());
|
||||
assertEquals(0, response.getDeviceId());
|
||||
|
||||
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
|
||||
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
|
||||
|
||||
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import io.grpc.Status;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
|
||||
|
||||
@Override
|
||||
protected AbstractAuthenticationInterceptor getInterceptor() {
|
||||
return new RequireAuthenticationInterceptor(getClientConnectionManager());
|
||||
}
|
||||
|
||||
@Test
|
||||
void interceptCall() {
|
||||
final ClientConnectionManager clientConnectionManager = getClientConnectionManager();
|
||||
|
||||
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
|
||||
|
||||
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
|
||||
|
||||
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
|
||||
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
|
||||
|
||||
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
|
||||
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
|
||||
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
|
||||
}
|
||||
}
|
|
@ -39,10 +39,12 @@ import org.signal.chat.rpc.EchoServiceGrpc;
|
|||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
|
||||
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
|
||||
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
|
||||
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
|
||||
class RemoteDeprecationFilterTest {
|
||||
|
||||
|
@ -126,17 +128,25 @@ class RemoteDeprecationFilterTest {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource(value="testFilter")
|
||||
void testGrpcFilter(final String userAgent, final boolean expectDeprecation) throws Exception {
|
||||
void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException {
|
||||
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
|
||||
|
||||
try {
|
||||
mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString));
|
||||
} catch (UnrecognizedUserAgentException ignored) {
|
||||
}
|
||||
|
||||
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
|
||||
.directExecutor()
|
||||
.addService(new EchoServiceImpl())
|
||||
.intercept(filterConfiguredForTest())
|
||||
.intercept(new UserAgentInterceptor())
|
||||
.intercept(mockRequestAttributesInterceptor)
|
||||
.build()
|
||||
.start();
|
||||
|
||||
final ManagedChannel channel = InProcessChannelBuilder.forName("RemoteDeprecationFilterTest")
|
||||
.directExecutor()
|
||||
.userAgent(userAgent)
|
||||
.userAgent(userAgentString)
|
||||
.build();
|
||||
|
||||
try {
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||
import io.grpc.inprocess.InProcessServerBuilder;
|
||||
import io.grpc.stub.MetadataUtils;
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.signal.chat.rpc.EchoRequest;
|
||||
import org.signal.chat.rpc.EchoResponse;
|
||||
import org.signal.chat.rpc.EchoServiceGrpc;
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class AcceptLanguageInterceptorTest {
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void parseLocale(final String header, final List<Locale> expectedLocales) throws IOException, InterruptedException {
|
||||
final AtomicReference<List<Locale>> observedLocales = new AtomicReference<>(null);
|
||||
final EchoServiceImpl serviceImpl = new EchoServiceImpl() {
|
||||
@Override
|
||||
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
|
||||
observedLocales.set(AcceptLanguageUtil.localeFromGrpcContext());
|
||||
super.echo(req, responseObserver);
|
||||
}
|
||||
};
|
||||
|
||||
final Server testServer = InProcessServerBuilder.forName("AcceptLanguageTest")
|
||||
.directExecutor()
|
||||
.addService(serviceImpl)
|
||||
.intercept(new AcceptLanguageInterceptor())
|
||||
.intercept(new UserAgentInterceptor())
|
||||
.build()
|
||||
.start();
|
||||
|
||||
try {
|
||||
final ManagedChannel channel = InProcessChannelBuilder.forName("AcceptLanguageTest")
|
||||
.directExecutor()
|
||||
.userAgent("Signal-Android/1.2.3")
|
||||
.build();
|
||||
|
||||
final Metadata metadata = new Metadata();
|
||||
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, header);
|
||||
|
||||
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel)
|
||||
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
|
||||
|
||||
final EchoRequest request = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("test request")).build();
|
||||
client.echo(request);
|
||||
assertEquals(expectedLocales, observedLocales.get());
|
||||
} finally {
|
||||
testServer.shutdownNow();
|
||||
testServer.awaitTermination();
|
||||
}
|
||||
}
|
||||
|
||||
private static Stream<Arguments> parseLocale() {
|
||||
return Stream.of(
|
||||
// en-US-POSIX is a special locale that exists alongside en-US. It matches because of the definition of
|
||||
// basic filtering in RFC 4647 (https://datatracker.ietf.org/doc/html/rfc4647#section-3.3.1)
|
||||
Arguments.of("en-US,fr-CA", List.of(Locale.forLanguageTag("en-US-POSIX"), Locale.forLanguageTag("en-US"), Locale.forLanguageTag("fr-CA"))),
|
||||
Arguments.of("en-US; q=0.9, fr-CA", List.of(Locale.forLanguageTag("fr-CA"), Locale.forLanguageTag("en-US-POSIX"), Locale.forLanguageTag("en-US"))),
|
||||
Arguments.of("invalid-locale,fr-CA", List.of(Locale.forLanguageTag("fr-CA"))),
|
||||
Arguments.of("", Collections.emptyList()),
|
||||
Arguments.of("acompletely,unexpectedfor , mat", Collections.emptyList())
|
||||
);
|
||||
}
|
||||
}
|
|
@ -13,9 +13,9 @@ import static org.mockito.ArgumentMatchers.anyString;
|
|||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.net.InetAddresses;
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Status;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.time.Duration;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
@ -72,7 +72,7 @@ class AccountsAnonymousGrpcServiceTest extends
|
|||
|
||||
when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
|
||||
|
||||
getMockRemoteAddressInterceptor().setRemoteAddress(new InetSocketAddress("127.0.0.1", 12345));
|
||||
getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1"));
|
||||
|
||||
return new AccountsAnonymousGrpcService(accountsManager, rateLimiters);
|
||||
}
|
||||
|
|
|
@ -29,21 +29,21 @@ public final class GrpcTestUtils {
|
|||
public static void setupAuthenticatedExtension(
|
||||
final GrpcServerExtension extension,
|
||||
final MockAuthenticationInterceptor mockAuthenticationInterceptor,
|
||||
final MockRemoteAddressInterceptor mockRemoteAddressInterceptor,
|
||||
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor,
|
||||
final UUID authenticatedAci,
|
||||
final byte authenticatedDeviceId,
|
||||
final BindableService service) {
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId);
|
||||
extension.getServiceRegistry()
|
||||
.addService(ServerInterceptors.intercept(service, mockRemoteAddressInterceptor, mockAuthenticationInterceptor, new ErrorMappingInterceptor()));
|
||||
.addService(ServerInterceptors.intercept(service, mockRequestAttributesInterceptor, mockAuthenticationInterceptor, new ErrorMappingInterceptor()));
|
||||
}
|
||||
|
||||
public static void setupUnauthenticatedExtension(
|
||||
final GrpcServerExtension extension,
|
||||
final MockRemoteAddressInterceptor mockRemoteAddressInterceptor,
|
||||
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor,
|
||||
final BindableService service) {
|
||||
extension.getServiceRegistry()
|
||||
.addService(ServerInterceptors.intercept(service, mockRemoteAddressInterceptor, new ErrorMappingInterceptor()));
|
||||
.addService(ServerInterceptors.intercept(service, mockRequestAttributesInterceptor, new ErrorMappingInterceptor()));
|
||||
}
|
||||
|
||||
public static void assertStatusException(final Status expected, final Executable serviceCall) {
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import java.net.SocketAddress;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
public class MockRemoteAddressInterceptor implements ServerInterceptor {
|
||||
|
||||
@Nullable
|
||||
private SocketAddress remoteAddress;
|
||||
|
||||
public void setRemoteAddress(@Nullable final SocketAddress remoteAddress) {
|
||||
this.remoteAddress = remoteAddress;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> serverCall,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
return remoteAddress == null
|
||||
? next.startCall(serverCall, headers)
|
||||
: Contexts.interceptCall(
|
||||
Context.current().withValue(RemoteAddressUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress),
|
||||
serverCall, headers, next);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import java.net.InetAddress;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import javax.annotation.Nullable;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
|
||||
public class MockRequestAttributesInterceptor implements ServerInterceptor {
|
||||
|
||||
@Nullable
|
||||
private InetAddress remoteAddress;
|
||||
|
||||
@Nullable
|
||||
private UserAgent userAgent;
|
||||
|
||||
@Nullable
|
||||
private List<Locale.LanguageRange> acceptLanguage;
|
||||
|
||||
public void setRemoteAddress(@Nullable final InetAddress remoteAddress) {
|
||||
this.remoteAddress = remoteAddress;
|
||||
}
|
||||
|
||||
public void setUserAgent(@Nullable final UserAgent userAgent) {
|
||||
this.userAgent = userAgent;
|
||||
}
|
||||
|
||||
public void setAcceptLanguage(@Nullable final List<Locale.LanguageRange> acceptLanguage) {
|
||||
this.acceptLanguage = acceptLanguage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> serverCall,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
Context context = Context.current();
|
||||
|
||||
if (remoteAddress != null) {
|
||||
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress);
|
||||
}
|
||||
|
||||
if (userAgent != null) {
|
||||
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent);
|
||||
}
|
||||
|
||||
if (acceptLanguage != null) {
|
||||
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage);
|
||||
}
|
||||
|
||||
return Contexts.interceptCall(context, serverCall, headers, next);
|
||||
}
|
||||
}
|
|
@ -17,16 +17,13 @@ import static org.mockito.Mockito.when;
|
|||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Channel;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.stub.MetadataUtils;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Instant;
|
||||
import java.time.temporal.ChronoUnit;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
@ -79,6 +76,8 @@ import org.whispersystems.textsecuregcm.storage.VersionedProfile;
|
|||
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
|
||||
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
|
||||
|
||||
|
@ -100,6 +99,14 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
|
|||
|
||||
@Override
|
||||
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
|
||||
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
|
||||
|
||||
try {
|
||||
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
|
||||
return new ProfileAnonymousGrpcService(
|
||||
accountsManager,
|
||||
profilesManager,
|
||||
|
@ -108,14 +115,6 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
|
|||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
|
||||
final Metadata metadata = new Metadata();
|
||||
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
|
||||
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
|
||||
return super.createStub(channel).withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
|
||||
}
|
||||
|
||||
@Test
|
||||
void getUnversionedProfile() {
|
||||
final UUID targetUuid = UUID.randomUUID();
|
||||
|
|
|
@ -25,7 +25,6 @@ import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusEx
|
|||
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||
import com.google.i18n.phonenumbers.Phonenumber;
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Status;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Clock;
|
||||
|
@ -34,6 +33,7 @@ import java.time.Instant;
|
|||
import java.time.temporal.ChronoUnit;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
@ -103,6 +103,8 @@ import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
|
|||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
import reactor.core.publisher.Mono;
|
||||
import software.amazon.awssdk.services.s3.S3AsyncClient;
|
||||
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
|
||||
|
@ -167,9 +169,14 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
|
|||
final String phoneNumber = PhoneNumberUtil.getInstance().format(
|
||||
PhoneNumberUtil.getInstance().getExampleNumber("US"),
|
||||
PhoneNumberUtil.PhoneNumberFormat.E164);
|
||||
final Metadata metadata = new Metadata();
|
||||
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
|
||||
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
|
||||
|
||||
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
|
||||
|
||||
try {
|
||||
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
|
||||
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
|
||||
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||
import org.signal.chat.rpc.GetRequestAttributesResponse;
|
||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
import org.signal.chat.rpc.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
public class RequestAttributesServiceImpl extends RequestAttributesGrpc.RequestAttributesImplBase {
|
||||
|
||||
@Override
|
||||
public void getRequestAttributes(final GetRequestAttributesRequest request,
|
||||
final StreamObserver<GetRequestAttributesResponse> responseObserver) {
|
||||
|
||||
final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder();
|
||||
|
||||
RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages ->
|
||||
acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString())));
|
||||
|
||||
RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale ->
|
||||
responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag()));
|
||||
|
||||
responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress());
|
||||
|
||||
RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder()
|
||||
.setPlatform(userAgent.getPlatform().toString())
|
||||
.setVersion(userAgent.getVersion().toString())
|
||||
.setAdditionalSpecifiers(userAgent.getAdditionalSpecifiers().orElse(""))
|
||||
.build()));
|
||||
|
||||
responseObserver.onNext(responseBuilder.build());
|
||||
responseObserver.onCompleted();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void getAuthenticatedDevice(final GetAuthenticatedDeviceRequest request,
|
||||
final StreamObserver<GetAuthenticatedDeviceResponse> responseObserver) {
|
||||
|
||||
final GetAuthenticatedDeviceResponse.Builder responseBuilder = GetAuthenticatedDeviceResponse.newBuilder();
|
||||
|
||||
try {
|
||||
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
|
||||
|
||||
responseBuilder.setAccountIdentifier(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()));
|
||||
responseBuilder.setDeviceId(authenticatedDevice.deviceId());
|
||||
} catch (final Exception ignored) {
|
||||
}
|
||||
|
||||
responseObserver.onNext(responseBuilder.build());
|
||||
responseObserver.onCompleted();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,160 @@
|
|||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.net.InetAddresses;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
import io.netty.channel.DefaultEventLoopGroup;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.local.LocalServerChannel;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||
import org.signal.chat.rpc.GetRequestAttributesResponse;
|
||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
|
||||
class RequestAttributesUtilTest {
|
||||
|
||||
private static DefaultEventLoopGroup eventLoopGroup;
|
||||
|
||||
private ClientConnectionManager clientConnectionManager;
|
||||
|
||||
private Server server;
|
||||
private ManagedChannel managedChannel;
|
||||
|
||||
@BeforeAll
|
||||
static void setUpBeforeAll() {
|
||||
eventLoopGroup = new DefaultEventLoopGroup();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws IOException {
|
||||
final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server");
|
||||
|
||||
clientConnectionManager = mock(ClientConnectionManager.class);
|
||||
|
||||
when(clientConnectionManager.getRemoteAddress(any()))
|
||||
.thenReturn(Optional.of(InetAddresses.forString("127.0.0.1")));
|
||||
|
||||
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
|
||||
// sure that we're using local channels and addresses
|
||||
server = NettyServerBuilder.forAddress(serverAddress)
|
||||
.channelType(LocalServerChannel.class)
|
||||
.bossEventLoopGroup(eventLoopGroup)
|
||||
.workerEventLoopGroup(eventLoopGroup)
|
||||
.intercept(new RequestAttributesInterceptor(clientConnectionManager))
|
||||
.addService(new RequestAttributesServiceImpl())
|
||||
.build()
|
||||
.start();
|
||||
|
||||
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
|
||||
.channelType(LocalChannel.class)
|
||||
.eventLoopGroup(eventLoopGroup)
|
||||
.usePlaintext()
|
||||
.build();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
managedChannel.shutdown();
|
||||
server.shutdown();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void tearDownAfterAll() throws InterruptedException {
|
||||
eventLoopGroup.shutdownGracefully().await();
|
||||
}
|
||||
|
||||
@Test
|
||||
void getAcceptableLanguages() {
|
||||
when(clientConnectionManager.getAcceptableLanguages(any()))
|
||||
.thenReturn(Optional.empty());
|
||||
|
||||
assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty());
|
||||
|
||||
when(clientConnectionManager.getAcceptableLanguages(any()))
|
||||
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
|
||||
|
||||
assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList());
|
||||
}
|
||||
|
||||
@Test
|
||||
void getAvailableAcceptedLocales() {
|
||||
when(clientConnectionManager.getAcceptableLanguages(any()))
|
||||
.thenReturn(Optional.empty());
|
||||
|
||||
assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty());
|
||||
|
||||
when(clientConnectionManager.getAcceptableLanguages(any()))
|
||||
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
|
||||
|
||||
final GetRequestAttributesResponse response = getRequestAttributes();
|
||||
|
||||
assertFalse(response.getAvailableAcceptedLocalesList().isEmpty());
|
||||
response.getAvailableAcceptedLocalesList().forEach(languageTag -> {
|
||||
final Locale locale = Locale.forLanguageTag(languageTag);
|
||||
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()));
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void getRemoteAddress() {
|
||||
when(clientConnectionManager.getRemoteAddress(any()))
|
||||
.thenReturn(Optional.empty());
|
||||
|
||||
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes);
|
||||
|
||||
final String remoteAddressString = "6.7.8.9";
|
||||
|
||||
when(clientConnectionManager.getRemoteAddress(any()))
|
||||
.thenReturn(Optional.of(InetAddresses.forString(remoteAddressString)));
|
||||
|
||||
assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress());
|
||||
}
|
||||
|
||||
@Test
|
||||
void getUserAgent() throws UnrecognizedUserAgentException {
|
||||
when(clientConnectionManager.getUserAgent(any()))
|
||||
.thenReturn(Optional.empty());
|
||||
|
||||
assertFalse(getRequestAttributes().hasUserAgent());
|
||||
|
||||
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
|
||||
|
||||
when(clientConnectionManager.getUserAgent(any()))
|
||||
.thenReturn(Optional.of(userAgent));
|
||||
|
||||
final GetRequestAttributesResponse response = getRequestAttributes();
|
||||
assertTrue(response.hasUserAgent());
|
||||
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
|
||||
assertEquals("1.2.3", response.getUserAgent().getVersion());
|
||||
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
|
||||
}
|
||||
|
||||
private GetRequestAttributesResponse getRequestAttributes() {
|
||||
return RequestAttributesGrpc.newBlockingStub(managedChannel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
|
||||
}
|
||||
}
|
|
@ -60,7 +60,7 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
|
|||
private AutoCloseable mocksCloseable;
|
||||
|
||||
private final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
|
||||
private final MockRemoteAddressInterceptor mockRemoteAddressInterceptor = new MockRemoteAddressInterceptor();
|
||||
private final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
|
||||
|
||||
private SERVICE service;
|
||||
|
||||
|
@ -114,8 +114,8 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
|
|||
mocksCloseable = MockitoAnnotations.openMocks(this);
|
||||
service = requireNonNull(createServiceBeforeEachTest(), "created service must not be `null`");
|
||||
GrpcTestUtils.setupAuthenticatedExtension(
|
||||
GRPC_SERVER_EXTENSION_AUTHENTICATED, mockAuthenticationInterceptor, mockRemoteAddressInterceptor, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service);
|
||||
GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, mockRemoteAddressInterceptor, service);
|
||||
GRPC_SERVER_EXTENSION_AUTHENTICATED, mockAuthenticationInterceptor, mockRequestAttributesInterceptor, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service);
|
||||
GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, mockRequestAttributesInterceptor, service);
|
||||
try {
|
||||
authenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());
|
||||
unauthenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_UNAUTHENTICATED.getChannel());
|
||||
|
@ -145,8 +145,8 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
|
|||
return unauthenticatedServiceStub;
|
||||
}
|
||||
|
||||
protected MockRemoteAddressInterceptor getMockRemoteAddressInterceptor() {
|
||||
return mockRemoteAddressInterceptor;
|
||||
protected MockRequestAttributesInterceptor getMockRequestAttributesInterceptor() {
|
||||
return mockRequestAttributesInterceptor;
|
||||
}
|
||||
|
||||
protected MockAuthenticationInterceptor getMockAuthenticationInterceptor() {
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.vdurmont.semver4j.Semver;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||
import io.grpc.inprocess.InProcessServerBuilder;
|
||||
import io.grpc.stub.MetadataUtils;
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.signal.chat.rpc.EchoRequest;
|
||||
import org.signal.chat.rpc.EchoResponse;
|
||||
import org.signal.chat.rpc.EchoServiceGrpc;
|
||||
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
|
||||
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class UserAgentInterceptorTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testInterceptor(final String header, final ClientPlatform platform, final String version) throws Exception {
|
||||
|
||||
final AtomicReference<UserAgent> observedUserAgent = new AtomicReference<>(null);
|
||||
final EchoServiceImpl serviceImpl = new EchoServiceImpl() {
|
||||
@Override
|
||||
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
|
||||
observedUserAgent.set(UserAgentUtil.userAgentFromGrpcContext());
|
||||
super.echo(req, responseObserver);
|
||||
}
|
||||
};
|
||||
|
||||
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
|
||||
.directExecutor()
|
||||
.addService(serviceImpl)
|
||||
.intercept(new UserAgentInterceptor())
|
||||
.build()
|
||||
.start();
|
||||
|
||||
try {
|
||||
final ManagedChannel channel = InProcessChannelBuilder.forName("RemoteDeprecationFilterTest")
|
||||
.directExecutor()
|
||||
.userAgent(header)
|
||||
.build();
|
||||
|
||||
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
|
||||
|
||||
final EchoRequest req = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("cluck cluck, i'm a parrot")).build();
|
||||
assertEquals("cluck cluck, i'm a parrot", client.echo(req).getPayload().toStringUtf8());
|
||||
if (platform == null) {
|
||||
assertNull(observedUserAgent.get());
|
||||
} else {
|
||||
assertEquals(platform, observedUserAgent.get().getPlatform());
|
||||
assertEquals(new Semver(version), observedUserAgent.get().getVersion());
|
||||
// can't assert on the additional specifiers because they include internal details of the grpc in-process channel itself
|
||||
}
|
||||
} finally {
|
||||
testServer.shutdownNow();
|
||||
testServer.awaitTermination();
|
||||
}
|
||||
}
|
||||
|
||||
private static Stream<Arguments> testInterceptor() {
|
||||
return Stream.of(
|
||||
Arguments.of(null, null, null),
|
||||
Arguments.of("", null, null),
|
||||
Arguments.of("Unrecognized UA", null, null),
|
||||
Arguments.of("Signal-Android/4.68.3", ClientPlatform.ANDROID, "4.68.3"),
|
||||
Arguments.of("Signal-iOS/3.9.0", ClientPlatform.IOS, "3.9.0"),
|
||||
Arguments.of("Signal-Desktop/1.2.3", ClientPlatform.DESKTOP, "1.2.3"),
|
||||
Arguments.of("Signal-Desktop/8.0.0-beta.2", ClientPlatform.DESKTOP, "8.0.0-beta.2"),
|
||||
Arguments.of("Signal-iOS/8.0.0-beta.2", ClientPlatform.IOS, "8.0.0-beta.2"));
|
||||
}
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import org.signal.chat.rpc.AuthenticationTypeGrpc;
|
||||
import org.signal.chat.rpc.GetAuthenticatedRequest;
|
||||
import org.signal.chat.rpc.GetAuthenticatedResponse;
|
||||
|
||||
public class AuthenticationTypeService extends AuthenticationTypeGrpc.AuthenticationTypeImplBase {
|
||||
|
||||
private final boolean authenticated;
|
||||
|
||||
public AuthenticationTypeService(final boolean authenticated) {
|
||||
this.authenticated = authenticated;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void getAuthenticated(final GetAuthenticatedRequest request, final StreamObserver<GetAuthenticatedResponse> responseObserver) {
|
||||
responseObserver.onNext(GetAuthenticatedResponse.newBuilder().setAuthenticated(authenticated).build());
|
||||
responseObserver.onCompleted();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,283 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.google.common.net.InetAddresses;
|
||||
import com.vdurmont.semver4j.Semver;
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.DefaultEventLoopGroup;
|
||||
import io.netty.channel.EventLoopGroup;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.local.LocalServerChannel;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.net.InetAddress;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class ClientConnectionManagerTest {
|
||||
|
||||
private static EventLoopGroup eventLoopGroup;
|
||||
|
||||
private LocalChannel localChannel;
|
||||
private LocalChannel remoteChannel;
|
||||
|
||||
private LocalServerChannel localServerChannel;
|
||||
|
||||
private ClientConnectionManager clientConnectionManager;
|
||||
|
||||
@BeforeAll
|
||||
static void setUpBeforeAll() {
|
||||
eventLoopGroup = new DefaultEventLoopGroup();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws InterruptedException {
|
||||
eventLoopGroup = new DefaultEventLoopGroup();
|
||||
|
||||
clientConnectionManager = new ClientConnectionManager();
|
||||
|
||||
// We have to jump through some hoops to get "real" LocalChannel instances to test with, and so we run a trivial
|
||||
// local server to which we can open trivial local connections
|
||||
localServerChannel = (LocalServerChannel) new ServerBootstrap()
|
||||
.group(eventLoopGroup)
|
||||
.channel(LocalServerChannel.class)
|
||||
.childHandler(new ChannelInitializer<>() {
|
||||
@Override
|
||||
protected void initChannel(final Channel channel) {
|
||||
}
|
||||
})
|
||||
.bind(new LocalAddress("test-server"))
|
||||
.await()
|
||||
.channel();
|
||||
|
||||
final Bootstrap clientBootstrap = new Bootstrap()
|
||||
.group(eventLoopGroup)
|
||||
.channel(LocalChannel.class)
|
||||
.handler(new ChannelInitializer<>() {
|
||||
@Override
|
||||
protected void initChannel(final Channel ch) {
|
||||
}
|
||||
});
|
||||
|
||||
localChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
|
||||
remoteChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws InterruptedException {
|
||||
localChannel.close().await();
|
||||
remoteChannel.close().await();
|
||||
localServerChannel.close().await();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void tearDownAfterAll() throws InterruptedException {
|
||||
eventLoopGroup.shutdownGracefully().await();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getAuthenticatedDevice(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
|
||||
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
|
||||
|
||||
assertEquals(maybeAuthenticatedDevice,
|
||||
clientConnectionManager.getAuthenticatedDevice(localChannel.localAddress()));
|
||||
}
|
||||
|
||||
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
|
||||
return List.of(
|
||||
Optional.of(new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID)),
|
||||
Optional.empty()
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getAcceptableLanguages() {
|
||||
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
|
||||
|
||||
assertEquals(Optional.empty(),
|
||||
clientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
|
||||
|
||||
final List<Locale.LanguageRange> acceptLanguageRanges = Locale.LanguageRange.parse("en,ja");
|
||||
remoteChannel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges);
|
||||
|
||||
assertEquals(Optional.of(acceptLanguageRanges),
|
||||
clientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
|
||||
}
|
||||
|
||||
@Test
|
||||
void getRemoteAddress() {
|
||||
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
|
||||
|
||||
assertEquals(Optional.empty(),
|
||||
clientConnectionManager.getRemoteAddress(localChannel.localAddress()));
|
||||
|
||||
final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9");
|
||||
remoteChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress);
|
||||
|
||||
assertEquals(Optional.of(remoteAddress),
|
||||
clientConnectionManager.getRemoteAddress(localChannel.localAddress()));
|
||||
}
|
||||
|
||||
@Test
|
||||
void getUserAgent() throws UnrecognizedUserAgentException {
|
||||
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
|
||||
|
||||
assertEquals(Optional.empty(),
|
||||
clientConnectionManager.getUserAgent(localChannel.localAddress()));
|
||||
|
||||
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
|
||||
remoteChannel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent);
|
||||
|
||||
assertEquals(Optional.of(userAgent),
|
||||
clientConnectionManager.getUserAgent(localChannel.localAddress()));
|
||||
}
|
||||
|
||||
@Test
|
||||
void closeConnection() throws InterruptedException {
|
||||
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
|
||||
|
||||
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
|
||||
|
||||
assertTrue(remoteChannel.isOpen());
|
||||
|
||||
assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
|
||||
assertEquals(List.of(remoteChannel),
|
||||
clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
|
||||
remoteChannel.close().await();
|
||||
|
||||
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
|
||||
assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleWebSocketHandshakeCompleteRemoteAddress() {
|
||||
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
|
||||
|
||||
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
|
||||
|
||||
ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
|
||||
preferredRemoteAddress,
|
||||
null,
|
||||
null);
|
||||
|
||||
assertEquals(preferredRemoteAddress,
|
||||
embeddedChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader,
|
||||
@Nullable final UserAgent expectedParsedUserAgent) {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
|
||||
|
||||
ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
|
||||
InetAddresses.forString("127.0.0.1"),
|
||||
userAgentHeader,
|
||||
null);
|
||||
|
||||
assertEquals(userAgentHeader,
|
||||
embeddedChannel.attr(ClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get());
|
||||
|
||||
assertEquals(expectedParsedUserAgent,
|
||||
embeddedChannel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
private static List<Arguments> handleWebSocketHandshakeCompleteUserAgent() {
|
||||
return List.of(
|
||||
// Recognized user-agent
|
||||
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
|
||||
|
||||
// Unrecognized user-agent
|
||||
Arguments.of("Not a valid user-agent string", null),
|
||||
|
||||
// Missing user-agent
|
||||
Arguments.of(null, null)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader,
|
||||
@Nullable final List<Locale.LanguageRange> expectedLanguageRanges) {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
|
||||
|
||||
ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
|
||||
InetAddresses.forString("127.0.0.1"),
|
||||
null,
|
||||
acceptLanguageHeader);
|
||||
|
||||
assertEquals(expectedLanguageRanges,
|
||||
embeddedChannel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
private static List<Arguments> handleWebSocketHandshakeCompleteAcceptLanguage() {
|
||||
return List.of(
|
||||
// Parseable list
|
||||
Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")),
|
||||
|
||||
// Unparsable list
|
||||
Arguments.of("This is not a valid language preference list", null),
|
||||
|
||||
// Missing list
|
||||
Arguments.of(null, null)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleConnectionEstablishedAuthenticated() throws InterruptedException {
|
||||
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
|
||||
|
||||
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
|
||||
assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
|
||||
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
|
||||
|
||||
assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
|
||||
assertEquals(List.of(remoteChannel), clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
|
||||
remoteChannel.close().await();
|
||||
|
||||
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
|
||||
assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleConnectionEstablishedAnonymous() throws InterruptedException {
|
||||
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
|
||||
|
||||
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
|
||||
|
||||
assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
|
||||
|
||||
remoteChannel.close().await();
|
||||
|
||||
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
|
||||
}
|
||||
}
|
|
@ -8,8 +8,8 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
|
|||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpClientCodec;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
|
||||
|
@ -35,6 +35,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
private final ECPublicKey rootPublicKey;
|
||||
@Nullable private final UUID accountIdentifier;
|
||||
private final byte deviceId;
|
||||
private final HttpHeaders headers;
|
||||
private final SocketAddress remoteServerAddress;
|
||||
private final WebSocketCloseListener webSocketCloseListener;
|
||||
|
||||
|
@ -50,6 +51,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
final ECPublicKey rootPublicKey,
|
||||
@Nullable final UUID accountIdentifier,
|
||||
final byte deviceId,
|
||||
final HttpHeaders headers,
|
||||
final SocketAddress remoteServerAddress,
|
||||
final WebSocketCloseListener webSocketCloseListener) {
|
||||
|
||||
|
@ -60,6 +62,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
this.rootPublicKey = rootPublicKey;
|
||||
this.accountIdentifier = accountIdentifier;
|
||||
this.deviceId = deviceId;
|
||||
this.headers = headers;
|
||||
this.remoteServerAddress = remoteServerAddress;
|
||||
this.webSocketCloseListener = webSocketCloseListener;
|
||||
}
|
||||
|
@ -87,7 +90,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
WebSocketVersion.V13,
|
||||
null,
|
||||
false,
|
||||
new DefaultHttpHeaders(),
|
||||
headers,
|
||||
Noise.MAX_PACKET_LEN,
|
||||
10_000))
|
||||
.addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))
|
||||
|
|
|
@ -11,6 +11,7 @@ import java.net.SocketAddress;
|
|||
import java.net.URI;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.UUID;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import javax.annotation.Nullable;
|
||||
|
@ -30,6 +31,7 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
|
|||
final ECPublicKey rootPublicKey,
|
||||
@Nullable final UUID accountIdentifier,
|
||||
final byte deviceId,
|
||||
final HttpHeaders headers,
|
||||
final X509Certificate trustedServerCertificate,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final WebSocketCloseListener webSocketCloseListener) {
|
||||
|
@ -48,6 +50,7 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
|
|||
rootPublicKey,
|
||||
accountIdentifier,
|
||||
deviceId,
|
||||
headers,
|
||||
remoteServerAddress,
|
||||
webSocketCloseListener));
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyByte;
|
||||
|
@ -17,6 +16,8 @@ import io.netty.channel.DefaultEventLoopGroup;
|
|||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
|
@ -35,28 +36,41 @@ import java.security.cert.X509Certificate;
|
|||
import java.security.spec.InvalidKeySpecException;
|
||||
import java.security.spec.PKCS8EncodedKeySpec;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.chat.rpc.AuthenticationTypeGrpc;
|
||||
import org.signal.chat.rpc.GetAuthenticatedRequest;
|
||||
import org.signal.chat.rpc.GetAuthenticatedResponse;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||
import org.signal.chat.rpc.GetRequestAttributesResponse;
|
||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
|
||||
|
||||
|
@ -66,6 +80,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
|
||||
private static X509Certificate serverTlsCertificate;
|
||||
|
||||
private ClientConnectionManager clientConnectionManager;
|
||||
private ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private ECKeyPair rootKeyPair;
|
||||
|
@ -79,6 +94,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
|
||||
private static final byte DEVICE_ID = Device.PRIMARY_ID;
|
||||
|
||||
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.randomAlphanumeric(16);
|
||||
|
||||
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
|
||||
// They were generated with:
|
||||
//
|
||||
|
@ -133,6 +150,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
clientKeyPair = Curve.generateKeyPair();
|
||||
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
|
||||
|
||||
clientConnectionManager = new ClientConnectionManager();
|
||||
|
||||
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
@ -146,7 +165,9 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
|
||||
@Override
|
||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||
serverBuilder.addService(new AuthenticationTypeService(true));
|
||||
serverBuilder.addService(new RequestAttributesServiceImpl())
|
||||
.intercept(new RequestAttributesInterceptor(clientConnectionManager))
|
||||
.intercept(new RequireAuthenticationInterceptor(clientConnectionManager));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -155,7 +176,9 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
|
||||
@Override
|
||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||
serverBuilder.addService(new AuthenticationTypeService(false));
|
||||
serverBuilder.addService(new RequestAttributesServiceImpl())
|
||||
.intercept(new RequestAttributesInterceptor(clientConnectionManager))
|
||||
.intercept(new ProhibitAuthenticationInterceptor(clientConnectionManager));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -166,11 +189,13 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
serverTlsPrivateKey,
|
||||
nioEventLoopGroup,
|
||||
delegatedTaskExecutor,
|
||||
clientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
serverKeyPair,
|
||||
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
|
||||
authenticatedGrpcServerAddress,
|
||||
anonymousGrpcServerAddress);
|
||||
anonymousGrpcServerAddress,
|
||||
RECOGNIZED_PROXY_SECRET);
|
||||
|
||||
websocketNoiseTunnelServer.start();
|
||||
}
|
||||
|
@ -198,10 +223,11 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||
|
||||
try {
|
||||
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
|
||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
|
||||
assertTrue(response.getAuthenticated());
|
||||
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
||||
assertEquals(DEVICE_ID, response.getDeviceId());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
|
@ -215,15 +241,15 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
|
||||
// Try to verify the server's public key with something other than the key with which it was signed
|
||||
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
|
||||
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
|
||||
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||
|
||||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
|
@ -247,8 +273,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
|
@ -272,8 +298,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
|
@ -294,6 +320,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
rootKeyPair.getPublicKey(),
|
||||
ACCOUNT_IDENTIFIER,
|
||||
DEVICE_ID,
|
||||
new DefaultHttpHeaders(),
|
||||
serverTlsCertificate,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
|
@ -304,8 +331,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
|
@ -320,10 +347,11 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||
|
||||
try {
|
||||
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
|
||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
|
||||
assertFalse(response.getAuthenticated());
|
||||
assertTrue(response.getAccountIdentifier().isEmpty());
|
||||
assertEquals(0, response.getDeviceId());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
|
@ -336,15 +364,15 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
|
||||
// Try to verify the server's public key with something other than the key with which it was signed
|
||||
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
|
||||
buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
|
||||
buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||
|
||||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
|
@ -365,6 +393,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
rootKeyPair.getPublicKey(),
|
||||
null,
|
||||
(byte) 0,
|
||||
new DefaultHttpHeaders(),
|
||||
serverTlsCertificate,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
|
@ -375,8 +404,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
|
@ -438,6 +467,86 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void getRequestAttributes() throws InterruptedException {
|
||||
final String remoteAddress = "4.5.6.7";
|
||||
final String acceptLanguage = "en";
|
||||
final String userAgent = "Signal-Desktop/1.2.3 Linux";
|
||||
|
||||
final HttpHeaders headers = new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||
.add("X-Forwarded-For", remoteAddress)
|
||||
.add("Accept-Language", acceptLanguage)
|
||||
.add("User-Agent", userAgent);
|
||||
|
||||
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
|
||||
buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), headers)) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||
|
||||
try {
|
||||
final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
|
||||
|
||||
assertEquals(remoteAddress, response.getRemoteAddress());
|
||||
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
|
||||
|
||||
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
|
||||
assertEquals("1.2.3", response.getUserAgent().getVersion());
|
||||
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void closeForReauthentication() throws InterruptedException {
|
||||
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
|
||||
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
|
||||
final AtomicBoolean closedByServer = new AtomicBoolean(false);
|
||||
|
||||
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
|
||||
|
||||
@Override
|
||||
public void handleWebSocketClosedByClient(final int statusCode) {
|
||||
serverCloseStatusCode.set(statusCode);
|
||||
closedByServer.set(false);
|
||||
connectionCloseLatch.countDown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleWebSocketClosedByServer(final int statusCode) {
|
||||
serverCloseStatusCode.set(statusCode);
|
||||
closedByServer.set(true);
|
||||
connectionCloseLatch.countDown();
|
||||
}
|
||||
};
|
||||
|
||||
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAuthenticatedClient(webSocketCloseListener)) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||
|
||||
try {
|
||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
|
||||
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
||||
assertEquals(DEVICE_ID, response.getDeviceId());
|
||||
|
||||
clientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
|
||||
assertTrue(connectionCloseLatch.await(2, TimeUnit.SECONDS));
|
||||
|
||||
assertEquals(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED.getStatusCode(),
|
||||
serverCloseStatusCode.get());
|
||||
|
||||
assertTrue(closedByServer.get());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException {
|
||||
return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER);
|
||||
}
|
||||
|
@ -445,11 +554,12 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
|
||||
throws InterruptedException {
|
||||
|
||||
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey());
|
||||
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
|
||||
}
|
||||
|
||||
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
|
||||
final ECPublicKey rootPublicKey) throws InterruptedException {
|
||||
final ECPublicKey rootPublicKey,
|
||||
final HttpHeaders headers) throws InterruptedException {
|
||||
|
||||
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
|
||||
WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
|
||||
|
@ -458,6 +568,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
rootPublicKey,
|
||||
ACCOUNT_IDENTIFIER,
|
||||
DEVICE_ID,
|
||||
headers,
|
||||
serverTlsCertificate,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
|
@ -465,11 +576,12 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
|
||||
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient() throws InterruptedException {
|
||||
return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey());
|
||||
return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
|
||||
}
|
||||
|
||||
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener,
|
||||
final ECPublicKey rootPublicKey) throws InterruptedException {
|
||||
final ECPublicKey rootPublicKey,
|
||||
final HttpHeaders headers) throws InterruptedException {
|
||||
|
||||
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
|
||||
WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI,
|
||||
|
@ -478,6 +590,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
rootPublicKey,
|
||||
null,
|
||||
(byte) 0,
|
||||
headers,
|
||||
serverTlsCertificate,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
import com.google.common.net.InetAddresses;
|
||||
import com.vdurmont.semver4j.Semver;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.SocketAddress;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
|
||||
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
private UserEventRecordingHandler userEventRecordingHandler;
|
||||
private MutableRemoteAddressEmbeddedChannel embeddedChannel;
|
||||
|
||||
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.randomAlphanumeric(16);
|
||||
|
||||
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final List<Object> receivedEvents = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
receivedEvents.add(event);
|
||||
}
|
||||
|
||||
public List<Object> getReceivedEvents() {
|
||||
return receivedEvents;
|
||||
}
|
||||
}
|
||||
|
||||
private static class MutableRemoteAddressEmbeddedChannel extends EmbeddedChannel {
|
||||
|
||||
private SocketAddress remoteAddress;
|
||||
|
||||
public MutableRemoteAddressEmbeddedChannel(final ChannelHandler... handlers) {
|
||||
super(handlers);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SocketAddress remoteAddress0() {
|
||||
return isActive() ? remoteAddress : null;
|
||||
}
|
||||
|
||||
public void setRemoteAddress(final SocketAddress remoteAddress) {
|
||||
this.remoteAddress = remoteAddress;
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
userEventRecordingHandler = new UserEventRecordingHandler();
|
||||
|
||||
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
|
||||
new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class),
|
||||
Curve.generateKeyPair(),
|
||||
new byte[64],
|
||||
RECOGNIZED_PROXY_SECRET),
|
||||
userEventRecordingHandler);
|
||||
|
||||
embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
|
||||
|
||||
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
|
||||
}
|
||||
|
||||
private static List<Arguments> handleWebSocketHandshakeComplete() {
|
||||
return List.of(
|
||||
Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
|
||||
Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleWebSocketHandshakeCompleteUnexpectedPath() {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleUnrecognizedEvent() {
|
||||
final Object unrecognizedEvent = new Object();
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
|
||||
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getRemoteAddress(final HttpHeaders headers, final SocketAddress remoteAddress, @Nullable InetAddress expectedRemoteAddress) {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete(
|
||||
WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, headers, null);
|
||||
|
||||
embeddedChannel.setRemoteAddress(remoteAddress);
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
|
||||
assertEquals(expectedRemoteAddress,
|
||||
embeddedChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
private static List<Arguments> getRemoteAddress() {
|
||||
final InetSocketAddress remoteAddress = new InetSocketAddress("5.6.7.8", 0);
|
||||
final InetAddress clientAddress = InetAddresses.forString("1.2.3.4");
|
||||
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
|
||||
|
||||
return List.of(
|
||||
// Recognized proxy, single forwarded-for address
|
||||
Arguments.of(new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
|
||||
remoteAddress,
|
||||
clientAddress),
|
||||
|
||||
// Recognized proxy, multiple forwarded-for addresses
|
||||
Arguments.of(new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
|
||||
remoteAddress,
|
||||
proxyAddress),
|
||||
|
||||
// No recognized proxy header, single forwarded-for address
|
||||
Arguments.of(new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
|
||||
remoteAddress,
|
||||
remoteAddress.getAddress()),
|
||||
|
||||
// No recognized proxy header, no forwarded-for address
|
||||
Arguments.of(new DefaultHttpHeaders(),
|
||||
remoteAddress,
|
||||
remoteAddress.getAddress()),
|
||||
|
||||
// Incorrect proxy header, single forwarded-for address
|
||||
Arguments.of(new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
|
||||
remoteAddress,
|
||||
remoteAddress.getAddress()),
|
||||
|
||||
// Recognized proxy, no forwarded-for address
|
||||
Arguments.of(new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
|
||||
remoteAddress,
|
||||
remoteAddress.getAddress()),
|
||||
|
||||
// Recognized proxy, bogus forwarded-for address
|
||||
Arguments.of(new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
|
||||
remoteAddress,
|
||||
null),
|
||||
|
||||
// No forwarded-for address, non-InetSocketAddress remote address
|
||||
Arguments.of(new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
|
||||
new LocalAddress("local-address"),
|
||||
null)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,91 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
class WebsocketHandshakeCompleteListenerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
private UserEventRecordingHandler userEventRecordingHandler;
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final List<Object> receivedEvents = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
receivedEvents.add(event);
|
||||
}
|
||||
|
||||
public List<Object> getReceivedEvents() {
|
||||
return receivedEvents;
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
userEventRecordingHandler = new UserEventRecordingHandler();
|
||||
|
||||
embeddedChannel = new EmbeddedChannel(
|
||||
new WebsocketHandshakeCompleteListener(mock(ClientPublicKeysManager.class), Curve.generateKeyPair(), new byte[64]),
|
||||
userEventRecordingHandler);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
|
||||
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
|
||||
|
||||
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
|
||||
}
|
||||
|
||||
private static List<Arguments> handleWebSocketHandshakeComplete() {
|
||||
return List.of(
|
||||
Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
|
||||
Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleWebSocketHandshakeCompleteUnexpectedPath() {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
|
||||
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleUnrecognizedEvent() {
|
||||
final Object unrecognizedEvent = new Object();
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
|
||||
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
|
||||
}
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
option java_multiple_files = true;
|
||||
|
||||
package org.signal.chat.rpc;
|
||||
|
||||
// A simple test service that identifies its authentication type to callers
|
||||
service AuthenticationType {
|
||||
rpc GetAuthenticated (GetAuthenticatedRequest) returns (GetAuthenticatedResponse) {}
|
||||
}
|
||||
|
||||
message GetAuthenticatedRequest {
|
||||
}
|
||||
|
||||
message GetAuthenticatedResponse {
|
||||
bool authenticated = 1;
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
option java_multiple_files = true;
|
||||
|
||||
package org.signal.chat.rpc;
|
||||
|
||||
// A simple test service that echoes request attributes to callers
|
||||
service RequestAttributes {
|
||||
rpc GetRequestAttributes (GetRequestAttributesRequest) returns (GetRequestAttributesResponse) {}
|
||||
|
||||
rpc GetAuthenticatedDevice (GetAuthenticatedDeviceRequest) returns (GetAuthenticatedDeviceResponse) {}
|
||||
}
|
||||
|
||||
message GetRequestAttributesRequest {
|
||||
}
|
||||
|
||||
message GetRequestAttributesResponse {
|
||||
repeated string acceptable_languages = 1;
|
||||
repeated string available_accepted_locales = 2;
|
||||
string remote_address = 3;
|
||||
UserAgent user_agent = 4;
|
||||
}
|
||||
|
||||
message UserAgent {
|
||||
string platform = 1;
|
||||
string version = 2;
|
||||
string additional_specifiers = 3;
|
||||
}
|
||||
|
||||
message GetAuthenticatedDeviceRequest {
|
||||
}
|
||||
|
||||
message GetAuthenticatedDeviceResponse {
|
||||
bytes account_identifier = 1;
|
||||
uint32 device_id = 2;
|
||||
}
|
Loading…
Reference in New Issue