Introduce a Noise-over-WebSocket client connection manager

This commit is contained in:
Jon Chambers 2024-03-22 15:20:55 -04:00 committed by GitHub
parent 075a08884b
commit aec6ac019f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
53 changed files with 1818 additions and 933 deletions

View File

@ -95,3 +95,5 @@ turn.secret: AAAAAAAAAAA=
linkDevice.secret: AAAAAAAAAAA= linkDevice.secret: AAAAAAAAAAA=
tlsKeyStore.password: unset tlsKeyStore.password: unset
noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA

View File

@ -468,3 +468,7 @@ callingTurnManualTable:
s3Bucket: a-bucket s3Bucket: a-bucket
objectKey: an-object.tar.gz objectKey: an-object.tar.gz
maxSize: 32777216 maxSize: 32777216
noiseTunnel:
port: 8443
recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret

View File

@ -39,6 +39,7 @@ import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration;
import org.whispersystems.textsecuregcm.configuration.NoiseWebSocketTunnelConfiguration;
import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration; import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration; import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration;
@ -333,6 +334,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty @JsonProperty
private MonitoredS3ObjectConfiguration callingTurnManualTable; private MonitoredS3ObjectConfiguration callingTurnManualTable;
@Valid
@NotNull
@JsonProperty
private NoiseWebSocketTunnelConfiguration noiseTunnel;
public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() { public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() {
return tlsKeyStore; return tlsKeyStore;
} }
@ -555,4 +561,8 @@ public class WhisperServerConfiguration extends Configuration {
public MonitoredS3ObjectConfiguration getCallingTurnManualTable() { public MonitoredS3ObjectConfiguration getCallingTurnManualTable() {
return callingTurnManualTable; return callingTurnManualTable;
} }
public NoiseWebSocketTunnelConfiguration getNoiseWebSocketTunnelConfiguration() {
return noiseTunnel;
}
} }

View File

@ -71,7 +71,8 @@ import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
import org.whispersystems.textsecuregcm.auth.grpc.BasicCredentialAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager; import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
import org.whispersystems.textsecuregcm.backup.BackupManager; import org.whispersystems.textsecuregcm.backup.BackupManager;
import org.whispersystems.textsecuregcm.backup.BackupsDb; import org.whispersystems.textsecuregcm.backup.BackupsDb;
@ -127,7 +128,6 @@ import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.geo.MaxMindDatabaseManager; import org.whispersystems.textsecuregcm.geo.MaxMindDatabaseManager;
import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor;
import org.whispersystems.textsecuregcm.grpc.AccountsAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.AccountsAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService; import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService;
import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor; import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
@ -138,7 +138,8 @@ import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
import org.whispersystems.textsecuregcm.grpc.PaymentsGrpcService; import org.whispersystems.textsecuregcm.grpc.PaymentsGrpcService;
import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService; import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService;
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor; import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup; import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer; import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer; import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
@ -752,8 +753,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
geoIpCityDatabaseManager geoIpCityDatabaseManager
); );
final BasicCredentialAuthenticationInterceptor basicCredentialAuthenticationInterceptor = final ClientConnectionManager clientConnectionManager = new ClientConnectionManager();
new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager));
final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup(); final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup();
@ -762,8 +762,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new MetricCollectingServerInterceptor(Metrics.globalRegistry); new MetricCollectingServerInterceptor(Metrics.globalRegistry);
final ErrorMappingInterceptor errorMappingInterceptor = new ErrorMappingInterceptor(); final ErrorMappingInterceptor errorMappingInterceptor = new ErrorMappingInterceptor();
final AcceptLanguageInterceptor acceptLanguageInterceptor = new AcceptLanguageInterceptor(); final RequestAttributesInterceptor requestAttributesInterceptor =
final UserAgentInterceptor userAgentInterceptor = new UserAgentInterceptor(); new RequestAttributesInterceptor(clientConnectionManager);
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("grpc-anonymous"); final LocalAddress anonymousGrpcServerAddress = new LocalAddress("grpc-anonymous");
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("grpc-authenticated"); final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("grpc-authenticated");
@ -778,9 +778,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
// TODO: specialize metrics with user-agent platform // TODO: specialize metrics with user-agent platform
.intercept(metricCollectingServerInterceptor) .intercept(metricCollectingServerInterceptor)
.intercept(errorMappingInterceptor) .intercept(errorMappingInterceptor)
.intercept(acceptLanguageInterceptor)
.intercept(remoteDeprecationFilter) .intercept(remoteDeprecationFilter)
.intercept(userAgentInterceptor) .intercept(requestAttributesInterceptor)
.intercept(new ProhibitAuthenticationInterceptor(clientConnectionManager))
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters)) .addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
.addService(new KeysAnonymousGrpcService(accountsManager, keysManager)) .addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
.addService(new PaymentsGrpcService(currencyManager)) .addService(new PaymentsGrpcService(currencyManager))
@ -799,10 +799,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
// TODO: specialize metrics with user-agent platform // TODO: specialize metrics with user-agent platform
.intercept(metricCollectingServerInterceptor) .intercept(metricCollectingServerInterceptor)
.intercept(errorMappingInterceptor) .intercept(errorMappingInterceptor)
.intercept(acceptLanguageInterceptor)
.intercept(remoteDeprecationFilter) .intercept(remoteDeprecationFilter)
.intercept(userAgentInterceptor) .intercept(requestAttributesInterceptor)
.intercept(new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager))) .intercept(new RequireAuthenticationInterceptor(clientConnectionManager))
.addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager)) .addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager))
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters)) .addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
.addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters)) .addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters))

View File

@ -0,0 +1,34 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import java.util.Optional;
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
private final ClientConnectionManager clientConnectionManager;
private static final Metadata EMPTY_TRAILERS = new Metadata();
AbstractAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) {
this.clientConnectionManager = clientConnectionManager;
}
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
return clientConnectionManager.getAuthenticatedDevice(localAddress);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) {
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
return new ServerCall.Listener<>() {};
}
}

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Status; import io.grpc.Status;
import java.util.UUID;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -16,8 +15,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
*/ */
public class AuthenticationUtil { public class AuthenticationUtil {
static final Context.Key<UUID> CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY = Context.key("authenticated-aci"); static final Context.Key<AuthenticatedDevice> CONTEXT_AUTHENTICATED_DEVICE = Context.key("authenticated-device");
static final Context.Key<Byte> CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id");
/** /**
* Returns the account/device authenticated in the current gRPC context or throws an "unauthenticated" exception if * Returns the account/device authenticated in the current gRPC context or throws an "unauthenticated" exception if
@ -29,11 +27,10 @@ public class AuthenticationUtil {
* could be retrieved from the current gRPC context * could be retrieved from the current gRPC context
*/ */
public static AuthenticatedDevice requireAuthenticatedDevice() { public static AuthenticatedDevice requireAuthenticatedDevice() {
@Nullable final UUID accountIdentifier = CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get(); @Nullable final AuthenticatedDevice authenticatedDevice = CONTEXT_AUTHENTICATED_DEVICE.get();
@Nullable final Byte deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get();
if (accountIdentifier != null && deviceId != null) { if (authenticatedDevice != null) {
return new AuthenticatedDevice(accountIdentifier, deviceId); return authenticatedDevice;
} }
throw Status.UNAUTHENTICATED.asRuntimeException(); throw Status.UNAUTHENTICATED.asRuntimeException();

View File

@ -1,85 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth.grpc;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.auth.basic.BasicCredentials;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
/**
* A basic credential authentication interceptor enforces the presence of a valid username and password on every call.
* Callers supply credentials by providing a username (UUID and optional device ID) and password pair in the
* {@code x-signal-basic-auth-credentials} call header.
* <p/>
* Downstream services can retrieve the identity of the authenticated caller using methods in
* {@link AuthenticationUtil}.
* <p/>
* Note that this authentication, while fully functional, is intended only for development and testing purposes and is
* intended to be replaced with a more robust and efficient strategy before widespread client adoption.
*
* @see AuthenticationUtil
* @see AccountAuthenticator
*/
public class BasicCredentialAuthenticationInterceptor implements ServerInterceptor {
private final AccountAuthenticator accountAuthenticator;
@VisibleForTesting
static final Metadata.Key<String> BASIC_CREDENTIALS =
Metadata.Key.of("x-signal-auth", Metadata.ASCII_STRING_MARSHALLER);
private static final Metadata EMPTY_TRAILERS = new Metadata();
public BasicCredentialAuthenticationInterceptor(final AccountAuthenticator accountAuthenticator) {
this.accountAuthenticator = accountAuthenticator;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
final String authHeader = headers.get(BASIC_CREDENTIALS);
if (StringUtils.isNotBlank(authHeader)) {
final Optional<BasicCredentials> maybeCredentials = HeaderUtils.basicCredentialsFromAuthHeader(authHeader);
if (maybeCredentials.isEmpty()) {
call.close(Status.UNAUTHENTICATED.withDescription("Could not parse credentials"), EMPTY_TRAILERS);
} else {
final Optional<AuthenticatedAccount> maybeAuthenticatedAccount =
accountAuthenticator.authenticate(maybeCredentials.get());
if (maybeAuthenticatedAccount.isPresent()) {
final AuthenticatedAccount authenticatedAccount = maybeAuthenticatedAccount.get();
final Context context = Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedAccount.getAccount().getUuid())
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedAccount.getAuthenticatedDevice().getId());
return Contexts.interceptCall(context, call, headers, next);
} else {
call.close(Status.UNAUTHENTICATED.withDescription("Credentials not accepted"), EMPTY_TRAILERS);
}
}
} else {
call.close(Status.UNAUTHENTICATED.withDescription("No credentials provided"), EMPTY_TRAILERS);
}
return new ServerCall.Listener<>() {};
}
}

View File

@ -0,0 +1,28 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
/**
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
* device are closed with an {@code UNAUTHENTICATED} status.
*/
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
public ProhibitAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) {
super(clientConnectionManager);
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return getAuthenticatedDevice(call)
.map(ignored -> closeAsUnauthenticated(call))
.orElseGet(() -> next.startCall(call, headers));
}
}

View File

@ -0,0 +1,32 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
/**
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
* status.
*/
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
public RequireAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) {
super(clientConnectionManager);
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return getAuthenticatedDevice(call)
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
call, headers, next))
.orElseGet(() -> closeAsUnauthenticated(call));
}
}

View File

@ -0,0 +1,8 @@
package org.whispersystems.textsecuregcm.configuration;
import javax.validation.constraints.NotNull;
import javax.validation.constraints.Positive;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
public record NoiseWebSocketTunnelConfiguration(@Positive int port, @NotNull SecretString recognizedProxySecret) {
}

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.filters; package org.whispersystems.textsecuregcm.filters;
import com.google.common.annotations.VisibleForTesting;
import java.io.IOException; import java.io.IOException;
import java.util.Optional; import java.util.Optional;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -72,8 +71,7 @@ public class RemoteAddressFilter implements Filter {
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP | * @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP |
* MDN</a> * MDN</a>
*/ */
@VisibleForTesting public static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
return Optional.ofNullable(forwardedFor) return Optional.ofNullable(forwardedFor)
.map(ff -> { .map(ff -> {
final int idx = forwardedFor.lastIndexOf(',') + 1; final int idx = forwardedFor.lastIndexOf(',') + 1;

View File

@ -17,6 +17,7 @@ import io.micrometer.core.instrument.Metrics;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import javax.annotation.Nullable;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
@ -26,6 +27,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesUtil;
import org.whispersystems.textsecuregcm.grpc.StatusConstants; import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@ -79,7 +81,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
if (shouldBlock(UserAgentUtil.userAgentFromGrpcContext())) { if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) {
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata()); call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
return new ServerCall.Listener<>() {}; return new ServerCall.Listener<>() {};
} else { } else {
@ -87,7 +89,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
} }
} }
private boolean shouldBlock(final UserAgent userAgent) { private boolean shouldBlock(@Nullable final UserAgent userAgent) {
final DynamicRemoteDeprecationConfiguration configuration = dynamicConfigurationManager final DynamicRemoteDeprecationConfiguration configuration = dynamicConfigurationManager
.getConfiguration().getRemoteDeprecationConfiguration(); .getConfiguration().getRemoteDeprecationConfiguration();
final Map<ClientPlatform, Semver> minimumVersionsByPlatform = configuration.getMinimumVersions(); final Map<ClientPlatform, Semver> minimumVersionsByPlatform = configuration.getMinimumVersions();

View File

@ -1,68 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class AcceptLanguageInterceptor implements ServerInterceptor {
private static final Logger logger = LoggerFactory.getLogger(AcceptLanguageInterceptor.class);
private static final String INVALID_ACCEPT_LANGUAGE_COUNTER_NAME = name(AcceptLanguageInterceptor.class, "invalidAcceptLanguage");
@VisibleForTesting
public static final Metadata.Key<String> ACCEPTABLE_LANGUAGES_GRPC_HEADER =
Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER);
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
final List<Locale> locales = parseLocales(headers.get(ACCEPTABLE_LANGUAGES_GRPC_HEADER));
return Contexts.interceptCall(
Context.current().withValue(AcceptLanguageUtil.ACCEPTABLE_LANGUAGES_CONTEXT_KEY, locales),
call,
headers,
next);
}
static List<Locale> parseLocales(@Nullable final String acceptableLanguagesHeader) {
if (acceptableLanguagesHeader == null) {
return Collections.emptyList();
}
try {
final List<Locale.LanguageRange> languageRanges = Locale.LanguageRange.parse(acceptableLanguagesHeader);
return Locale.filter(languageRanges, Arrays.asList(Locale.getAvailableLocales()));
} catch (final IllegalArgumentException e) {
final UserAgent userAgent = UserAgentUtil.userAgentFromGrpcContext();
Metrics.counter(INVALID_ACCEPT_LANGUAGE_COUNTER_NAME, "platform", userAgent.getPlatform().name().toLowerCase()).increment();
logger.debug("Could not get acceptable languages; Accept-Language: {}; User-Agent: {}",
acceptableLanguagesHeader,
userAgent,
e);
return Collections.emptyList();
}
}
}

View File

@ -1,17 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import java.util.List;
import java.util.Locale;
public class AcceptLanguageUtil {
static final Context.Key<List<Locale>> ACCEPTABLE_LANGUAGES_CONTEXT_KEY = Context.key("accept-language");
public static List<Locale> localeFromGrpcContext() {
return ACCEPTABLE_LANGUAGES_CONTEXT_KEY.get();
}
}

View File

@ -111,7 +111,7 @@ public class ProfileGrpcHelper {
case ACI -> { case ACI -> {
responseBuilder.setUnrestrictedUnidentifiedAccess(targetAccount.isUnrestrictedUnidentifiedAccess()) responseBuilder.setUnrestrictedUnidentifiedAccess(targetAccount.isUnrestrictedUnidentifiedAccess())
.addAllBadges(buildBadges(profileBadgeConverter.convert( .addAllBadges(buildBadges(profileBadgeConverter.convert(
AcceptLanguageUtil.localeFromGrpcContext(), RequestAttributesUtil.getAvailableAcceptedLocales(),
targetAccount.getBadges(), targetAccount.getBadges(),
ProfileHelper.isSelfProfileRequest(requesterUuid, (AciServiceIdentifier) targetIdentifier)))); ProfileHelper.isSelfProfileRequest(requesterUuid, (AciServiceIdentifier) targetIdentifier))));

View File

@ -5,28 +5,12 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.net.SocketAddress;
import java.time.Duration;
class RateLimitUtil { class RateLimitUtil {
private static final RateLimitExceededException UNKNOWN_REMOTE_ADDRESS_EXCEPTION =
new RateLimitExceededException(Duration.ofHours(1), true);
static Mono<Void> rateLimitByRemoteAddress(final RateLimiter rateLimiter) { static Mono<Void> rateLimitByRemoteAddress(final RateLimiter rateLimiter) {
return rateLimitByRemoteAddress(rateLimiter, true); return rateLimiter.validateReactive(RequestAttributesUtil.getRemoteAddress().getHostAddress());
}
static Mono<Void> rateLimitByRemoteAddress(final RateLimiter rateLimiter, final boolean failOnUnknownRemoteAddress) {
final SocketAddress remoteAddress = RemoteAddressUtil.getRemoteAddress();
if (remoteAddress != null) {
return rateLimiter.validateReactive(remoteAddress.toString());
} else {
return failOnUnknownRemoteAddress ? Mono.error(UNKNOWN_REMOTE_ADDRESS_EXCEPTION) : Mono.empty();
}
} }
} }

View File

@ -1,33 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import java.net.SocketAddress;
public class RemoteAddressInterceptor implements ServerInterceptor {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> serverCall,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
// Note: the specific implementation for getting a remote client address may change depending on the client
// connection strategy. The important thing is that the remote address wind up in the context for the current
// call so it can be retrieved by `RemoteAddressUtil`.
final SocketAddress remoteAddress = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
return Contexts.interceptCall(
Context.current().withValue(RemoteAddressUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress),
serverCall, headers, next);
}
}

View File

@ -1,23 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import java.net.SocketAddress;
public class RemoteAddressUtil {
static final Context.Key<SocketAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
/**
* Returns the socket address of the remote client in the current gRPC request context.
*
* @return the socket address of the remote client
*/
public static SocketAddress getRemoteAddress() {
return REMOTE_ADDRESS_CONTEXT_KEY.get();
}
}

View File

@ -0,0 +1,75 @@
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
public class RequestAttributesInterceptor implements ServerInterceptor {
private final ClientConnectionManager clientConnectionManager;
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
public RequestAttributesInterceptor(final ClientConnectionManager clientConnectionManager) {
this.clientConnectionManager = clientConnectionManager;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
Context context = Context.current();
{
final Optional<InetAddress> maybeRemoteAddress = clientConnectionManager.getRemoteAddress(localAddress);
if (maybeRemoteAddress.isEmpty()) {
// We should never have a call from a party whose remote address we can't identify
log.warn("No remote address available");
call.close(Status.INTERNAL, new Metadata());
return new ServerCall.Listener<>() {};
}
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
}
{
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
clientConnectionManager.getAcceptableLanguages(localAddress);
if (maybeAcceptLanguage.isPresent()) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
}
}
{
final Optional<UserAgent> maybeUserAgent = clientConnectionManager.getUserAgent(localAddress);
if (maybeUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
}
}
return Contexts.interceptCall(context, call, headers, next);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
}

View File

@ -0,0 +1,59 @@
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class RequestAttributesUtil {
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language");
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("user-agent");
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
/**
* Returns the acceptable languages listed by the remote client in the current gRPC request context.
*
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
*/
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() {
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get());
}
/**
* Returns a list of distinct locales supported by the JVM and accepted by the remote client in the current gRPC
* context. May be empty if the client did not supply a list of acceptable languages, if the list of acceptable
* languages could not be parsed, or if none of the acceptable languages are available in the current JVM.
*
* @return a list of distinct locales acceptable to the remote client and available in this JVM
*/
public static List<Locale> getAvailableAcceptedLocales() {
return getAcceptableLanguages()
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
.orElseGet(Collections::emptyList);
}
/**
* Returns the remote address of the remote client in the current gRPC request context.
*
* @return the remote address of the remote client
*/
public static InetAddress getRemoteAddress() {
return REMOTE_ADDRESS_CONTEXT_KEY.get();
}
/**
* Returns the parsed user-agent of the remote client in the current gRPC request context.
*
* @return the parsed user-agent of the remote client; may be null if unparseable or not specified
*/
public static Optional<UserAgent> getUserAgent() {
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
}
}

View File

@ -1,43 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
public class UserAgentInterceptor implements ServerInterceptor {
@VisibleForTesting
public static final Metadata.Key<String> USER_AGENT_GRPC_HEADER =
Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER);
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
UserAgent userAgent;
try {
userAgent = UserAgentUtil.parseUserAgentString(headers.get(USER_AGENT_GRPC_HEADER));
} catch (final UnrecognizedUserAgentException e) {
userAgent = null;
}
final Context context = Context.current().withValue(UserAgentUtil.USER_AGENT_CONTEXT_KEY, userAgent);
return Contexts.interceptCall(context, call, headers, next);
}
}

View File

@ -0,0 +1,218 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.util.AttributeKey;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
/**
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate.
*/
public class ClientConnectionManager {
private final Map<LocalAddress, Channel> remoteChannelsByLocalAddress = new ConcurrentHashMap<>();
private final Map<AuthenticatedDevice, List<Channel>> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>();
@VisibleForTesting
static final AttributeKey<AuthenticatedDevice> AUTHENTICATED_DEVICE_ATTRIBUTE_KEY =
AttributeKey.valueOf(ClientConnectionManager.class, "authenticatedDevice");
@VisibleForTesting
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress");
@VisibleForTesting
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent");
@VisibleForTesting
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
@VisibleForTesting
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
private static final Logger log = LoggerFactory.getLogger(ClientConnectionManager.class);
/**
* Returns the authenticated device associated with the given local address, if any. An authenticated device is
* available if and only if the given local address maps to an active local connection and that connection is
* authenticated (i.e. not anonymous).
*
* @param localAddress the local address for which to find an authenticated device
*
* @return the authenticated device associated with the given local address, if any
*/
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) {
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress));
}
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) {
return Optional.ofNullable(remoteChannel)
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
}
/**
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may
* be unavailable if the local connection associated with the given local address has already closed, if the client
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
*
* @param localAddress the local address for which to find acceptable languages
*
* @return the acceptable languages associated with the given local address, if any
*/
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
}
/**
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if
* the local connection associated with the given local address has already closed.
*
* @param localAddress the local address for which to find a remote address
*
* @return the remote address associated with the given local address, if any
*/
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
/**
* Returns the parsed user agent provided by the client the opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address or if the client's user-agent string was not recognized.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent associated with the given local address
*/
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
/**
* Closes any client connections to this host associated with the given authenticated device.
*
* @param authenticatedDevice the authenticated device for which to close connections
*/
public void closeConnection(final AuthenticatedDevice authenticatedDevice) {
// Channels will actually get removed from the list/map by their closeFuture listeners
remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()).forEach(channel ->
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
.toWebSocketCloseStatus("Reauthentication required")))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
}
@VisibleForTesting
@Nullable List<Channel> getRemoteChannelsByAuthenticatedDevice(final AuthenticatedDevice authenticatedDevice) {
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
}
@VisibleForTesting
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) {
return remoteChannelsByLocalAddress.get(localAddress);
}
/**
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
* request with the channel via which the handshake took place.
*
* @param channel the channel that completed a WebSocket handshake
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null}
*/
static void handleWebSocketHandshakeComplete(final Channel channel,
final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) {
channel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress);
if (StringUtils.isNotBlank(userAgentHeader)) {
channel.attr(ClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
try {
channel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
} catch (final UnrecognizedUserAgentException ignored) {
}
}
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try {
channel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader));
} catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
}
}
}
/**
* Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server.
*
* @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC
* server
* @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
*/
void handleConnectionEstablished(final LocalChannel localChannel,
final Channel remoteChannel,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
remoteChannel.attr(ClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
final List<Channel> channels = existingChannelList != null ? existingChannelList : new ArrayList<>();
channels.add(remoteChannel);
return channels;
}));
remoteChannel.closeFuture().addListener(closeFuture -> {
remoteChannelsByLocalAddress.remove(localChannel.localAddress());
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
if (existingChannelList == null) {
return null;
}
existingChannelList.remove(remoteChannel);
return existingChannelList.isEmpty() ? null : existingChannelList;
}));
});
}
}

View File

@ -22,15 +22,21 @@ import org.slf4j.LoggerFactory;
*/ */
class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
private final ClientConnectionManager clientConnectionManager;
private final LocalAddress authenticatedGrpcServerAddress; private final LocalAddress authenticatedGrpcServerAddress;
private final LocalAddress anonymousGrpcServerAddress; private final LocalAddress anonymousGrpcServerAddress;
private final List<Object> pendingReads = new ArrayList<>(); private final List<Object> pendingReads = new ArrayList<>();
private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class); private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class);
public EstablishLocalGrpcConnectionHandler(final LocalAddress authenticatedGrpcServerAddress, public EstablishLocalGrpcConnectionHandler(final ClientConnectionManager clientConnectionManager,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress) { final LocalAddress anonymousGrpcServerAddress) {
this.clientConnectionManager = clientConnectionManager;
this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress; this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress;
this.anonymousGrpcServerAddress = anonymousGrpcServerAddress; this.anonymousGrpcServerAddress = anonymousGrpcServerAddress;
} }
@ -41,7 +47,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
} }
@Override @Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) throws Exception { public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) { if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
@ -53,7 +59,6 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
new Bootstrap() new Bootstrap()
.remoteAddress(grpcServerAddress) .remoteAddress(grpcServerAddress)
// TODO Set local address
.channel(LocalChannel.class) .channel(LocalChannel.class)
.group(remoteChannelContext.channel().eventLoop()) .group(remoteChannelContext.channel().eventLoop())
.handler(new ChannelInitializer<LocalChannel>() { .handler(new ChannelInitializer<LocalChannel>() {
@ -63,15 +68,19 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
} }
}) })
.connect() .connect()
.addListener((ChannelFutureListener) future -> { .addListener((ChannelFutureListener) localChannelFuture -> {
if (future.isSuccess()) { if (localChannelFuture.isSuccess()) {
clientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(),
noiseHandshakeCompleteEvent.authenticatedDevice());
// Close the local connection if the remote channel closes and vice versa // Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> future.channel().close()); remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
future.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().closeFuture().addListener(closeFuture ->
remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART))); remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
remoteChannelContext.pipeline() remoteChannelContext.pipeline()
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(future.channel())); .addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
// Flush any buffered reads we accumulated while waiting to open the connection // Flush any buffered reads we accumulated while waiting to open the connection
pendingReads.forEach(remoteChannelContext::fireChannelRead); pendingReads.forEach(remoteChannelContext::fireChannelRead);
@ -79,7 +88,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this); remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this);
} else { } else {
log.warn("Failed to establish local connection to gRPC server", future.cause()); log.warn("Failed to establish local connection to gRPC server", localChannelFuture.cause());
remoteChannelContext.close(); remoteChannelContext.close();
} }
}); });

View File

@ -0,0 +1,134 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.InetAddresses;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.Optional;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
* Noise handshake handler for the requested path.
*/
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
private final ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair ecKeyPair;
private final byte[] publicKeySignature;
private final byte[] recognizedProxySecret;
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
@VisibleForTesting
static final String RECOGNIZED_PROXY_SECRET_HEADER = "X-Signal-Recognized-Proxy";
@VisibleForTesting
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final byte[] publicKeySignature,
final String recognizedProxySecret) {
this.clientPublicKeysManager = clientPublicKeysManager;
this.ecKeyPair = ecKeyPair;
this.publicKeySignature = publicKeySignature;
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
// MessageDigest.equals() later.
this.recognizedProxySecret = recognizedProxySecret.getBytes(StandardCharsets.UTF_8);
}
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
final InetAddress preferredRemoteAddress;
{
final Optional<InetAddress> maybePreferredRemoteAddress =
getPreferredRemoteAddress(context, handshakeCompleteEvent);
if (maybePreferredRemoteAddress.isEmpty()) {
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR,
"Could not determine remote address"))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
return;
}
preferredRemoteAddress = maybePreferredRemoteAddress.get();
}
ClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(),
preferredRemoteAddress,
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH ->
new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH ->
new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
default -> {
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
// internal error if something slipped through.
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
}
};
context.pipeline().replace(WebsocketHandshakeCompleteHandler.this, null, noiseHandshakeHandler);
}
context.fireUserEventTriggered(event);
}
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
final byte[] recognizedProxySecretFromHeader =
handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "")
.getBytes(StandardCharsets.UTF_8);
final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader);
if (trustForwardedFor && handshakeCompleteEvent.requestHeaders().contains(FORWARDED_FOR_HEADER)) {
final String forwardedFor = handshakeCompleteEvent.requestHeaders().get(FORWARDED_FOR_HEADER);
return RemoteAddressFilter.getMostRecentProxy(forwardedFor).map(mostRecentProxy -> {
try {
return InetAddresses.forString(mostRecentProxy);
} catch (final IllegalArgumentException e) {
log.warn("Failed to parse forwarded-for address: {}", forwardedFor, e);
return null;
}
});
} else {
// Either we don't trust the forwarded-for header or it's not present
if (context.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) {
return Optional.of(inetSocketAddress.getAddress());
} else {
log.warn("Channel's remote address was not an InetSocketAddress");
return Optional.empty();
}
}
}
}

View File

@ -1,52 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
* A WebSocket handshake listener waits for a WebSocket handshake to complete, then replaces itself with the appropriate
* Noise handshake handler for the requested path.
*/
class WebsocketHandshakeCompleteListener extends ChannelInboundHandlerAdapter {
private final ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair ecKeyPair;
private final byte[] publicKeySignature;
WebsocketHandshakeCompleteListener(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final byte[] publicKeySignature) {
this.clientPublicKeysManager = clientPublicKeysManager;
this.ecKeyPair = ecKeyPair;
this.publicKeySignature = publicKeySignature;
}
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH ->
new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH ->
new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
default -> {
// The HttpHandler should have caught all of these cases already; we'll consider it an internal error if
// something slipped through.
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
}
};
context.pipeline().replace(WebsocketHandshakeCompleteListener.this, null, noiseHandshakeHandler);
}
context.fireUserEventTriggered(event);
}
}

View File

@ -48,11 +48,13 @@ public class WebsocketNoiseTunnelServer implements Managed {
final PrivateKey tlsPrivateKey, final PrivateKey tlsPrivateKey,
final NioEventLoopGroup eventLoopGroup, final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor, final Executor delegatedTaskExecutor,
final ClientConnectionManager clientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager, final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair, final ECKeyPair ecKeyPair,
final byte[] publicKeySignature, final byte[] publicKeySignature,
final LocalAddress authenticatedGrpcServerAddress, final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress) throws SSLException { final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws SSLException {
final SslProvider sslProvider; final SslProvider sslProvider;
@ -88,10 +90,10 @@ public class WebsocketNoiseTunnelServer implements Managed {
.addLast(new RejectUnsupportedMessagesHandler()) .addLast(new RejectUnsupportedMessagesHandler())
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once // The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
// a WebSocket handshake has been completed // a WebSocket handshake has been completed
.addLast(new WebsocketHandshakeCompleteListener(clientPublicKeysManager, ecKeyPair, publicKeySignature)) .addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature, recognizedProxySecret))
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
// once the Noise handshake has completed // once the Noise handshake has completed
.addLast(new EstablishLocalGrpcConnectionHandler(authenticatedGrpcServerAddress, anonymousGrpcServerAddress)) .addLast(new EstablishLocalGrpcConnectionHandler(clientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
.addLast(new ErrorHandler()); .addLast(new ErrorHandler());
} }
}); });

View File

@ -7,15 +7,12 @@ package org.whispersystems.textsecuregcm.util.ua;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import io.grpc.Context;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
public class UserAgentUtil { public class UserAgentUtil {
public static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("x-signal-user-agent");
private static final Pattern STANDARD_UA_PATTERN = Pattern.compile("^Signal-(Android|Desktop|iOS)/([^ ]+)( (.+))?$", Pattern.CASE_INSENSITIVE); private static final Pattern STANDARD_UA_PATTERN = Pattern.compile("^Signal-(Android|Desktop|iOS)/([^ ]+)( (.+))?$", Pattern.CASE_INSENSITIVE);
public static UserAgent parseUserAgentString(final String userAgentString) throws UnrecognizedUserAgentException { public static UserAgent parseUserAgentString(final String userAgentString) throws UnrecognizedUserAgentException {
@ -36,10 +33,6 @@ public class UserAgentUtil {
throw new UnrecognizedUserAgentException(); throw new UnrecognizedUserAgentException();
} }
public static UserAgent userAgentFromGrpcContext() {
return USER_AGENT_CONTEXT_KEY.get();
}
@VisibleForTesting @VisibleForTesting
static UserAgent parseStandardUserAgentString(final String userAgentString) { static UserAgent parseStandardUserAgentString(final String userAgentString) {
final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString); final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);

View File

@ -0,0 +1,77 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import static org.mockito.Mockito.mock;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
abstract class AbstractAuthenticationInterceptorTest {
private static DefaultEventLoopGroup eventLoopGroup;
private ClientConnectionManager clientConnectionManager;
private Server server;
private ManagedChannel managedChannel;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-authentication-interceptor-server");
clientConnectionManager = mock(ClientConnectionManager.class);
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(getInterceptor())
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
protected abstract AbstractAuthenticationInterceptor getInterceptor();
protected ClientConnectionManager getClientConnectionManager() {
return clientConnectionManager;
}
protected GetAuthenticatedDeviceResponse getAuthenticatedDevice() {
return RequestAttributesGrpc.newBlockingStub(managedChannel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
}
}

View File

@ -1,135 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth.grpc;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.grpc.CallCredentials;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import java.io.IOException;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.Executor;
import java.util.stream.Stream;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
class BasicCredentialAuthenticationInterceptorTest {
private Server server;
private ManagedChannel managedChannel;
private AccountAuthenticator accountAuthenticator;
@BeforeEach
void setUp() throws IOException {
accountAuthenticator = mock(AccountAuthenticator.class);
final BasicCredentialAuthenticationInterceptor authenticationInterceptor =
new BasicCredentialAuthenticationInterceptor(accountAuthenticator);
final String serverName = InProcessServerBuilder.generateName();
server = InProcessServerBuilder.forName(serverName)
.directExecutor()
.intercept(authenticationInterceptor)
.addService(new EchoServiceImpl())
.build()
.start();
managedChannel = InProcessChannelBuilder.forName(serverName)
.directExecutor()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@ParameterizedTest
@MethodSource
void interceptCall(final Metadata headers, final boolean acceptCredentials, final boolean expectAuthentication) {
if (acceptCredentials) {
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(UUID.randomUUID());
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
} else {
when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.empty());
}
final EchoServiceGrpc.EchoServiceBlockingStub stub = EchoServiceGrpc.newBlockingStub(managedChannel)
.withCallCredentials(new CallCredentials() {
@Override
public void applyRequestMetadata(final RequestInfo requestInfo, final Executor appExecutor, final MetadataApplier applier) {
applier.apply(headers);
}
@Override
public void thisUsesUnstableApi() {
}
});
if (expectAuthentication) {
assertDoesNotThrow(() -> stub.echo(EchoRequest.newBuilder().build()));
} else {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> stub.echo(EchoRequest.newBuilder().build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), exception.getStatus().getCode());
}
}
private static Stream<Arguments> interceptCall() {
final Metadata malformedCredentialHeaders = new Metadata();
malformedCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, "Incorrect");
final Metadata structurallyValidCredentialHeaders = new Metadata();
structurallyValidCredentialHeaders.put(
BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS,
HeaderUtils.basicAuthHeader(UUID.randomUUID().toString(), RandomStringUtils.randomAlphanumeric(16))
);
return Stream.of(
Arguments.of(new Metadata(), true, false),
Arguments.of(malformedCredentialHeaders, true, false),
Arguments.of(structurallyValidCredentialHeaders, false, false),
Arguments.of(structurallyValidCredentialHeaders, true, true)
);
}
}

View File

@ -13,15 +13,14 @@ import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import java.util.UUID; import java.util.UUID;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.util.Pair;
public class MockAuthenticationInterceptor implements ServerInterceptor { public class MockAuthenticationInterceptor implements ServerInterceptor {
@Nullable @Nullable
private Pair<UUID, Byte> authenticatedDevice; private AuthenticatedDevice authenticatedDevice;
public void setAuthenticatedDevice(final UUID accountIdentifier, final byte deviceId) { public void setAuthenticatedDevice(final UUID accountIdentifier, final byte deviceId) {
authenticatedDevice = new Pair<>(accountIdentifier, deviceId); authenticatedDevice = new AuthenticatedDevice(accountIdentifier, deviceId);
} }
public void clearAuthenticatedDevice() { public void clearAuthenticatedDevice() {
@ -33,14 +32,10 @@ public class MockAuthenticationInterceptor implements ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
if (authenticatedDevice != null) { return authenticatedDevice != null
final Context context = Context.current() ? Contexts.interceptCall(
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedDevice.first()) Context.current().withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedDevice.second()); call, headers, next)
: next.startCall(call, headers);
return Contexts.interceptCall(context, call, headers, next);
}
return next.startCall(call, headers);
} }
} }

View File

@ -0,0 +1,40 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Status;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
@Override
protected AbstractAuthenticationInterceptor getInterceptor() {
return new ProhibitAuthenticationInterceptor(getClientConnectionManager());
}
@Test
void interceptCall() {
final ClientConnectionManager clientConnectionManager = getClientConnectionManager();
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertTrue(response.getAccountIdentifier().isEmpty());
assertEquals(0, response.getDeviceId());
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
}
}

View File

@ -0,0 +1,39 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import io.grpc.Status;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
@Override
protected AbstractAuthenticationInterceptor getInterceptor() {
return new RequireAuthenticationInterceptor(getClientConnectionManager());
}
@Test
void interceptCall() {
final ClientConnectionManager clientConnectionManager = getClientConnectionManager();
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
}
}

View File

@ -39,10 +39,12 @@ import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.StatusConstants; import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RemoteDeprecationFilterTest { class RemoteDeprecationFilterTest {
@ -126,17 +128,25 @@ class RemoteDeprecationFilterTest {
@ParameterizedTest @ParameterizedTest
@MethodSource(value="testFilter") @MethodSource(value="testFilter")
void testGrpcFilter(final String userAgent, final boolean expectDeprecation) throws Exception { void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
try {
mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString));
} catch (UnrecognizedUserAgentException ignored) {
}
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest") final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor() .directExecutor()
.addService(new EchoServiceImpl()) .addService(new EchoServiceImpl())
.intercept(filterConfiguredForTest()) .intercept(filterConfiguredForTest())
.intercept(new UserAgentInterceptor()) .intercept(mockRequestAttributesInterceptor)
.build() .build()
.start(); .start();
final ManagedChannel channel = InProcessChannelBuilder.forName("RemoteDeprecationFilterTest") final ManagedChannel channel = InProcessChannelBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor() .directExecutor()
.userAgent(userAgent) .userAgent(userAgentString)
.build(); .build();
try { try {

View File

@ -1,79 +0,0 @@
package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class AcceptLanguageInterceptorTest {
@ParameterizedTest
@MethodSource
void parseLocale(final String header, final List<Locale> expectedLocales) throws IOException, InterruptedException {
final AtomicReference<List<Locale>> observedLocales = new AtomicReference<>(null);
final EchoServiceImpl serviceImpl = new EchoServiceImpl() {
@Override
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
observedLocales.set(AcceptLanguageUtil.localeFromGrpcContext());
super.echo(req, responseObserver);
}
};
final Server testServer = InProcessServerBuilder.forName("AcceptLanguageTest")
.directExecutor()
.addService(serviceImpl)
.intercept(new AcceptLanguageInterceptor())
.intercept(new UserAgentInterceptor())
.build()
.start();
try {
final ManagedChannel channel = InProcessChannelBuilder.forName("AcceptLanguageTest")
.directExecutor()
.userAgent("Signal-Android/1.2.3")
.build();
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, header);
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel)
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
final EchoRequest request = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("test request")).build();
client.echo(request);
assertEquals(expectedLocales, observedLocales.get());
} finally {
testServer.shutdownNow();
testServer.awaitTermination();
}
}
private static Stream<Arguments> parseLocale() {
return Stream.of(
// en-US-POSIX is a special locale that exists alongside en-US. It matches because of the definition of
// basic filtering in RFC 4647 (https://datatracker.ietf.org/doc/html/rfc4647#section-3.3.1)
Arguments.of("en-US,fr-CA", List.of(Locale.forLanguageTag("en-US-POSIX"), Locale.forLanguageTag("en-US"), Locale.forLanguageTag("fr-CA"))),
Arguments.of("en-US; q=0.9, fr-CA", List.of(Locale.forLanguageTag("fr-CA"), Locale.forLanguageTag("en-US-POSIX"), Locale.forLanguageTag("en-US"))),
Arguments.of("invalid-locale,fr-CA", List.of(Locale.forLanguageTag("fr-CA"))),
Arguments.of("", Collections.emptyList()),
Arguments.of("acompletely,unexpectedfor , mat", Collections.emptyList())
);
}
}

View File

@ -13,9 +13,9 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
import java.net.InetSocketAddress;
import java.time.Duration; import java.time.Duration;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -72,7 +72,7 @@ class AccountsAnonymousGrpcServiceTest extends
when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
getMockRemoteAddressInterceptor().setRemoteAddress(new InetSocketAddress("127.0.0.1", 12345)); getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1"));
return new AccountsAnonymousGrpcService(accountsManager, rateLimiters); return new AccountsAnonymousGrpcService(accountsManager, rateLimiters);
} }

View File

@ -29,21 +29,21 @@ public final class GrpcTestUtils {
public static void setupAuthenticatedExtension( public static void setupAuthenticatedExtension(
final GrpcServerExtension extension, final GrpcServerExtension extension,
final MockAuthenticationInterceptor mockAuthenticationInterceptor, final MockAuthenticationInterceptor mockAuthenticationInterceptor,
final MockRemoteAddressInterceptor mockRemoteAddressInterceptor, final MockRequestAttributesInterceptor mockRequestAttributesInterceptor,
final UUID authenticatedAci, final UUID authenticatedAci,
final byte authenticatedDeviceId, final byte authenticatedDeviceId,
final BindableService service) { final BindableService service) {
mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId); mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId);
extension.getServiceRegistry() extension.getServiceRegistry()
.addService(ServerInterceptors.intercept(service, mockRemoteAddressInterceptor, mockAuthenticationInterceptor, new ErrorMappingInterceptor())); .addService(ServerInterceptors.intercept(service, mockRequestAttributesInterceptor, mockAuthenticationInterceptor, new ErrorMappingInterceptor()));
} }
public static void setupUnauthenticatedExtension( public static void setupUnauthenticatedExtension(
final GrpcServerExtension extension, final GrpcServerExtension extension,
final MockRemoteAddressInterceptor mockRemoteAddressInterceptor, final MockRequestAttributesInterceptor mockRequestAttributesInterceptor,
final BindableService service) { final BindableService service) {
extension.getServiceRegistry() extension.getServiceRegistry()
.addService(ServerInterceptors.intercept(service, mockRemoteAddressInterceptor, new ErrorMappingInterceptor())); .addService(ServerInterceptors.intercept(service, mockRequestAttributesInterceptor, new ErrorMappingInterceptor()));
} }
public static void assertStatusException(final Status expected, final Executable serviceCall) { public static void assertStatusException(final Status expected, final Executable serviceCall) {

View File

@ -1,37 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import java.net.SocketAddress;
import javax.annotation.Nullable;
public class MockRemoteAddressInterceptor implements ServerInterceptor {
@Nullable
private SocketAddress remoteAddress;
public void setRemoteAddress(@Nullable final SocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> serverCall,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return remoteAddress == null
? next.startCall(serverCall, headers)
: Contexts.interceptCall(
Context.current().withValue(RemoteAddressUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress),
serverCall, headers, next);
}
}

View File

@ -0,0 +1,64 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class MockRequestAttributesInterceptor implements ServerInterceptor {
@Nullable
private InetAddress remoteAddress;
@Nullable
private UserAgent userAgent;
@Nullable
private List<Locale.LanguageRange> acceptLanguage;
public void setRemoteAddress(@Nullable final InetAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
public void setUserAgent(@Nullable final UserAgent userAgent) {
this.userAgent = userAgent;
}
public void setAcceptLanguage(@Nullable final List<Locale.LanguageRange> acceptLanguage) {
this.acceptLanguage = acceptLanguage;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> serverCall,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
Context context = Context.current();
if (remoteAddress != null) {
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress);
}
if (userAgent != null) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent);
}
if (acceptLanguage != null) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage);
}
return Contexts.interceptCall(context, serverCall, headers, next);
}
}

View File

@ -17,16 +17,13 @@ import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.stub.MetadataUtils;
import java.lang.reflect.InvocationTargetException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -79,6 +76,8 @@ import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> { public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@ -100,6 +99,14 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
@Override @Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() { protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
return new ProfileAnonymousGrpcService( return new ProfileAnonymousGrpcService(
accountsManager, accountsManager,
profilesManager, profilesManager,
@ -108,14 +115,6 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
); );
} }
@Override
protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
return super.createStub(channel).withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
}
@Test @Test
void getUnversionedProfile() { void getUnversionedProfile() {
final UUID targetUuid = UUID.randomUUID(); final UUID targetUuid = UUID.randomUUID();

View File

@ -25,7 +25,6 @@ import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusEx
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber; import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
@ -34,6 +33,7 @@ import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -103,6 +103,8 @@ import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@ -167,9 +169,14 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
final String phoneNumber = PhoneNumberUtil.getInstance().format( final String phoneNumber = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164); PhoneNumberUtil.PhoneNumberFormat.E164);
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us"); getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter); when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());

View File

@ -0,0 +1,57 @@
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.stub.StreamObserver;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.chat.rpc.UserAgent;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
public class RequestAttributesServiceImpl extends RequestAttributesGrpc.RequestAttributesImplBase {
@Override
public void getRequestAttributes(final GetRequestAttributesRequest request,
final StreamObserver<GetRequestAttributesResponse> responseObserver) {
final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder();
RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages ->
acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString())));
RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale ->
responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag()));
responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress());
RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder()
.setPlatform(userAgent.getPlatform().toString())
.setVersion(userAgent.getVersion().toString())
.setAdditionalSpecifiers(userAgent.getAdditionalSpecifiers().orElse(""))
.build()));
responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted();
}
@Override
public void getAuthenticatedDevice(final GetAuthenticatedDeviceRequest request,
final StreamObserver<GetAuthenticatedDeviceResponse> responseObserver) {
final GetAuthenticatedDeviceResponse.Builder responseBuilder = GetAuthenticatedDeviceResponse.newBuilder();
try {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
responseBuilder.setAccountIdentifier(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()));
responseBuilder.setDeviceId(authenticatedDevice.deviceId());
} catch (final Exception ignored) {
}
responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted();
}
}

View File

@ -0,0 +1,160 @@
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RequestAttributesUtilTest {
private static DefaultEventLoopGroup eventLoopGroup;
private ClientConnectionManager clientConnectionManager;
private Server server;
private ManagedChannel managedChannel;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server");
clientConnectionManager = mock(ClientConnectionManager.class);
when(clientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString("127.0.0.1")));
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(new RequestAttributesInterceptor(clientConnectionManager))
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
}
@Test
void getAcceptableLanguages() {
when(clientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty());
when(clientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList());
}
@Test
void getAvailableAcceptedLocales() {
when(clientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty());
when(clientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
final GetRequestAttributesResponse response = getRequestAttributes();
assertFalse(response.getAvailableAcceptedLocalesList().isEmpty());
response.getAvailableAcceptedLocalesList().forEach(languageTag -> {
final Locale locale = Locale.forLanguageTag(languageTag);
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()));
});
}
@Test
void getRemoteAddress() {
when(clientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes);
final String remoteAddressString = "6.7.8.9";
when(clientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString(remoteAddressString)));
assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress());
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
when(clientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.empty());
assertFalse(getRequestAttributes().hasUserAgent());
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
when(clientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.of(userAgent));
final GetRequestAttributesResponse response = getRequestAttributes();
assertTrue(response.hasUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
}
private GetRequestAttributesResponse getRequestAttributes() {
return RequestAttributesGrpc.newBlockingStub(managedChannel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
}
}

View File

@ -60,7 +60,7 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
private AutoCloseable mocksCloseable; private AutoCloseable mocksCloseable;
private final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor(); private final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
private final MockRemoteAddressInterceptor mockRemoteAddressInterceptor = new MockRemoteAddressInterceptor(); private final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
private SERVICE service; private SERVICE service;
@ -114,8 +114,8 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
mocksCloseable = MockitoAnnotations.openMocks(this); mocksCloseable = MockitoAnnotations.openMocks(this);
service = requireNonNull(createServiceBeforeEachTest(), "created service must not be `null`"); service = requireNonNull(createServiceBeforeEachTest(), "created service must not be `null`");
GrpcTestUtils.setupAuthenticatedExtension( GrpcTestUtils.setupAuthenticatedExtension(
GRPC_SERVER_EXTENSION_AUTHENTICATED, mockAuthenticationInterceptor, mockRemoteAddressInterceptor, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service); GRPC_SERVER_EXTENSION_AUTHENTICATED, mockAuthenticationInterceptor, mockRequestAttributesInterceptor, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service);
GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, mockRemoteAddressInterceptor, service); GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, mockRequestAttributesInterceptor, service);
try { try {
authenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel()); authenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());
unauthenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_UNAUTHENTICATED.getChannel()); unauthenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_UNAUTHENTICATED.getChannel());
@ -145,8 +145,8 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
return unauthenticatedServiceStub; return unauthenticatedServiceStub;
} }
protected MockRemoteAddressInterceptor getMockRemoteAddressInterceptor() { protected MockRequestAttributesInterceptor getMockRequestAttributesInterceptor() {
return mockRemoteAddressInterceptor; return mockRequestAttributesInterceptor;
} }
protected MockAuthenticationInterceptor getMockAuthenticationInterceptor() { protected MockAuthenticationInterceptor getMockAuthenticationInterceptor() {

View File

@ -1,90 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import com.google.protobuf.ByteString;
import com.vdurmont.semver4j.Semver;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;
public class UserAgentInterceptorTest {
@ParameterizedTest
@MethodSource
void testInterceptor(final String header, final ClientPlatform platform, final String version) throws Exception {
final AtomicReference<UserAgent> observedUserAgent = new AtomicReference<>(null);
final EchoServiceImpl serviceImpl = new EchoServiceImpl() {
@Override
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
observedUserAgent.set(UserAgentUtil.userAgentFromGrpcContext());
super.echo(req, responseObserver);
}
};
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor()
.addService(serviceImpl)
.intercept(new UserAgentInterceptor())
.build()
.start();
try {
final ManagedChannel channel = InProcessChannelBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor()
.userAgent(header)
.build();
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
final EchoRequest req = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("cluck cluck, i'm a parrot")).build();
assertEquals("cluck cluck, i'm a parrot", client.echo(req).getPayload().toStringUtf8());
if (platform == null) {
assertNull(observedUserAgent.get());
} else {
assertEquals(platform, observedUserAgent.get().getPlatform());
assertEquals(new Semver(version), observedUserAgent.get().getVersion());
// can't assert on the additional specifiers because they include internal details of the grpc in-process channel itself
}
} finally {
testServer.shutdownNow();
testServer.awaitTermination();
}
}
private static Stream<Arguments> testInterceptor() {
return Stream.of(
Arguments.of(null, null, null),
Arguments.of("", null, null),
Arguments.of("Unrecognized UA", null, null),
Arguments.of("Signal-Android/4.68.3", ClientPlatform.ANDROID, "4.68.3"),
Arguments.of("Signal-iOS/3.9.0", ClientPlatform.IOS, "3.9.0"),
Arguments.of("Signal-Desktop/1.2.3", ClientPlatform.DESKTOP, "1.2.3"),
Arguments.of("Signal-Desktop/8.0.0-beta.2", ClientPlatform.DESKTOP, "8.0.0-beta.2"),
Arguments.of("Signal-iOS/8.0.0-beta.2", ClientPlatform.IOS, "8.0.0-beta.2"));
}
}

View File

@ -1,21 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.grpc.stub.StreamObserver;
import org.signal.chat.rpc.AuthenticationTypeGrpc;
import org.signal.chat.rpc.GetAuthenticatedRequest;
import org.signal.chat.rpc.GetAuthenticatedResponse;
public class AuthenticationTypeService extends AuthenticationTypeGrpc.AuthenticationTypeImplBase {
private final boolean authenticated;
public AuthenticationTypeService(final boolean authenticated) {
this.authenticated = authenticated;
}
@Override
public void getAuthenticated(final GetAuthenticatedRequest request, final StreamObserver<GetAuthenticatedResponse> responseObserver) {
responseObserver.onNext(GetAuthenticatedResponse.newBuilder().setAuthenticated(authenticated).build());
responseObserver.onCompleted();
}
}

View File

@ -0,0 +1,283 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
class ClientConnectionManagerTest {
private static EventLoopGroup eventLoopGroup;
private LocalChannel localChannel;
private LocalChannel remoteChannel;
private LocalServerChannel localServerChannel;
private ClientConnectionManager clientConnectionManager;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws InterruptedException {
eventLoopGroup = new DefaultEventLoopGroup();
clientConnectionManager = new ClientConnectionManager();
// We have to jump through some hoops to get "real" LocalChannel instances to test with, and so we run a trivial
// local server to which we can open trivial local connections
localServerChannel = (LocalServerChannel) new ServerBootstrap()
.group(eventLoopGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel channel) {
}
})
.bind(new LocalAddress("test-server"))
.await()
.channel();
final Bootstrap clientBootstrap = new Bootstrap()
.group(eventLoopGroup)
.channel(LocalChannel.class)
.handler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel ch) {
}
});
localChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
remoteChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
}
@AfterEach
void tearDown() throws InterruptedException {
localChannel.close().await();
remoteChannel.close().await();
localServerChannel.close().await();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
}
@ParameterizedTest
@MethodSource
void getAuthenticatedDevice(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
assertEquals(maybeAuthenticatedDevice,
clientConnectionManager.getAuthenticatedDevice(localChannel.localAddress()));
}
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
return List.of(
Optional.of(new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID)),
Optional.empty()
);
}
@Test
void getAcceptableLanguages() {
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
clientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
final List<Locale.LanguageRange> acceptLanguageRanges = Locale.LanguageRange.parse("en,ja");
remoteChannel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges);
assertEquals(Optional.of(acceptLanguageRanges),
clientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
}
@Test
void getRemoteAddress() {
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
clientConnectionManager.getRemoteAddress(localChannel.localAddress()));
final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9");
remoteChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress);
assertEquals(Optional.of(remoteAddress),
clientConnectionManager.getRemoteAddress(localChannel.localAddress()));
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
clientConnectionManager.getUserAgent(localChannel.localAddress()));
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
remoteChannel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent);
assertEquals(Optional.of(userAgent),
clientConnectionManager.getUserAgent(localChannel.localAddress()));
}
@Test
void closeConnection() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertTrue(remoteChannel.isOpen());
assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(List.of(remoteChannel),
clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@Test
void handleWebSocketHandshakeCompleteRemoteAddress() {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
preferredRemoteAddress,
null,
null);
assertEquals(preferredRemoteAddress,
embeddedChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader,
@Nullable final UserAgent expectedParsedUserAgent) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
userAgentHeader,
null);
assertEquals(userAgentHeader,
embeddedChannel.attr(ClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get());
assertEquals(expectedParsedUserAgent,
embeddedChannel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteUserAgent() {
return List.of(
// Recognized user-agent
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
// Unrecognized user-agent
Arguments.of("Not a valid user-agent string", null),
// Missing user-agent
Arguments.of(null, null)
);
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader,
@Nullable final List<Locale.LanguageRange> expectedLanguageRanges) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
null,
acceptLanguageHeader);
assertEquals(expectedLanguageRanges,
embeddedChannel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteAcceptLanguage() {
return List.of(
// Parseable list
Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")),
// Unparsable list
Arguments.of("This is not a valid language preference list", null),
// Missing list
Arguments.of(null, null)
);
}
@Test
void handleConnectionEstablishedAuthenticated() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@Test
void handleConnectionEstablishedAnonymous() throws InterruptedException {
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
remoteChannel.close().await();
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
}
}

View File

@ -8,8 +8,8 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketVersion; import io.netty.handler.codec.http.websocketx.WebSocketVersion;
@ -35,6 +35,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final ECPublicKey rootPublicKey; private final ECPublicKey rootPublicKey;
@Nullable private final UUID accountIdentifier; @Nullable private final UUID accountIdentifier;
private final byte deviceId; private final byte deviceId;
private final HttpHeaders headers;
private final SocketAddress remoteServerAddress; private final SocketAddress remoteServerAddress;
private final WebSocketCloseListener webSocketCloseListener; private final WebSocketCloseListener webSocketCloseListener;
@ -50,6 +51,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
final ECPublicKey rootPublicKey, final ECPublicKey rootPublicKey,
@Nullable final UUID accountIdentifier, @Nullable final UUID accountIdentifier,
final byte deviceId, final byte deviceId,
final HttpHeaders headers,
final SocketAddress remoteServerAddress, final SocketAddress remoteServerAddress,
final WebSocketCloseListener webSocketCloseListener) { final WebSocketCloseListener webSocketCloseListener) {
@ -60,6 +62,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
this.rootPublicKey = rootPublicKey; this.rootPublicKey = rootPublicKey;
this.accountIdentifier = accountIdentifier; this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId; this.deviceId = deviceId;
this.headers = headers;
this.remoteServerAddress = remoteServerAddress; this.remoteServerAddress = remoteServerAddress;
this.webSocketCloseListener = webSocketCloseListener; this.webSocketCloseListener = webSocketCloseListener;
} }
@ -87,7 +90,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
WebSocketVersion.V13, WebSocketVersion.V13,
null, null,
false, false,
new DefaultHttpHeaders(), headers,
Noise.MAX_PACKET_LEN, Noise.MAX_PACKET_LEN,
10_000)) 10_000))
.addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener)) .addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))

View File

@ -11,6 +11,7 @@ import java.net.SocketAddress;
import java.net.URI; import java.net.URI;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.UUID; import java.util.UUID;
import io.netty.handler.codec.http.HttpHeaders;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.ecc.ECPublicKey;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -30,6 +31,7 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
final ECPublicKey rootPublicKey, final ECPublicKey rootPublicKey,
@Nullable final UUID accountIdentifier, @Nullable final UUID accountIdentifier,
final byte deviceId, final byte deviceId,
final HttpHeaders headers,
final X509Certificate trustedServerCertificate, final X509Certificate trustedServerCertificate,
final NioEventLoopGroup eventLoopGroup, final NioEventLoopGroup eventLoopGroup,
final WebSocketCloseListener webSocketCloseListener) { final WebSocketCloseListener webSocketCloseListener) {
@ -48,6 +50,7 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
rootPublicKey, rootPublicKey,
accountIdentifier, accountIdentifier,
deviceId, deviceId,
headers,
remoteServerAddress, remoteServerAddress,
webSocketCloseListener)); webSocketCloseListener));
} }

View File

@ -1,7 +1,6 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
@ -17,6 +16,8 @@ import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
@ -35,28 +36,41 @@ import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException; import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec; import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64; import java.util.Base64;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.TrustManagerFactory;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.AuthenticationTypeGrpc; import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedRequest; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetAuthenticatedResponse; import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest { class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
@ -66,6 +80,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private static X509Certificate serverTlsCertificate; private static X509Certificate serverTlsCertificate;
private ClientConnectionManager clientConnectionManager;
private ClientPublicKeysManager clientPublicKeysManager; private ClientPublicKeysManager clientPublicKeysManager;
private ECKeyPair rootKeyPair; private ECKeyPair rootKeyPair;
@ -79,6 +94,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID(); private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID; private static final byte DEVICE_ID = Device.PRIMARY_ID;
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.randomAlphanumeric(16);
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test. // Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
// They were generated with: // They were generated with:
// //
@ -133,6 +150,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
clientKeyPair = Curve.generateKeyPair(); clientKeyPair = Curve.generateKeyPair();
final ECKeyPair serverKeyPair = Curve.generateKeyPair(); final ECKeyPair serverKeyPair = Curve.generateKeyPair();
clientConnectionManager = new ClientConnectionManager();
clientPublicKeysManager = mock(ClientPublicKeysManager.class); clientPublicKeysManager = mock(ClientPublicKeysManager.class);
when(clientPublicKeysManager.findPublicKey(any(), anyByte())) when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty())); .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
@ -146,7 +165,9 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override @Override
protected void configureServer(final ServerBuilder<?> serverBuilder) { protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new AuthenticationTypeService(true)); serverBuilder.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(clientConnectionManager))
.intercept(new RequireAuthenticationInterceptor(clientConnectionManager));
} }
}; };
@ -155,7 +176,9 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override @Override
protected void configureServer(final ServerBuilder<?> serverBuilder) { protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new AuthenticationTypeService(false)); serverBuilder.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(clientConnectionManager))
.intercept(new ProhibitAuthenticationInterceptor(clientConnectionManager));
} }
}; };
@ -166,11 +189,13 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
serverTlsPrivateKey, serverTlsPrivateKey,
nioEventLoopGroup, nioEventLoopGroup,
delegatedTaskExecutor, delegatedTaskExecutor,
clientConnectionManager,
clientPublicKeysManager, clientPublicKeysManager,
serverKeyPair, serverKeyPair,
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()), rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
authenticatedGrpcServerAddress, authenticatedGrpcServerAddress,
anonymousGrpcServerAddress); anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
websocketNoiseTunnelServer.start(); websocketNoiseTunnelServer.start();
} }
@ -198,10 +223,11 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try { try {
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel) final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()); .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertTrue(response.getAuthenticated()); assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -215,15 +241,15 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
// Try to verify the server's public key with something other than the key with which it was signed // Try to verify the server's public key with something other than the key with which it was signed
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) { buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try { try {
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel) () -> RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -247,8 +273,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
try { try {
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel) () -> RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -272,8 +298,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
try { try {
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel) () -> RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -294,6 +320,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
rootKeyPair.getPublicKey(), rootKeyPair.getPublicKey(),
ACCOUNT_IDENTIFIER, ACCOUNT_IDENTIFIER,
DEVICE_ID, DEVICE_ID,
new DefaultHttpHeaders(),
serverTlsCertificate, serverTlsCertificate,
nioEventLoopGroup, nioEventLoopGroup,
webSocketCloseListener) webSocketCloseListener)
@ -304,8 +331,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
try { try {
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel) () -> RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -320,10 +347,11 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try { try {
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel) final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()); .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertFalse(response.getAuthenticated()); assertTrue(response.getAccountIdentifier().isEmpty());
assertEquals(0, response.getDeviceId());
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -336,15 +364,15 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
// Try to verify the server's public key with something other than the key with which it was signed // Try to verify the server's public key with something other than the key with which it was signed
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) { buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try { try {
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel) () -> RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -365,6 +393,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
rootKeyPair.getPublicKey(), rootKeyPair.getPublicKey(),
null, null,
(byte) 0, (byte) 0,
new DefaultHttpHeaders(),
serverTlsCertificate, serverTlsCertificate,
nioEventLoopGroup, nioEventLoopGroup,
webSocketCloseListener) webSocketCloseListener)
@ -375,8 +404,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
try { try {
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel) () -> RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -438,6 +467,86 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
} }
} }
@Test
void getRequestAttributes() throws InterruptedException {
final String remoteAddress = "4.5.6.7";
final String acceptLanguage = "en";
final String userAgent = "Signal-Desktop/1.2.3 Linux";
final HttpHeaders headers = new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add("X-Forwarded-For", remoteAddress)
.add("Accept-Language", acceptLanguage)
.add("User-Agent", userAgent);
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), headers)) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} finally {
channel.shutdown();
}
}
}
@Test
void closeForReauthentication() throws InterruptedException {
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
final AtomicBoolean closedByServer = new AtomicBoolean(false);
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(false);
connectionCloseLatch.countDown();
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(true);
connectionCloseLatch.countDown();
}
};
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAuthenticatedClient(webSocketCloseListener)) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
clientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
assertTrue(connectionCloseLatch.await(2, TimeUnit.SECONDS));
assertEquals(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED.getStatusCode(),
serverCloseStatusCode.get());
assertTrue(closedByServer.get());
} finally {
channel.shutdown();
}
}
}
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException { private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException {
return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER); return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER);
} }
@ -445,11 +554,12 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener) private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
throws InterruptedException { throws InterruptedException {
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey()); return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
} }
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener, private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
final ECPublicKey rootPublicKey) throws InterruptedException { final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(), return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI, WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
@ -458,6 +568,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
rootPublicKey, rootPublicKey,
ACCOUNT_IDENTIFIER, ACCOUNT_IDENTIFIER,
DEVICE_ID, DEVICE_ID,
headers,
serverTlsCertificate, serverTlsCertificate,
nioEventLoopGroup, nioEventLoopGroup,
webSocketCloseListener) webSocketCloseListener)
@ -465,11 +576,12 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
} }
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient() throws InterruptedException { private WebSocketNoiseTunnelClient buildAndStartAnonymousClient() throws InterruptedException {
return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey()); return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
} }
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener, private WebSocketNoiseTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener,
final ECPublicKey rootPublicKey) throws InterruptedException { final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(), return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI, WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI,
@ -478,6 +590,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
rootPublicKey, rootPublicKey,
null, null,
(byte) 0, (byte) 0,
headers,
serverTlsCertificate, serverTlsCertificate,
nioEventLoopGroup, nioEventLoopGroup,
webSocketCloseListener) webSocketCloseListener)

View File

@ -0,0 +1,202 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
private UserEventRecordingHandler userEventRecordingHandler;
private MutableRemoteAddressEmbeddedChannel embeddedChannel;
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.randomAlphanumeric(16);
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
private final List<Object> receivedEvents = new ArrayList<>();
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
receivedEvents.add(event);
}
public List<Object> getReceivedEvents() {
return receivedEvents;
}
}
private static class MutableRemoteAddressEmbeddedChannel extends EmbeddedChannel {
private SocketAddress remoteAddress;
public MutableRemoteAddressEmbeddedChannel(final ChannelHandler... handlers) {
super(handlers);
}
@Override
protected SocketAddress remoteAddress0() {
return isActive() ? remoteAddress : null;
}
public void setRemoteAddress(final SocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
}
@BeforeEach
void setUp() {
userEventRecordingHandler = new UserEventRecordingHandler();
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class),
Curve.generateKeyPair(),
new byte[64],
RECOGNIZED_PROXY_SECRET),
userEventRecordingHandler);
embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0));
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
}
private static List<Arguments> handleWebSocketHandshakeComplete() {
return List.of(
Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
}
@Test
void handleWebSocketHandshakeCompleteUnexpectedPath() {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
}
@Test
void handleUnrecognizedEvent() {
final Object unrecognizedEvent = new Object();
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
}
@ParameterizedTest
@MethodSource
void getRemoteAddress(final HttpHeaders headers, final SocketAddress remoteAddress, @Nullable InetAddress expectedRemoteAddress) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(
WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, headers, null);
embeddedChannel.setRemoteAddress(remoteAddress);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertEquals(expectedRemoteAddress,
embeddedChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
private static List<Arguments> getRemoteAddress() {
final InetSocketAddress remoteAddress = new InetSocketAddress("5.6.7.8", 0);
final InetAddress clientAddress = InetAddresses.forString("1.2.3.4");
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
return List.of(
// Recognized proxy, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
clientAddress),
// Recognized proxy, multiple forwarded-for addresses
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
remoteAddress,
proxyAddress),
// No recognized proxy header, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
// No recognized proxy header, no forwarded-for address
Arguments.of(new DefaultHttpHeaders(),
remoteAddress,
remoteAddress.getAddress()),
// Incorrect proxy header, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
// Recognized proxy, no forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
remoteAddress,
remoteAddress.getAddress()),
// Recognized proxy, bogus forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
remoteAddress,
null),
// No forwarded-for address, non-InetSocketAddress remote address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
new LocalAddress("local-address"),
null)
);
}
}

View File

@ -1,91 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
class WebsocketHandshakeCompleteListenerTest extends AbstractLeakDetectionTest {
private UserEventRecordingHandler userEventRecordingHandler;
private EmbeddedChannel embeddedChannel;
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
private final List<Object> receivedEvents = new ArrayList<>();
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
receivedEvents.add(event);
}
public List<Object> getReceivedEvents() {
return receivedEvents;
}
}
@BeforeEach
void setUp() {
userEventRecordingHandler = new UserEventRecordingHandler();
embeddedChannel = new EmbeddedChannel(
new WebsocketHandshakeCompleteListener(mock(ClientPublicKeysManager.class), Curve.generateKeyPair(), new byte[64]),
userEventRecordingHandler);
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
}
private static List<Arguments> handleWebSocketHandshakeComplete() {
return List.of(
Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
}
@Test
void handleWebSocketHandshakeCompleteUnexpectedPath() {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
}
@Test
void handleUnrecognizedEvent() {
final Object unrecognizedEvent = new Object();
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
}
}

View File

@ -1,22 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
syntax = "proto3";
option java_multiple_files = true;
package org.signal.chat.rpc;
// A simple test service that identifies its authentication type to callers
service AuthenticationType {
rpc GetAuthenticated (GetAuthenticatedRequest) returns (GetAuthenticatedResponse) {}
}
message GetAuthenticatedRequest {
}
message GetAuthenticatedResponse {
bool authenticated = 1;
}

View File

@ -0,0 +1,41 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
syntax = "proto3";
option java_multiple_files = true;
package org.signal.chat.rpc;
// A simple test service that echoes request attributes to callers
service RequestAttributes {
rpc GetRequestAttributes (GetRequestAttributesRequest) returns (GetRequestAttributesResponse) {}
rpc GetAuthenticatedDevice (GetAuthenticatedDeviceRequest) returns (GetAuthenticatedDeviceResponse) {}
}
message GetRequestAttributesRequest {
}
message GetRequestAttributesResponse {
repeated string acceptable_languages = 1;
repeated string available_accepted_locales = 2;
string remote_address = 3;
UserAgent user_agent = 4;
}
message UserAgent {
string platform = 1;
string version = 2;
string additional_specifiers = 3;
}
message GetAuthenticatedDeviceRequest {
}
message GetAuthenticatedDeviceResponse {
bytes account_identifier = 1;
uint32 device_id = 2;
}