Compare commits

..

No commits in common. "main" and "v20250415.0.0" have entirely different histories.

46 changed files with 699 additions and 779 deletions

View File

@ -482,8 +482,7 @@ turn:
- turn:%s
- turn:%s:80?transport=tcp
- turns:%s:443?transport=tcp
requestedCredentialTtl: PT24H
clientCredentialTtl: PT12H
ttl: 86400
hostname: turn.cloudflare.example.com
numHttpClients: 1

View File

@ -673,8 +673,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
config.getTurnConfiguration().cloudflare().apiToken().value(),
config.getTurnConfiguration().cloudflare().endpoint(),
config.getTurnConfiguration().cloudflare().requestedCredentialTtl(),
config.getTurnConfiguration().cloudflare().clientCredentialTtl(),
config.getTurnConfiguration().cloudflare().ttl(),
config.getTurnConfiguration().cloudflare().urls(),
config.getTurnConfiguration().cloudflare().urlsWithIps(),
config.getTurnConfiguration().cloudflare().hostname(),

View File

@ -15,7 +15,6 @@ import java.net.Inet6Address;
import java.net.URI;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletionException;
@ -40,18 +39,16 @@ public class CloudflareTurnCredentialsManager {
private final List<String> cloudflareTurnUrls;
private final List<String> cloudflareTurnUrlsWithIps;
private final String cloudflareTurnHostname;
private final HttpRequest getCredentialsRequest;
private final HttpRequest request;
private final FaultTolerantHttpClient cloudflareTurnClient;
private final DnsNameResolver dnsNameResolver;
private final Duration clientCredentialTtl;
record CredentialRequest(long ttl) {}
private record CredentialRequest(long ttl) {}
record CloudflareTurnResponse(IceServer iceServers) {
private record CloudflareTurnResponse(IceServer iceServers) {
private record IceServer(
record IceServer(
String username,
String credential,
List<String> urls) {
@ -59,17 +56,10 @@ public class CloudflareTurnCredentialsManager {
}
public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken,
final String cloudflareTurnEndpoint,
final Duration requestedCredentialTtl,
final Duration clientCredentialTtl,
final List<String> cloudflareTurnUrls,
final List<String> cloudflareTurnUrlsWithIps,
final String cloudflareTurnHostname,
final int cloudflareTurnNumHttpClients,
final CircuitBreakerConfiguration circuitBreaker,
final ExecutorService executor,
final RetryConfiguration retry,
final ScheduledExecutorService retryExecutor,
final String cloudflareTurnEndpoint, final long cloudflareTurnTtl, final List<String> cloudflareTurnUrls,
final List<String> cloudflareTurnUrlsWithIps, final String cloudflareTurnHostname,
final int cloudflareTurnNumHttpClients, final CircuitBreakerConfiguration circuitBreaker,
final ExecutorService executor, final RetryConfiguration retry, final ScheduledExecutorService retryExecutor,
final DnsNameResolver dnsNameResolver) {
this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder()
@ -85,24 +75,17 @@ public class CloudflareTurnCredentialsManager {
this.cloudflareTurnHostname = cloudflareTurnHostname;
this.dnsNameResolver = dnsNameResolver;
final String credentialsRequestBody;
try {
credentialsRequestBody =
SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(requestedCredentialTtl.toSeconds()));
} catch (final JsonProcessingException e) {
final String body = SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(cloudflareTurnTtl));
this.request = HttpRequest.newBuilder()
.uri(URI.create(cloudflareTurnEndpoint))
.header("Content-Type", "application/json")
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
.POST(HttpRequest.BodyPublishers.ofString(body))
.build();
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
// We repeat the same request to Cloudflare every time, so we can construct it once and re-use it
this.getCredentialsRequest = HttpRequest.newBuilder()
.uri(URI.create(cloudflareTurnEndpoint))
.header("Content-Type", "application/json")
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
.POST(HttpRequest.BodyPublishers.ofString(credentialsRequestBody))
.build();
this.clientCredentialTtl = clientCredentialTtl;
}
public TurnToken retrieveFromCloudflare() throws IOException {
@ -122,7 +105,7 @@ public class CloudflareTurnCredentialsManager {
final Timer.Sample sample = Timer.start();
final HttpResponse<String> response;
try {
response = cloudflareTurnClient.sendAsync(getCredentialsRequest, HttpResponse.BodyHandlers.ofString()).join();
response = cloudflareTurnClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join();
sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME)
.publishPercentileHistogram(true)
.tags("outcome", "success")
@ -147,7 +130,6 @@ public class CloudflareTurnCredentialsManager {
return new TurnToken(
cloudflareTurnResponse.iceServers().username(),
cloudflareTurnResponse.iceServers().credential(),
clientCredentialTtl.toSeconds(),
cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls,
cloudflareTurnComposedUrls,
cloudflareTurnHostname

View File

@ -5,15 +5,13 @@
package org.whispersystems.textsecuregcm.auth;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;
public record TurnToken(
String username,
String password,
@JsonProperty("ttl") long ttlSeconds,
@Nonnull List<String> urls,
@Nonnull List<String> urlsWithIps,
@Nullable String hostname) {

View File

@ -1,22 +1,34 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerInterceptor;
import java.util.Optional;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import java.util.Optional;
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Metadata EMPTY_TRAILERS = new Metadata();
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call)
throws ChannelNotFoundException {
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
return grpcClientConnectionManager.getAuthenticatedDevice(localAddress);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
return grpcClientConnectionManager.getAuthenticatedDevice(call);
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) {
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
return new ServerCall.Listener<>() {};
}
}

View File

@ -3,17 +3,12 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
* device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
* reject the call with a status of {@code UNAVAILABLE}.
* device are closed with an {@code UNAUTHENTICATED} status.
*/
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -26,15 +21,8 @@ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInt
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
try {
return getAuthenticatedDevice(call)
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited
// service via an authenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
.orElseGet(() -> next.startCall(call, headers));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
return getAuthenticatedDevice(call)
.map(ignored -> closeAsUnauthenticated(call))
.orElseGet(() -> next.startCall(call, headers));
}
}

View File

@ -5,16 +5,12 @@ import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
* status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
* status.
*/
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -27,17 +23,10 @@ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInte
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
try {
return getAuthenticatedDevice(call)
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
call, headers, next))
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
return getAuthenticatedDevice(call)
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
call, headers, next))
.orElseGet(() -> closeAsUnauthenticated(call));
}
}

View File

@ -6,36 +6,16 @@
package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import java.time.Duration;
import java.util.List;
import jakarta.validation.constraints.Positive;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
/**
* Configuration properties for Cloudflare TURN integration.
*
* @param apiToken the API token to use when requesting TURN tokens from Cloudflare
* @param endpoint the URI of the Cloudflare API endpoint that vends TURN tokens
* @param requestedCredentialTtl the lifetime of TURN tokens to request from Cloudflare
* @param clientCredentialTtl the time clients may cache a TURN token; must be less than or equal to {@link #requestedCredentialTtl}
* @param urls a collection of TURN URLs to include verbatim in responses to clients
* @param urlsWithIps a collection of {@link String#format(String, Object...)} patterns to be populated with resolved IP
* addresses for {@link #hostname} in responses to clients; each pattern must include a single
* {@code %s} placeholder for the IP address
* @param circuitBreaker a circuit breaker for requests to Cloudflare
* @param retry a retry policy for requests to Cloudflare
* @param hostname the hostname to resolve to IP addresses for use with {@link #urlsWithIps}; also transmitted to
* clients for use as an SNI when connecting to pre-resolved hosts
* @param numHttpClients the number of parallel HTTP clients to use to communicate with Cloudflare
*/
public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
@NotBlank String endpoint,
@NotNull Duration requestedCredentialTtl,
@NotNull Duration clientCredentialTtl,
@NotBlank long ttl,
@NotNull @NotEmpty @Valid List<@NotBlank String> urls,
@NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps,
@NotNull @Valid CircuitBreakerConfiguration circuitBreaker,
@ -55,9 +35,4 @@ public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
retry = new RetryConfiguration();
}
}
@AssertTrue
public boolean isClientTtlShorterThanRequestedTtl() {
return clientCredentialTtl.compareTo(requestedCredentialTtl) <= 0;
}
}

View File

@ -15,12 +15,16 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly;
@ -28,16 +32,14 @@ import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v2/calling")
public class CallRoutingControllerV2 {
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
private final RateLimiters rateLimiters;
private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER =
Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
public CallRoutingControllerV2(
final RateLimiters rateLimiters,
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager) {
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager
) {
this.rateLimiters = rateLimiters;
this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager;
}
@ -56,17 +58,25 @@ public class CallRoutingControllerV2 {
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth)
throws RateLimitExceededException, IOException {
final UUID aci = auth.getAccount().getUuid();
public GetCallingRelaysResponse getCallingRelays(
final @ReadOnly @Auth AuthenticatedDevice auth
) throws RateLimitExceededException, IOException {
UUID aci = auth.getAccount().getUuid();
rateLimiters.getCallEndpointLimiter().validate(aci);
List<TurnToken> tokens = new ArrayList<>();
try {
return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare()));
} catch (final Exception e) {
CLOUDFLARE_TURN_ERROR_COUNTER.increment();
tokens.add(cloudflareTurnCredentialsManager.retrieveFromCloudflare());
} catch (Exception e) {
CallRoutingControllerV2.CLOUDFLARE_TURN_ERROR_COUNTER.increment();
throw e;
}
return new GetCallingRelaysResponse(tokens);
}
public record GetCallingRelaysResponse(
List<TurnToken> relays
) {
}
}

View File

@ -104,12 +104,14 @@ public class DonationController {
.type(MediaType.TEXT_PLAIN_TYPE).build());
}
return accountsManager.updateAsync(auth.getAccount(), a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId);
}
})
return accountsManager.getByAccountIdentifierAsync(auth.getAccount().getUuid())
.thenCompose(optionalAccount ->
optionalAccount.map(account -> accountsManager.updateAsync(account, a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId);
}
})).orElse(CompletableFuture.completedFuture(null)))
.thenApply(ignored -> Response.ok().build());
});
}).thenCompose(Function.identity());

View File

@ -1,13 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import java.util.List;
public record GetCallingRelaysResponse(List<TurnToken> relays) {
}

View File

@ -81,16 +81,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
@Nullable final UserAgent userAgent = RequestAttributesUtil.getUserAgent()
.map(userAgentString -> {
try {
return UserAgentUtil.parseUserAgentString(userAgentString);
} catch (final UnrecognizedUserAgentException e) {
return null;
}
}).orElse(null);
if (shouldBlock(userAgent)) {
if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) {
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
return new ServerCall.Listener<>() {};
} else {

View File

@ -1,12 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
/**
* Indicates that a remote channel was not found for a given server call or remote address.
*/
public class ChannelNotFoundException extends Exception {
}

View File

@ -253,7 +253,7 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me
story,
ephemeral,
urgent,
RequestAttributesUtil.getUserAgent().orElse(null));
RequestAttributesUtil.getRawUserAgent().orElse(null));
final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder();

View File

@ -55,7 +55,7 @@ public class MessagesGrpcHelper {
messagesByDeviceId,
registrationIdsByDeviceId,
syncMessageSenderDeviceId,
RequestAttributesUtil.getUserAgent().orElse(null));
RequestAttributesUtil.getRawUserAgent().orElse(null));
return SEND_MESSAGE_SUCCESS_RESPONSE;
} catch (final MismatchedDevicesException e) {

View File

@ -1,16 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
public record RequestAttributes(InetAddress remoteAddress,
@Nullable String userAgent,
List<Locale.LanguageRange> acceptLanguage) {
}

View File

@ -2,25 +2,28 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
/**
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
*
* @see RequestAttributesUtil
*/
public class RequestAttributesInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
@ -30,12 +33,52 @@ public class RequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
try {
return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next);
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
Context context = Context.current();
{
final Optional<InetAddress> maybeRemoteAddress = grpcClientConnectionManager.getRemoteAddress(localAddress);
if (maybeRemoteAddress.isEmpty()) {
// We should never have a call from a party whose remote address we can't identify
log.warn("No remote address available");
call.close(Status.INTERNAL, new Metadata());
return new ServerCall.Listener<>() {};
}
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
}
{
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
grpcClientConnectionManager.getAcceptableLanguages(localAddress);
if (maybeAcceptLanguage.isPresent()) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
}
}
{
final Optional<String> maybeRawUserAgent =
grpcClientConnectionManager.getRawUserAgent(localAddress);
if (maybeRawUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.RAW_USER_AGENT_CONTEXT_KEY, maybeRawUserAgent.get());
}
}
{
final Optional<UserAgent> maybeUserAgent = grpcClientConnectionManager.getUserAgent(localAddress);
if (maybeUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
}
}
return Contexts.interceptCall(context, call, headers, next);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
}

View File

@ -3,13 +3,18 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class RequestAttributesUtil {
static final Context.Key<RequestAttributes> REQUEST_ATTRIBUTES_CONTEXT_KEY = Context.key("request-attributes");
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language");
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
static final Context.Key<String> RAW_USER_AGENT_CONTEXT_KEY = Context.key("unparsed-user-agent");
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("parsed-user-agent");
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
@ -18,8 +23,8 @@ public class RequestAttributesUtil {
*
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
*/
public static List<Locale.LanguageRange> getAcceptableLanguages() {
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().acceptLanguage();
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() {
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get());
}
/**
@ -30,7 +35,9 @@ public class RequestAttributesUtil {
* @return a list of distinct locales acceptable to the remote client and available in this JVM
*/
public static List<Locale> getAvailableAcceptedLocales() {
return Locale.filter(getAcceptableLanguages(), AVAILABLE_LOCALES);
return getAcceptableLanguages()
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
.orElseGet(Collections::emptyList);
}
/**
@ -39,7 +46,16 @@ public class RequestAttributesUtil {
* @return the remote address of the remote client
*/
public static InetAddress getRemoteAddress() {
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().remoteAddress();
return REMOTE_ADDRESS_CONTEXT_KEY.get();
}
/**
* Returns the parsed user-agent of the remote client in the current gRPC request context.
*
* @return the parsed user-agent of the remote client; may be empty if unparseable or not specified
*/
public static Optional<UserAgent> getUserAgent() {
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
}
/**
@ -47,7 +63,7 @@ public class RequestAttributesUtil {
*
* @return the unparsed user-agent of the remote client; may be empty if not specified
*/
public static Optional<String> getUserAgent() {
return Optional.ofNullable(REQUEST_ATTRIBUTES_CONTEXT_KEY.get().userAgent());
public static Optional<String> getRawUserAgent() {
return Optional.ofNullable(RAW_USER_AGENT_CONTEXT_KEY.get());
}
}

View File

@ -1,39 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
public class ServerInterceptorUtil {
@SuppressWarnings("rawtypes")
private static final ServerCall.Listener NO_OP_LISTENER = new ServerCall.Listener<>() {};
private static final Metadata EMPTY_TRAILERS = new Metadata();
private ServerInterceptorUtil() {
}
/**
* Closes the given server call with the given status, returning a no-op listener.
*
* @param call the server call to close
* @param status the status with which to close the call
*
* @return a no-op server call listener
*
* @param <ReqT> the type of request object handled by the server call
* @param <RespT> the type of response object returned by the server call
*/
public static <ReqT, RespT> ServerCall.Listener<ReqT> closeWithStatus(final ServerCall<ReqT, RespT> call, final Status status) {
call.close(status, EMPTY_TRAILERS);
//noinspection unchecked
return NO_OP_LISTENER;
}
}

View File

@ -12,10 +12,8 @@ import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
/**
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
@ -50,12 +48,12 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
@Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseIdentityDeterminedEvent(final Optional<AuthenticatedDevice> authenticatedDevice)) {
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
// authenticated service.
final LocalAddress grpcServerAddress = authenticatedDevice.isPresent()
final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent()
? authenticatedGrpcServerAddress
: anonymousGrpcServerAddress;
@ -74,7 +72,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
if (localChannelFuture.isSuccess()) {
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(),
authenticatedDevice);
noiseIdentityDeterminedEvent.authenticatedDevice());
// Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());

View File

@ -1,8 +1,6 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Grpc;
import io.grpc.ServerCall;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress;
@ -25,25 +23,15 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
/**
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
* Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
* identity of the device that opened the connection (for non-anonymous connections). It can also close connections
* associated with a given device if that device's credentials have changed and clients must reauthenticate.
* <p>
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
* narrow window between a server call being created and the start of call execution, in which case accessor methods
* in this class will throw a {@link ChannelNotFoundException}.
* <p>
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
* be called from any application code.
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate.
*/
public class GrpcClientConnectionManager implements DisconnectionRequestListener {
@ -55,56 +43,94 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
@VisibleForTesting
public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress");
@VisibleForTesting
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent");
@VisibleForTesting
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
@VisibleForTesting
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
/**
* Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
* (i.e. unauthenticated), the returned value will be empty.
* Returns the authenticated device associated with the given local address, if any. An authenticated device is
* available if and only if the given local address maps to an active local connection and that connection is
* authenticated (i.e. not anonymous).
*
* @param serverCall the gRPC server call for which to find an authenticated device
* @param localAddress the local address for which to find an authenticated device
*
* @return the authenticated device associated with the given local address, if any
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall)
throws ChannelNotFoundException {
return getAuthenticatedDevice(getRemoteChannel(serverCall));
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) {
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress));
}
@VisibleForTesting
Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) {
return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) {
return Optional.ofNullable(remoteChannel)
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
}
/**
* Returns the request attributes associated with the given server call.
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may
* be unavailable if the local connection associated with the given local address has already closed, if the client
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
*
* @param serverCall the gRPC server call for which to retrieve request attributes
* @param localAddress the local address for which to find acceptable languages
*
* @return the request attributes associated with the given server call
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
* @return the acceptable languages associated with the given local address, if any
*/
public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRequestAttributes(getRemoteChannel(serverCall));
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
}
@VisibleForTesting
RequestAttributes getRequestAttributes(final Channel remoteChannel) {
final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
/**
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if
* the local connection associated with the given local address has already closed.
*
* @param localAddress the local address for which to find a remote address
*
* @return the remote address associated with the given local address, if any
*/
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
if (requestAttributes == null) {
throw new IllegalStateException("Channel does not have request attributes");
}
/**
* Returns the unparsed user agent provided by the client that opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent string associated with the given local address
*/
public Optional<String> getRawUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(RAW_USER_AGENT_ATTRIBUTE_KEY).get());
}
return requestAttributes;
/**
* Returns the parsed user agent provided by the client that opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address or if the client's user-agent string was not recognized.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent associated with the given local address
*/
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
/**
@ -130,32 +156,11 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
}
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRemoteChannel(getLocalAddress(serverCall));
}
@VisibleForTesting
Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
if (remoteChannel == null) {
throw new ChannelNotFoundException();
}
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) {
return remoteChannelsByLocalAddress.get(localAddress);
}
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
// a proxied Noise channel.
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
return localAddress;
}
/**
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
* request with the channel via which the handshake took place.
@ -166,23 +171,30 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null}
*/
static void handleHandshakeComplete(final Channel channel,
static void handleWebSocketHandshakeComplete(final Channel channel,
final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) {
@Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
channel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress);
if (StringUtils.isNotBlank(userAgentHeader)) {
channel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
try {
channel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
} catch (final UnrecognizedUserAgentException ignored) {
}
}
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try {
acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
channel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader));
} catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
}
}
channel.attr(REQUEST_ATTRIBUTES_KEY)
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
}
/**

View File

@ -74,7 +74,7 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
preferredRemoteAddress = maybePreferredRemoteAddress.get();
}
GrpcClientConnectionManager.handleHandshakeComplete(context.channel(),
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(),
preferredRemoteAddress,
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));

View File

@ -56,7 +56,6 @@ public class MessageSender {
// Note that these names deliberately reference `MessageController` for metric continuity
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage");
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize");
private static final String EMPTY_MESSAGE_LIST_COUNTER_NAME = MetricsUtil.name(MessageSender.class, "emptyMessageList");
private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage");
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
@ -106,13 +105,6 @@ public class MessageSender {
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
}
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
if (messagesByDeviceId.isEmpty()) {
Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME,
Tags.of("sync", String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment();
}
final byte excludedDeviceId;
if (syncMessageSenderDeviceId.isPresent()) {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) ||
@ -159,7 +151,7 @@ public class MessageSender {
STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()),
MULTI_RECIPIENT_TAG_NAME, "false")
.and(platformTag);
.and(UserAgentTagUtil.getPlatformTag(userAgent));
Metrics.counter(SEND_COUNTER_NAME, tags).increment();
});

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.subscriptions;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.stripe.Stripe;
import com.stripe.StripeClient;
import com.stripe.exception.CardException;
import com.stripe.exception.IdempotencyException;
@ -72,7 +71,6 @@ import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerVersion;
import org.whispersystems.textsecuregcm.storage.PaymentTime;
import org.whispersystems.textsecuregcm.storage.SubscriptionException;
import org.whispersystems.textsecuregcm.util.Conversions;
@ -99,9 +97,6 @@ public class StripeManager implements CustomerAwareSubscriptionPaymentProcessor
if (Strings.isNullOrEmpty(apiKey)) {
throw new IllegalArgumentException("apiKey cannot be empty");
}
Stripe.setAppInfo("Signal-Server", WhisperServerVersion.getServerVersion());
this.stripeClient = new StripeClient(apiKey);
this.executor = Objects.requireNonNull(executor);
this.idempotencyKeyGenerator = Objects.requireNonNull(idempotencyKeyGenerator);

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.util.ua;
import com.google.common.annotations.VisibleForTesting;
import com.vdurmont.semver4j.Semver;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@ -20,10 +21,10 @@ public class UserAgentUtil {
}
try {
final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);
final UserAgent standardUserAgent = parseStandardUserAgentString(userAgentString);
if (matcher.matches()) {
return new UserAgent(ClientPlatform.valueOf(matcher.group(1).toUpperCase()), new Semver(matcher.group(2)), StringUtils.stripToNull(matcher.group(4)));
if (standardUserAgent != null) {
return standardUserAgent;
}
} catch (final Exception e) {
throw new UnrecognizedUserAgentException(e);
@ -31,4 +32,15 @@ public class UserAgentUtil {
throw new UnrecognizedUserAgentException();
}
@VisibleForTesting
static UserAgent parseStandardUserAgentString(final String userAgentString) {
final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);
if (matcher.matches()) {
return new UserAgent(ClientPlatform.valueOf(matcher.group(1).toUpperCase()), new Semver(matcher.group(2)), StringUtils.stripToNull(matcher.group(4)));
}
return null;
}
}

View File

@ -5,11 +5,8 @@
package org.whispersystems.textsecuregcm.auth;
import static com.github.tomakehurst.wiremock.client.WireMock.created;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat;
@ -18,19 +15,17 @@ import static org.mockito.Mockito.when;
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.util.concurrent.GlobalEventExecutor;
import io.netty.util.concurrent.SucceededFuture;
import io.netty.util.concurrent.Future;
import java.io.IOException;
import java.net.InetAddress;
import java.time.Duration;
import java.util.ArrayList;
import java.security.cert.CertificateException;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -40,41 +35,31 @@ import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
public class CloudflareTurnCredentialsManagerTest {
@RegisterExtension
private static final WireMockExtension wireMock = WireMockExtension.newInstance()
private final WireMockExtension wireMock = WireMockExtension.newInstance()
.options(wireMockConfig().dynamicPort().dynamicHttpsPort())
.build();
private static final String GET_CREDENTIALS_PATH = "/v1/turn/keys/LMNOP/credentials/generate";
private static final String TURN_HOSTNAME = "localhost";
private ExecutorService httpExecutor;
private ScheduledExecutorService retryExecutor;
private DnsNameResolver dnsResolver;
private Future<List<InetAddress>> dnsResult;
private CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
private static final String GET_CREDENTIALS_PATH = "/v1/turn/keys/LMNOP/credentials/generate";
private static final String TURN_HOSTNAME = "localhost";
private static final String API_TOKEN = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final String USERNAME = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final String CREDENTIAL = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final List<String> CLOUDFLARE_TURN_URLS = List.of("turn:cf.example.com");
private static final Duration REQUESTED_CREDENTIAL_TTL = Duration.ofSeconds(100);
private static final Duration CLIENT_CREDENTIAL_TTL = REQUESTED_CREDENTIAL_TTL.dividedBy(2);
private static final List<String> IP_URL_PATTERNS = List.of("turn:%s", "turn:%s:80?transport=tcp", "turns:%s:443?transport=tcp");
private CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = null;
@BeforeEach
void setUp() {
void setUp() throws CertificateException {
httpExecutor = Executors.newSingleThreadExecutor();
retryExecutor = Executors.newSingleThreadScheduledExecutor();
dnsResolver = mock(DnsNameResolver.class);
dnsResult = mock(Future.class);
cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
API_TOKEN,
"API_TOKEN",
"http://localhost:" + wireMock.getPort() + GET_CREDENTIALS_PATH,
REQUESTED_CREDENTIAL_TTL,
CLIENT_CREDENTIAL_TTL,
CLOUDFLARE_TURN_URLS,
IP_URL_PATTERNS,
100,
List.of("turn:cf.example.com"),
List.of("turn:%s", "turn:%s:80?transport=tcp", "turns:%s:443?transport=tcp"),
TURN_HOSTNAME,
2,
new CircuitBreakerConfiguration(),
@ -88,61 +73,26 @@ public class CloudflareTurnCredentialsManagerTest {
@AfterEach
void tearDown() throws InterruptedException {
httpExecutor.shutdown();
retryExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
httpExecutor.awaitTermination(1, TimeUnit.SECONDS);
//noinspection ResultOfMethodCallIgnored
retryExecutor.shutdown();
retryExecutor.awaitTermination(1, TimeUnit.SECONDS);
}
@Test
public void testSuccess() throws IOException, CancellationException {
public void testSuccess() throws IOException, CancellationException, ExecutionException, InterruptedException {
wireMock.stubFor(post(urlEqualTo(GET_CREDENTIALS_PATH))
.willReturn(created()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"iceServers": {
"urls": [
"turn:cloudflare.example.com:3478?transport=udp"
],
"username": "%s",
"credential": "%s"
}
}
""".formatted(USERNAME, CREDENTIAL))));
.willReturn(aResponse().withStatus(201).withHeader("Content-Type", new String[]{"application/json"}).withBody("{\"iceServers\":{\"urls\":[\"turn:cloudflare.example.com:3478?transport=udp\"],\"username\":\"ABC\",\"credential\":\"XYZ\"}}")));
when(dnsResult.get())
.thenReturn(List.of(InetAddress.getByName("127.0.0.1"), InetAddress.getByName("::1")));
when(dnsResolver.resolveAll(TURN_HOSTNAME))
.thenReturn(new SucceededFuture<>(GlobalEventExecutor.INSTANCE,
List.of(InetAddress.getByName("127.0.0.1"), InetAddress.getByName("::1"))));
.thenReturn(dnsResult);
TurnToken token = cloudflareTurnCredentialsManager.retrieveFromCloudflare();
wireMock.verify(postRequestedFor(urlEqualTo(GET_CREDENTIALS_PATH))
.withHeader("Content-Type", equalTo("application/json"))
.withHeader("Authorization", equalTo("Bearer " + API_TOKEN))
.withRequestBody(equalToJson("""
{
"ttl": %d
}
""".formatted(REQUESTED_CREDENTIAL_TTL.toSeconds()))));
assertThat(token.username()).isEqualTo(USERNAME);
assertThat(token.password()).isEqualTo(CREDENTIAL);
assertThat(token.hostname()).isEqualTo(TURN_HOSTNAME);
assertThat(token.urls()).isEqualTo(CLOUDFLARE_TURN_URLS);
assertThat(token.ttlSeconds()).isEqualTo(CLIENT_CREDENTIAL_TTL.toSeconds());
final List<String> expectedUrlsWithIps = new ArrayList<>();
for (final String ip : new String[] {"127.0.0.1", "[0:0:0:0:0:0:0:1]"}) {
for (final String pattern : IP_URL_PATTERNS) {
expectedUrlsWithIps.add(pattern.formatted(ip));
}
}
assertThat(token.urlsWithIps()).containsExactlyElementsOf(expectedUrlsWithIps);
assertThat(token.username()).isEqualTo("ABC");
assertThat(token.password()).isEqualTo("XYZ");
assertThat(token.hostname()).isEqualTo("localhost");
assertThat(token.urlsWithIps()).containsAll(List.of("turn:127.0.0.1", "turn:127.0.0.1:80?transport=tcp", "turns:127.0.0.1:443?transport=tcp", "turn:[0:0:0:0:0:0:0:1]", "turn:[0:0:0:0:0:0:0:1]:80?transport=tcp", "turns:[0:0:0:0:0:0:0:1]:443?transport=tcp"));;
assertThat(token.urls()).isEqualTo(List.of("turn:cf.example.com"));
}
}

View File

@ -3,7 +3,6 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Status;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
@ -23,7 +22,7 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
}
@Test
void interceptCall() throws ChannelNotFoundException {
void interceptCall() {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
@ -35,10 +34,6 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
}
}

View File

@ -9,7 +9,6 @@ import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
@ -23,12 +22,12 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
}
@Test
void interceptCall() throws ChannelNotFoundException {
void interceptCall() {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
@ -36,9 +35,5 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
}
}

View File

@ -40,15 +40,14 @@ class CallRoutingControllerV2Test {
private static final TurnToken CLOUDFLARE_TURN_TOKEN = new TurnToken(
"ABC",
"XYZ",
43_200,
List.of("turn:cloudflare.example.com:3478?transport=udp"),
null,
"cf.example.com");
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter getCallEndpointLimiter = mock(RateLimiter.class);
private static final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager =
mock(CloudflareTurnCredentialsManager.class);
private static final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = mock(
CloudflareTurnCredentialsManager.class);
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
@ -67,14 +66,21 @@ class CallRoutingControllerV2Test {
@AfterEach
void tearDown() {
reset(rateLimiters, getCallEndpointLimiter);
reset( rateLimiters, getCallEndpointLimiter);
}
void initializeMocksWith(TurnToken cloudflareToken) {
try {
when(cloudflareTurnCredentialsManager.retrieveFromCloudflare()).thenReturn(cloudflareToken);
} catch (IOException ignored) {
}
}
@Test
void testGetRelaysBothRouting() throws IOException {
when(cloudflareTurnCredentialsManager.retrieveFromCloudflare()).thenReturn(CLOUDFLARE_TURN_TOKEN);
void testGetRelaysBothRouting() {
initializeMocksWith(CLOUDFLARE_TURN_TOKEN);
try (final Response rawResponse = resources.getJerseyTest()
try (Response rawResponse = resources.getJerseyTest()
.target(GET_CALL_RELAYS_PATH)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
@ -82,8 +88,11 @@ class CallRoutingControllerV2Test {
assertThat(rawResponse.getStatus()).isEqualTo(200);
assertThat(rawResponse.readEntity(GetCallingRelaysResponse.class).relays())
.isEqualTo(List.of(CLOUDFLARE_TURN_TOKEN));
CallRoutingControllerV2.GetCallingRelaysResponse response = rawResponse.readEntity(
CallRoutingControllerV2.GetCallingRelaysResponse.class);
List<TurnToken> relays = response.relays();
assertThat(relays).isEqualTo(List.of(CLOUDFLARE_TURN_TOKEN));
}
}

View File

@ -1140,58 +1140,6 @@ class ProfileControllerTest {
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false));
}
}
@Test
void testSetProfileBadgeAfterUpdateTries() throws Exception {
final ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(
new ServiceId.Aci(AuthHelper.VALID_UUID));
final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156);
final String version = versionHex("anotherversion");
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
reset(accountsManager);
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(AuthHelper.VALID_ACCOUNT_TWO.getBadges())
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true)
))
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), true)
));
try (final Response response = resources.getJerseyTest()
.target("/v1/profile/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))
.put(Entity.entity(new CreateProfileRequest(commitment, version, name, emoji, about, null, false, false,
Optional.of(List.of("TEST1")), null), MediaType.APPLICATION_JSON_TYPE))) {
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.hasEntity()).isFalse();
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).getBadges();
final List<AccountBadge> badges = badgeCaptor.getValue();
assertThat(badges).isNotNull().hasSize(4).containsOnly(
new AccountBadge("TEST1", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), false));
}
}
@ParameterizedTest

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
@ -25,6 +24,7 @@ import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.core.Response;
import java.net.InetAddress;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Set;
@ -39,7 +39,6 @@ import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.InetAddressRange;
@ExtendWith(DropwizardExtensionsSupport.class)
@ -158,7 +157,7 @@ class ExternalRequestFilterTest {
@BeforeEach
void setUp() throws Exception {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
mockRequestAttributesInterceptor.setRemoteAddress(InetAddress.getByName("127.0.0.1"));
testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest")
.directExecutor()

View File

@ -15,7 +15,6 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString;
import com.vdurmont.semver4j.Semver;
import io.grpc.ManagedChannel;
@ -41,10 +40,11 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RemoteDeprecationFilterTest {
@ -130,7 +130,11 @@ class RemoteDeprecationFilterTest {
@MethodSource(value="testFilter")
void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), userAgentString, null));
try {
mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString));
} catch (UnrecognizedUserAgentException ignored) {
}
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor()

View File

@ -72,8 +72,7 @@ class AccountsAnonymousGrpcServiceTest extends
when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
getMockRequestAttributesInterceptor().setRequestAttributes(
new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1"));
return new AccountsAnonymousGrpcService(accountsManager, rateLimiters);
}

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.net.InetAddresses;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
@ -20,10 +19,25 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class MockRequestAttributesInterceptor implements ServerInterceptor {
private RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null);
@Nullable
private InetAddress remoteAddress;
public void setRequestAttributes(final RequestAttributes requestAttributes) {
this.requestAttributes = requestAttributes;
@Nullable
private UserAgent userAgent;
@Nullable
private List<Locale.LanguageRange> acceptLanguage;
public void setRemoteAddress(@Nullable final InetAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
public void setUserAgent(@Nullable final UserAgent userAgent) {
this.userAgent = userAgent;
}
public void setAcceptLanguage(@Nullable final List<Locale.LanguageRange> acceptLanguage) {
this.acceptLanguage = acceptLanguage;
}
@Override
@ -31,7 +45,20 @@ public class MockRequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes), serverCall, headers, next);
Context context = Context.current();
if (remoteAddress != null) {
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress);
}
if (userAgent != null) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent);
}
if (acceptLanguage != null) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage);
}
return Contexts.interceptCall(context, serverCall, headers, next);
}
}

View File

@ -15,7 +15,6 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import java.nio.charset.StandardCharsets;
@ -76,6 +75,8 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@ -95,9 +96,13 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
@Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
Locale.LanguageRange.parse("en-us")));
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
return new ProfileAnonymousGrpcService(
accountsManager,

View File

@ -15,16 +15,13 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.refEq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString;
@ -33,7 +30,6 @@ import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
@ -97,18 +93,18 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceCapability;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@ -148,8 +144,6 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
@Mock
private ServerZkProfileOperations serverZkProfileOperations;
private Clock clock;
@Override
protected ProfileGrpcService createServiceBeforeEachTest() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@ -176,9 +170,13 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
Locale.LanguageRange.parse("en-us")));
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());
@ -205,10 +203,8 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
clock = Clock.fixed(Instant.ofEpochSecond(42), ZoneId.of("Etc/UTC"));
return new ProfileGrpcService(
clock,
Clock.systemUTC(),
accountsManager,
profilesManager,
dynamicConfigurationManager,
@ -396,42 +392,6 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
}
}
@Test
void setProfileBadges() throws InvalidInputException {
final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(AUTHENTICATED_ACI)).serialize();
final SetProfileRequest request = SetProfileRequest.newBuilder()
.setVersion(VERSION)
.setName(ByteString.copyFrom(VALID_NAME))
.setAvatarChange(AvatarChange.AVATAR_CHANGE_UNCHANGED)
.addAllBadgeIds(List.of("TEST3"))
.setCommitment(ByteString.copyFrom(commitment))
.build();
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(account.getBadges())
.thenReturn(List.of(new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)))
.thenReturn(List.of(new AccountBadge("TEST2", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)));
//noinspection ResultOfMethodCallIgnored
authenticatedServiceStub().setProfile(request);
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(account, times(2)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(account, times(accountsManagerUpdateRetryCount)).getBadges();
assertEquals(List.of(
new AccountBadge("TEST3", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(41), false)),
badgeCaptor.getValue());
}
@ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void getUnversionedProfile(final IdentityType identityType) {

View File

@ -1,11 +1,13 @@
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.stub.StreamObserver;
import org.apache.commons.lang3.StringUtils;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.chat.rpc.UserAgent;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
@ -18,15 +20,21 @@ public class RequestAttributesServiceImpl extends RequestAttributesGrpc.RequestA
final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder();
RequestAttributesUtil.getAcceptableLanguages()
.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()));
RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages ->
acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString())));
RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale ->
responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag()));
responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress());
RequestAttributesUtil.getUserAgent().ifPresent(responseBuilder::setUserAgent);
RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder()
.setPlatform(userAgent.platform().toString())
.setVersion(userAgent.version().toString())
.setAdditionalSpecifiers(StringUtils.stripToEmpty(userAgent.additionalSpecifiers()))
.build()));
RequestAttributesUtil.getRawUserAgent().ifPresent(responseBuilder::setRawUserAgent);
responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted();

View File

@ -3,84 +3,172 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses;
import io.grpc.Context;
import java.net.InetAddress;
import java.util.Collections;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.concurrent.Callable;
import javax.annotation.Nullable;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RequestAttributesUtilTest {
private static final InetAddress REMOTE_ADDRESS = InetAddresses.forString("127.0.0.1");
private static DefaultEventLoopGroup eventLoopGroup;
@Test
void getAcceptableLanguages() throws Exception {
assertEquals(Collections.emptyList(),
callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAcceptableLanguages));
private GrpcClientConnectionManager grpcClientConnectionManager;
assertEquals(Locale.LanguageRange.parse("en,ja"),
callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAcceptableLanguages));
private Server server;
private ManagedChannel managedChannel;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server");
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString("127.0.0.1")));
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
}
@Test
void getAvailableAcceptedLocales() throws Exception {
assertEquals(Collections.emptyList(),
callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAvailableAcceptedLocales));
void getAcceptableLanguages() {
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.empty());
final List<Locale> availableAcceptedLocales =
callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAvailableAcceptedLocales);
assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty());
assertFalse(availableAcceptedLocales.isEmpty());
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
availableAcceptedLocales.forEach(locale ->
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage())));
assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList());
}
@Test
void getRemoteAddress() throws Exception {
assertEquals(REMOTE_ADDRESS,
callWithRequestAttributes(new RequestAttributes(REMOTE_ADDRESS, null, null),
RequestAttributesUtil::getRemoteAddress));
void getAvailableAcceptedLocales() {
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty());
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
final GetRequestAttributesResponse response = getRequestAttributes();
assertFalse(response.getAvailableAcceptedLocalesList().isEmpty());
response.getAvailableAcceptedLocalesList().forEach(languageTag -> {
final Locale locale = Locale.forLanguageTag(languageTag);
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()));
});
}
@Test
void getUserAgent() throws Exception {
assertEquals(Optional.empty(),
callWithRequestAttributes(buildRequestAttributes((String) null),
RequestAttributesUtil::getUserAgent));
void getRemoteAddress() {
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.empty());
assertEquals(Optional.of("Signal-Desktop/1.2.3 Linux"),
callWithRequestAttributes(buildRequestAttributes("Signal-Desktop/1.2.3 Linux"),
RequestAttributesUtil::getUserAgent));
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes);
final String remoteAddressString = "6.7.8.9";
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString(remoteAddressString)));
assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress());
}
private static <V> V callWithRequestAttributes(final RequestAttributes requestAttributes, final Callable<V> callable) throws Exception {
return Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes)
.call(callable);
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
when(grpcClientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.empty());
assertFalse(getRequestAttributes().hasUserAgent());
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
when(grpcClientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.of(userAgent));
final GetRequestAttributesResponse response = getRequestAttributes();
assertTrue(response.hasUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
}
private static RequestAttributes buildRequestAttributes(final String userAgent) {
return buildRequestAttributes(userAgent, Collections.emptyList());
@Test
void getRawUserAgent() {
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getRawUserAgent().isBlank());
final String userAgentString = "Signal-Desktop/1.2.3 Linux";
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.of(userAgentString));
assertEquals(userAgentString, getRequestAttributes().getRawUserAgent());
}
private static RequestAttributes buildRequestAttributes(final List<Locale.LanguageRange> acceptLanguage) {
return buildRequestAttributes(null, acceptLanguage);
}
private static RequestAttributes buildRequestAttributes(@Nullable final String userAgent,
final List<Locale.LanguageRange> acceptLanguage) {
return new RequestAttributes(REMOTE_ADDRESS, userAgent, acceptLanguage);
private GetRequestAttributesResponse getRequestAttributes() {
return RequestAttributesGrpc.newBlockingStub(managedChannel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
}
}

View File

@ -1,11 +1,7 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
@ -16,12 +12,6 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.net.InetAddress;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
@ -31,9 +21,20 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
class GrpcClientConnectionManagerTest {
@ -102,7 +103,7 @@ class GrpcClientConnectionManagerTest {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
assertEquals(maybeAuthenticatedDevice,
grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel));
grpcClientConnectionManager.getAuthenticatedDevice(localChannel.localAddress()));
}
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
@ -113,115 +114,170 @@ class GrpcClientConnectionManagerTest {
}
@Test
void getRequestAttributes() {
void getAcceptableLanguages() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel));
assertEquals(Optional.empty(),
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null);
remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes);
final List<Locale.LanguageRange> acceptLanguageRanges = Locale.LanguageRange.parse("en,ja");
remoteChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges);
assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel));
assertEquals(Optional.of(acceptLanguageRanges),
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
}
@Test
void closeConnection() throws InterruptedException, ChannelNotFoundException {
void getRemoteAddress() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9");
remoteChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress);
assertEquals(Optional.of(remoteAddress),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
remoteChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent);
assertEquals(Optional.of(userAgent),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
}
@Test
void closeConnection() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertTrue(remoteChannel.isOpen());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(List.of(remoteChannel),
grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@Test
void handleWebSocketHandshakeCompleteRemoteAddress() {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
preferredRemoteAddress,
null,
null);
assertEquals(preferredRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
@ParameterizedTest
@MethodSource
void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress,
final String userAgentHeader,
final String acceptLanguageHeader,
final RequestAttributes expectedRequestAttributes) {
void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader,
@Nullable final UserAgent expectedParsedUserAgent) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel,
preferredRemoteAddress,
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
userAgentHeader,
acceptLanguageHeader);
null);
assertEquals(expectedRequestAttributes,
embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
assertEquals(userAgentHeader,
embeddedChannel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get());
assertEquals(expectedParsedUserAgent,
embeddedChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleHandshakeCompleteRequestAttributes() {
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
private static List<Arguments> handleWebSocketHandshakeCompleteUserAgent() {
return List.of(
Arguments.argumentSet("Null User-Agent and Accept-Language headers",
preferredRemoteAddress, null, null,
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())),
// Recognized user-agent
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
Arguments.argumentSet("Recognized User-Agent and null Accept-Language header",
preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null,
new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())),
// Unrecognized user-agent
Arguments.of("Not a valid user-agent string", null),
Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header",
preferredRemoteAddress, "Not a valid user-agent string", null,
new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())),
// Missing user-agent
Arguments.of(null, null)
);
}
Arguments.argumentSet("Null User-Agent and parsable Accept-Language header",
preferredRemoteAddress, null, "ja,en;q=0.4",
new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))),
Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header",
preferredRemoteAddress, null, "This is not a valid language preference list",
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList()))
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader,
@Nullable final List<Locale.LanguageRange> expectedLanguageRanges) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
null,
acceptLanguageHeader);
assertEquals(expectedLanguageRanges,
embeddedChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteAcceptLanguage() {
return List.of(
// Parseable list
Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")),
// Unparsable list
Arguments.of("This is not a valid language preference list", null),
// Missing list
Arguments.of(null, null)
);
}
@Test
void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException {
void handleConnectionEstablishedAuthenticated() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@Test
void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException {
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
void handleConnectionEstablishedAnonymous() throws InterruptedException {
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
remoteChannel.close().await();
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
}
}

View File

@ -523,7 +523,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals(userAgent, response.getUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} finally {
channel.shutdown();
}

View File

@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
@ -17,7 +16,6 @@ import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.Attribute;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
@ -33,7 +31,6 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
@ -137,13 +134,8 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
embeddedChannel.setRemoteAddress(remoteAddress);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertEquals(expectedRemoteAddress,
Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY))
.map(Attribute::get)
.map(RequestAttributes::remoteAddress)
.orElse(null));
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
private static List<Arguments> getRemoteAddress() {
@ -152,53 +144,53 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
return List.of(
argumentSet("Recognized proxy, single forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
// Recognized proxy, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
clientAddress),
argumentSet("Recognized proxy, multiple forwarded-for addresses",
new DefaultHttpHeaders()
// Recognized proxy, multiple forwarded-for addresses
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
remoteAddress,
proxyAddress),
argumentSet("No recognized proxy header, single forwarded-for address",
new DefaultHttpHeaders()
// No recognized proxy header, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
argumentSet("No recognized proxy header, no forwarded-for address",
new DefaultHttpHeaders(),
// No recognized proxy header, no forwarded-for address
Arguments.of(new DefaultHttpHeaders(),
remoteAddress,
remoteAddress.getAddress()),
argumentSet("Incorrect proxy header, single forwarded-for address",
new DefaultHttpHeaders()
// Incorrect proxy header, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
argumentSet("Recognized proxy, no forwarded-for address",
new DefaultHttpHeaders()
// Recognized proxy, no forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
remoteAddress,
remoteAddress.getAddress()),
argumentSet("Recognized proxy, bogus forwarded-for address",
new DefaultHttpHeaders()
// Recognized proxy, bogus forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
remoteAddress,
null),
argumentSet("No forwarded-for address, non-InetSocketAddress remote address",
new DefaultHttpHeaders()
// No forwarded-for address, non-InetSocketAddress remote address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
new LocalAddress("local-address"),
null)

View File

@ -31,8 +31,8 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class AccountsHelper {
@ -62,71 +62,6 @@ public class AccountsHelper {
setupMockUpdate(mockAccountsManager, false);
}
/**
* Sets up stubbing for:
* <ul>
* <li>{@link AccountsManager#update(Account, Consumer)}</li>
* <li>{@link AccountsManager#updateAsync(Account, Consumer)}</li>
* <li>{@link AccountsManager#updateDevice(Account, byte, Consumer)}</li>
* <li>{@link AccountsManager#updateDeviceAsync(Account, byte, Consumer)}</li>
* </ul>
*
* with multiple calls to the {@link Consumer<Account>}. This simulates retries from {@link org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException}.
* Callers will typically set up stubbing for relevant {@link Account} methods with multiple {@link org.mockito.stubbing.OngoingStubbing#thenReturn(Object)}
* calls:
* <pre>
* // example stubbing
* when(account.getNextDeviceId())
* .thenReturn(2)
* .thenReturn(3);
* </pre>
*/
@SuppressWarnings("unchecked")
public static void setupMockUpdateWithRetries(final AccountsManager mockAccountsManager, final int retryCount) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
for (int i = 0; i < retryCount; i++) {
answer.getArgument(1, Consumer.class).accept(account);
}
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateAsync(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
for (int i = 0; i < retryCount; i++) {
answer.getArgument(1, Consumer.class).accept(account);
}
return CompletableFuture.completedFuture(copyAndMarkStale(account));
});
when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
for (int i = 0; i < retryCount; i++) {
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
}
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateDeviceAsync(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
for (int i = 0; i < retryCount; i++) {
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
}
return CompletableFuture.completedFuture(copyAndMarkStale(account));
});
}
@SuppressWarnings("unchecked")
private static void setupMockUpdate(final AccountsManager mockAccountsManager, final boolean markStale) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);

View File

@ -13,20 +13,28 @@ import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import javax.annotation.Nullable;
class UserAgentUtilTest {
@ParameterizedTest
@MethodSource("argumentsForTestParseStandardUserAgentString")
void testParseStandardUserAgentString(final String userAgentString, @Nullable final UserAgent expectedUserAgent)
throws UnrecognizedUserAgentException {
@MethodSource
void testParseBogusUserAgentString(final String userAgentString) {
assertThrows(UnrecognizedUserAgentException.class, () -> UserAgentUtil.parseUserAgentString(userAgentString));
}
if (expectedUserAgent != null) {
assertEquals(expectedUserAgent, UserAgentUtil.parseUserAgentString(userAgentString));
} else {
assertThrows(UnrecognizedUserAgentException.class, () -> UserAgentUtil.parseUserAgentString(userAgentString));
}
@SuppressWarnings("unused")
private static Stream<String> testParseBogusUserAgentString() {
return Stream.of(
null,
"This is obviously not a reasonable User-Agent string.",
"Signal-Android/4.6-8.3.unreasonableversionstring-17"
);
}
@ParameterizedTest
@MethodSource("argumentsForTestParseStandardUserAgentString")
void testParseStandardUserAgentString(final String userAgentString, final UserAgent expectedUserAgent) {
assertEquals(expectedUserAgent, UserAgentUtil.parseStandardUserAgentString(userAgentString));
}
private static Stream<Arguments> argumentsForTestParseStandardUserAgentString() {

View File

@ -23,7 +23,14 @@ message GetRequestAttributesResponse {
repeated string acceptable_languages = 1;
repeated string available_accepted_locales = 2;
string remote_address = 3;
string user_agent = 4;
string raw_user_agent = 4;
UserAgent user_agent = 5;
}
message UserAgent {
string platform = 1;
string version = 2;
string additional_specifiers = 3;
}
message GetAuthenticatedDeviceRequest {

View File

@ -470,8 +470,7 @@ turn:
cloudflare:
apiToken: secret://turn.cloudflare.apiToken
endpoint: https://rtc.live.cloudflare.com/v1/turn/keys/LMNOP/credentials/generate
requestedCredentialTtl: PT24H
clientCredentialTtl: PT12H
ttl: 86400
urls:
- turn:turn.example.com:80
urlsWithIps:

@ -1 +1 @@
Subproject commit d9852e294a853b88c7feaa748e17fee38acbf849
Subproject commit 8f566196d763c8eb1f3c8fcefd5be3c35ff8d148