Compare commits

...

28 Commits

Author SHA1 Message Date
Jonathan Klabunde Tomer 21c1d71551 take advantage of list non-nullitude 2025-04-25 10:06:42 -05:00
Jonathan Klabunde Tomer 38befdb260 default lists to empty 2025-04-25 10:06:42 -05:00
Jonathan Klabunde Tomer 63c79173b2 limit prekey uploads to 100 2025-04-25 10:06:42 -05:00
Ameya Lokare d2ad003891 Remove free memory and OS memory gauges 2025-04-25 10:05:29 -05:00
Chris Eager eb89773819 Remove unused parameter 2025-04-25 10:05:18 -05:00
Chris Eager 403abd84f6 Run test action on pull_request events 2025-04-25 10:05:08 -05:00
Jon Chambers f62f79c95c Add a counter for cases where clients use both an authenticated identity and UAK when fetching profiles 2025-04-24 11:47:43 -04:00
Jon Chambers 144c4c9223 Add a "sync" dimension to the "sent message" counter 2025-04-24 10:33:39 -05:00
Ravi Khadiwala ab4fc4f459 Add skip low urgency push experiment 2025-04-24 10:32:46 -05:00
Jonathan Klabunde Tomer 51569ce0a5
Use cached partition topology for metrics/logs 2025-04-24 08:29:58 -07:00
Jon Chambers f191c68efc Close remote connections only after all active server calls have completed 2025-04-22 17:00:48 -04:00
Jon Chambers bb8ce6d981 Introduce `ClosableEpoch` 2025-04-22 17:00:48 -04:00
Katherine e0ee75e0d0
Fix Daylight Savings bug in recommended notification time calculation 2025-04-22 16:56:10 -04:00
Jon Chambers 1ef3a230a1 Tag queue size distribution with client platform 2025-04-22 16:55:16 -04:00
Jon Chambers b1805d4bf1 Add a "persisted bytes" counter 2025-04-22 16:55:16 -04:00
Jon Chambers cac979c7fd Count individual persisted messages 2025-04-22 16:55:16 -04:00
Jon Chambers 4072dcdda5 Introduce `DevicePlatformUtil` 2025-04-22 16:55:16 -04:00
Jonathan Klabunde Tomer ed382fff6d log slot number and shard host of message persister failures 2025-04-22 16:55:16 -04:00
Jon Chambers 23bb8277d5 Update to the latest version of the spam filter 2025-04-18 15:56:17 -04:00
Jon Chambers 8099d6465c
Clarify guarantees around remote channnel/request attribute presence 2025-04-18 15:44:21 -04:00
Jon Chambers 28a0b9e84e
Include a TURN credential TTL for clients in `GetCallingRelaysResponse` 2025-04-17 10:30:58 -04:00
Chris Eager 9287aaf7ce Add app info to Stripe API calls 2025-04-17 09:30:34 -05:00
Chris Eager 0585f862cb Add regression test for set profile badges calculation 2025-04-17 09:29:11 -05:00
Chris Eager 7cac6f6f72 Remove extraneous account fetch in POST /v1/donation/redeem-receipt 2025-04-17 09:28:57 -05:00
Jon Chambers 57be4d798b Add a counter for attempts to send empty message lists 2025-04-17 10:27:46 -04:00
Jon Chambers 05c74f1997 Simplify `UserAgentUtil` 2025-04-17 10:27:24 -04:00
Jon Chambers f5e49b6db7 Convert `UserAgent` to a record 2025-04-15 14:58:09 -04:00
Jon Chambers 3c40e72d27
Fix registration ID map construction when changing numbers 2025-04-15 14:57:28 -04:00
84 changed files with 1708 additions and 1069 deletions

View File

@ -1,6 +1,7 @@
name: Service CI name: Service CI
on: on:
pull_request:
push: push:
branches-ignore: branches-ignore:
- gh-pages - gh-pages

View File

@ -482,7 +482,8 @@ turn:
- turn:%s - turn:%s
- turn:%s:80?transport=tcp - turn:%s:80?transport=tcp
- turns:%s:443?transport=tcp - turns:%s:443?transport=tcp
ttl: 86400 requestedCredentialTtl: PT24H
clientCredentialTtl: PT12H
hostname: turn.cloudflare.example.com hostname: turn.cloudflare.example.com
numHttpClients: 1 numHttpClients: 1

View File

@ -668,12 +668,13 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager); final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager); final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager);
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor); final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager( final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
config.getTurnConfiguration().cloudflare().apiToken().value(), config.getTurnConfiguration().cloudflare().apiToken().value(),
config.getTurnConfiguration().cloudflare().endpoint(), config.getTurnConfiguration().cloudflare().endpoint(),
config.getTurnConfiguration().cloudflare().ttl(), config.getTurnConfiguration().cloudflare().requestedCredentialTtl(),
config.getTurnConfiguration().cloudflare().clientCredentialTtl(),
config.getTurnConfiguration().cloudflare().urls(), config.getTurnConfiguration().cloudflare().urls(),
config.getTurnConfiguration().cloudflare().urlsWithIps(), config.getTurnConfiguration().cloudflare().urlsWithIps(),
config.getTurnConfiguration().cloudflare().hostname(), config.getTurnConfiguration().cloudflare().hostname(),
@ -693,7 +694,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager, PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager,
pushChallengeDynamoDb); pushChallengeDynamoDb);
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, Clock.systemUTC());
HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build(); HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build();
FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients() FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients()
@ -987,7 +988,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.setConnectListener( webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager, new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor, pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor)); messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager));
webSocketEnvironment.jersey() webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager)); .register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters)); webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));

View File

@ -15,6 +15,7 @@ import java.net.Inet6Address;
import java.net.URI; import java.net.URI;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -39,16 +40,18 @@ public class CloudflareTurnCredentialsManager {
private final List<String> cloudflareTurnUrls; private final List<String> cloudflareTurnUrls;
private final List<String> cloudflareTurnUrlsWithIps; private final List<String> cloudflareTurnUrlsWithIps;
private final String cloudflareTurnHostname; private final String cloudflareTurnHostname;
private final HttpRequest request; private final HttpRequest getCredentialsRequest;
private final FaultTolerantHttpClient cloudflareTurnClient; private final FaultTolerantHttpClient cloudflareTurnClient;
private final DnsNameResolver dnsNameResolver; private final DnsNameResolver dnsNameResolver;
record CredentialRequest(long ttl) {} private final Duration clientCredentialTtl;
record CloudflareTurnResponse(IceServer iceServers) { private record CredentialRequest(long ttl) {}
record IceServer( private record CloudflareTurnResponse(IceServer iceServers) {
private record IceServer(
String username, String username,
String credential, String credential,
List<String> urls) { List<String> urls) {
@ -56,10 +59,17 @@ public class CloudflareTurnCredentialsManager {
} }
public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken, public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken,
final String cloudflareTurnEndpoint, final long cloudflareTurnTtl, final List<String> cloudflareTurnUrls, final String cloudflareTurnEndpoint,
final List<String> cloudflareTurnUrlsWithIps, final String cloudflareTurnHostname, final Duration requestedCredentialTtl,
final int cloudflareTurnNumHttpClients, final CircuitBreakerConfiguration circuitBreaker, final Duration clientCredentialTtl,
final ExecutorService executor, final RetryConfiguration retry, final ScheduledExecutorService retryExecutor, final List<String> cloudflareTurnUrls,
final List<String> cloudflareTurnUrlsWithIps,
final String cloudflareTurnHostname,
final int cloudflareTurnNumHttpClients,
final CircuitBreakerConfiguration circuitBreaker,
final ExecutorService executor,
final RetryConfiguration retry,
final ScheduledExecutorService retryExecutor,
final DnsNameResolver dnsNameResolver) { final DnsNameResolver dnsNameResolver) {
this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder() this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder()
@ -75,17 +85,24 @@ public class CloudflareTurnCredentialsManager {
this.cloudflareTurnHostname = cloudflareTurnHostname; this.cloudflareTurnHostname = cloudflareTurnHostname;
this.dnsNameResolver = dnsNameResolver; this.dnsNameResolver = dnsNameResolver;
final String credentialsRequestBody;
try { try {
final String body = SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(cloudflareTurnTtl)); credentialsRequestBody =
this.request = HttpRequest.newBuilder() SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(requestedCredentialTtl.toSeconds()));
} catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
// We repeat the same request to Cloudflare every time, so we can construct it once and re-use it
this.getCredentialsRequest = HttpRequest.newBuilder()
.uri(URI.create(cloudflareTurnEndpoint)) .uri(URI.create(cloudflareTurnEndpoint))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken)) .header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
.POST(HttpRequest.BodyPublishers.ofString(body)) .POST(HttpRequest.BodyPublishers.ofString(credentialsRequestBody))
.build(); .build();
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e); this.clientCredentialTtl = clientCredentialTtl;
}
} }
public TurnToken retrieveFromCloudflare() throws IOException { public TurnToken retrieveFromCloudflare() throws IOException {
@ -105,7 +122,7 @@ public class CloudflareTurnCredentialsManager {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
final HttpResponse<String> response; final HttpResponse<String> response;
try { try {
response = cloudflareTurnClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join(); response = cloudflareTurnClient.sendAsync(getCredentialsRequest, HttpResponse.BodyHandlers.ofString()).join();
sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME) sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME)
.publishPercentileHistogram(true) .publishPercentileHistogram(true)
.tags("outcome", "success") .tags("outcome", "success")
@ -130,6 +147,7 @@ public class CloudflareTurnCredentialsManager {
return new TurnToken( return new TurnToken(
cloudflareTurnResponse.iceServers().username(), cloudflareTurnResponse.iceServers().username(),
cloudflareTurnResponse.iceServers().credential(), cloudflareTurnResponse.iceServers().credential(),
clientCredentialTtl.toSeconds(),
cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls, cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls,
cloudflareTurnComposedUrls, cloudflareTurnComposedUrls,
cloudflareTurnHostname cloudflareTurnHostname

View File

@ -5,13 +5,15 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List;
public record TurnToken( public record TurnToken(
String username, String username,
String password, String password,
@JsonProperty("ttl") long ttlSeconds,
@Nonnull List<String> urls, @Nonnull List<String> urls,
@Nonnull List<String> urlsWithIps, @Nonnull List<String> urlsWithIps,
@Nullable String hostname) { @Nullable String hostname) {

View File

@ -1,34 +1,22 @@
package org.whispersystems.textsecuregcm.auth.grpc; package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import java.util.Optional; import java.util.Optional;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor { abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager; private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Metadata EMPTY_TRAILERS = new Metadata();
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager; this.grpcClientConnectionManager = grpcClientConnectionManager;
} }
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) { protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call)
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { throws ChannelNotFoundException {
return grpcClientConnectionManager.getAuthenticatedDevice(localAddress);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) { return grpcClientConnectionManager.getAuthenticatedDevice(call);
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
return new ServerCall.Listener<>() {};
} }
} }

View File

@ -3,12 +3,17 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/** /**
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not * A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated * originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
* device are closed with an {@code UNAUTHENTICATED} status. * device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
* reject the call with a status of {@code UNAVAILABLE}.
*/ */
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor { public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -21,8 +26,15 @@ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInt
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
try {
return getAuthenticatedDevice(call) return getAuthenticatedDevice(call)
.map(ignored -> closeAsUnauthenticated(call)) // Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited
// service via an authenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
.orElseGet(() -> next.startCall(call, headers)); .orElseGet(() -> next.startCall(call, headers));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
} }
} }

View File

@ -5,12 +5,16 @@ import io.grpc.Contexts;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/** /**
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an * A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED} * authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
* status. * status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
*/ */
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor { public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -23,10 +27,17 @@ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInte
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
try {
return getAuthenticatedDevice(call) return getAuthenticatedDevice(call)
.map(authenticatedDevice -> Contexts.interceptCall(Context.current() .map(authenticatedDevice -> Contexts.interceptCall(Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
call, headers, next)) call, headers, next))
.orElseGet(() -> closeAsUnauthenticated(call)); // Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
} }
} }

View File

@ -6,16 +6,36 @@
package org.whispersystems.textsecuregcm.configuration; package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import java.time.Duration;
import java.util.List; import java.util.List;
import jakarta.validation.constraints.Positive; import jakarta.validation.constraints.Positive;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString; import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
/**
* Configuration properties for Cloudflare TURN integration.
*
* @param apiToken the API token to use when requesting TURN tokens from Cloudflare
* @param endpoint the URI of the Cloudflare API endpoint that vends TURN tokens
* @param requestedCredentialTtl the lifetime of TURN tokens to request from Cloudflare
* @param clientCredentialTtl the time clients may cache a TURN token; must be less than or equal to {@link #requestedCredentialTtl}
* @param urls a collection of TURN URLs to include verbatim in responses to clients
* @param urlsWithIps a collection of {@link String#format(String, Object...)} patterns to be populated with resolved IP
* addresses for {@link #hostname} in responses to clients; each pattern must include a single
* {@code %s} placeholder for the IP address
* @param circuitBreaker a circuit breaker for requests to Cloudflare
* @param retry a retry policy for requests to Cloudflare
* @param hostname the hostname to resolve to IP addresses for use with {@link #urlsWithIps}; also transmitted to
* clients for use as an SNI when connecting to pre-resolved hosts
* @param numHttpClients the number of parallel HTTP clients to use to communicate with Cloudflare
*/
public record CloudflareTurnConfiguration(@NotNull SecretString apiToken, public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
@NotBlank String endpoint, @NotBlank String endpoint,
@NotBlank long ttl, @NotNull Duration requestedCredentialTtl,
@NotNull Duration clientCredentialTtl,
@NotNull @NotEmpty @Valid List<@NotBlank String> urls, @NotNull @NotEmpty @Valid List<@NotBlank String> urls,
@NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps, @NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps,
@NotNull @Valid CircuitBreakerConfiguration circuitBreaker, @NotNull @Valid CircuitBreakerConfiguration circuitBreaker,
@ -35,4 +55,9 @@ public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
retry = new RetryConfiguration(); retry = new RetryConfiguration();
} }
} }
@AssertTrue
public boolean isClientTtlShorterThanRequestedTtl() {
return clientCredentialTtl.compareTo(requestedCredentialTtl) <= 0;
}
} }

View File

@ -15,16 +15,12 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import jakarta.ws.rs.GET; import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces; import jakarta.ws.rs.Produces;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.MediaType;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager; import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly; import org.whispersystems.websocket.auth.ReadOnly;
@ -32,14 +28,16 @@ import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v2/calling") @Path("/v2/calling")
public class CallRoutingControllerV2 { public class CallRoutingControllerV2 {
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager; private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER =
Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
public CallRoutingControllerV2( public CallRoutingControllerV2(
final RateLimiters rateLimiters, final RateLimiters rateLimiters,
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager) {
) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager; this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager;
} }
@ -58,25 +56,17 @@ public class CallRoutingControllerV2 {
@ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponse(responseCode = "429", description = "Rate limited.")
public GetCallingRelaysResponse getCallingRelays( public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth)
final @ReadOnly @Auth AuthenticatedDevice auth throws RateLimitExceededException, IOException {
) throws RateLimitExceededException, IOException {
UUID aci = auth.getAccount().getUuid(); final UUID aci = auth.getAccount().getUuid();
rateLimiters.getCallEndpointLimiter().validate(aci); rateLimiters.getCallEndpointLimiter().validate(aci);
List<TurnToken> tokens = new ArrayList<>();
try { try {
tokens.add(cloudflareTurnCredentialsManager.retrieveFromCloudflare()); return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare()));
} catch (Exception e) { } catch (final Exception e) {
CallRoutingControllerV2.CLOUDFLARE_TURN_ERROR_COUNTER.increment(); CLOUDFLARE_TURN_ERROR_COUNTER.increment();
throw e; throw e;
} }
return new GetCallingRelaysResponse(tokens);
}
public record GetCallingRelaysResponse(
List<TurnToken> relays
) {
} }
} }

View File

@ -44,7 +44,6 @@ import java.util.EnumMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -52,7 +51,6 @@ import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
@ -74,6 +72,7 @@ import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -402,7 +401,7 @@ public class DeviceController {
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) { private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
try { try {
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform()); return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).platform());
} catch (final UnrecognizedUserAgentException ignored) { } catch (final UnrecognizedUserAgentException ignored) {
return linkedDeviceListenersForUnrecognizedPlatforms; return linkedDeviceListenersForUnrecognizedPlatforms;
} }
@ -600,25 +599,9 @@ public class DeviceController {
} }
private static io.micrometer.core.instrument.Tag primaryPlatformTag(final Account account) { private static io.micrometer.core.instrument.Tag primaryPlatformTag(final Account account) {
final Device primaryDevice = account.getPrimaryDevice();
Optional<ClientPlatform> clientPlatform = Optional.empty();
if (StringUtils.isNotBlank(primaryDevice.getGcmId())) {
clientPlatform = Optional.of(ClientPlatform.ANDROID);
} else if (StringUtils.isNotBlank(primaryDevice.getApnId())) {
clientPlatform = Optional.of(ClientPlatform.IOS);
}
clientPlatform = clientPlatform.or(() -> Optional.ofNullable(
switch (primaryDevice.getUserAgent()) {
case "OWA" -> ClientPlatform.ANDROID;
case "OWI", "OWP" -> ClientPlatform.IOS;
case "OWD" -> ClientPlatform.DESKTOP;
case null, default -> null;
}));
return io.micrometer.core.instrument.Tag.of( return io.micrometer.core.instrument.Tag.of(
"primaryPlatform", "primaryPlatform",
clientPlatform DevicePlatformUtil.getDevicePlatform(account.getPrimaryDevice())
.map(p -> p.name().toLowerCase(Locale.ROOT)) .map(p -> p.name().toLowerCase(Locale.ROOT))
.orElse("unknown")); .orElse("unknown"));
} }

View File

@ -104,14 +104,12 @@ public class DonationController {
.type(MediaType.TEXT_PLAIN_TYPE).build()); .type(MediaType.TEXT_PLAIN_TYPE).build());
} }
return accountsManager.getByAccountIdentifierAsync(auth.getAccount().getUuid()) return accountsManager.updateAsync(auth.getAccount(), a -> {
.thenCompose(optionalAccount ->
optionalAccount.map(account -> accountsManager.updateAsync(account, a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible())); a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) { if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId); a.makeBadgePrimaryIfExists(clock, badgeId);
} }
})).orElse(CompletableFuture.completedFuture(null))) })
.thenApply(ignored -> Response.ok().build()); .thenApply(ignored -> Response.ok().build());
}); });
}).thenCompose(Function.identity()); }).thenCompose(Function.identity());

View File

@ -0,0 +1,13 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import java.util.List;
public record GetCallingRelaysResponse(List<TurnToken> relays) {
}

View File

@ -152,7 +152,7 @@ public class KeysController {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4); final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
if (setKeysRequest.preKeys() != null && !setKeysRequest.preKeys().isEmpty()) { if (!setKeysRequest.preKeys().isEmpty()) {
Metrics.counter(STORE_KEYS_COUNTER_NAME, Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec"))) Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec")))
.increment(); .increment();
@ -168,7 +168,7 @@ public class KeysController {
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey())); storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
} }
if (setKeysRequest.pqPreKeys() != null && !setKeysRequest.pqPreKeys().isEmpty()) { if (!setKeysRequest.pqPreKeys().isEmpty()) {
Metrics.counter(STORE_KEYS_COUNTER_NAME, Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber"))) Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber")))
.increment(); .increment();
@ -192,11 +192,7 @@ public class KeysController {
final IdentityKey identityKey, final IdentityKey identityKey,
@Nullable final String userAgent) { @Nullable final String userAgent) {
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(); final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(setKeysRequest.pqPreKeys());
if (setKeysRequest.pqPreKeys() != null) {
signedPreKeys.addAll(setKeysRequest.pqPreKeys());
}
if (setKeysRequest.pqLastResortPreKey() != null) { if (setKeysRequest.pqLastResortPreKey() != null) {
signedPreKeys.add(setKeysRequest.pqLastResortPreKey()); signedPreKeys.add(setKeysRequest.pqLastResortPreKey());

View File

@ -428,7 +428,7 @@ public class OneTimeDonationController {
@Nullable @Nullable
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) { private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
try { try {
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform(); return UserAgentUtil.parseUserAgentString(userAgentString).platform();
} catch (final UnrecognizedUserAgentException e) { } catch (final UnrecognizedUserAgentException e) {
return null; return null;
} }

View File

@ -47,6 +47,7 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.ServiceId;
@ -123,6 +124,7 @@ public class ProfileController {
private static final String EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE = "expiringProfileKey"; private static final String EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE = "expiringProfileKey";
private static final String VERSION_NOT_FOUND_COUNTER_NAME = name(ProfileController.class, "versionNotFound"); private static final String VERSION_NOT_FOUND_COUNTER_NAME = name(ProfileController.class, "versionNotFound");
private static final String DUPLICATE_AUTHENTICATION_COUNTER_NAME = name(ProfileController.class, "duplicateAuthentication");
public ProfileController( public ProfileController(
Clock clock, Clock clock,
@ -230,11 +232,12 @@ public class ProfileController {
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version) @PathParam("version") String version,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException { throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount); final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier); final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "getVersionedProfile", userAgent);
return buildVersionedProfileResponse(targetAccount, return buildVersionedProfileResponse(targetAccount,
version, version,
@ -253,7 +256,8 @@ public class ProfileController {
@PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version, @PathParam("version") String version,
@PathParam("credentialRequest") String credentialRequest, @PathParam("credentialRequest") String credentialRequest,
@QueryParam("credentialType") String credentialType) @QueryParam("credentialType") String credentialType,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException { throws RateLimitExceededException {
if (!EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE.equals(credentialType)) { if (!EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE.equals(credentialType)) {
@ -261,7 +265,7 @@ public class ProfileController {
} }
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount); final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier); final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "credentialRequest", userAgent);
final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false); final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false);
return buildExpiringProfileKeyCredentialProfileResponse(targetAccount, return buildExpiringProfileKeyCredentialProfileResponse(targetAccount,
@ -283,8 +287,7 @@ public class ProfileController {
@HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken, @HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@PathParam("identifier") ServiceIdentifier identifier, @PathParam("identifier") ServiceIdentifier identifier)
@QueryParam("ca") boolean useCaCertificate)
throws RateLimitExceededException { throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount); final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
@ -303,7 +306,7 @@ public class ProfileController {
} }
} else { } else {
targetAccount = verifyPermissionToReceiveProfile( targetAccount = verifyPermissionToReceiveProfile(
maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier); maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier, "getUnversionedProfile", userAgent);
} }
return switch (identifier.identityType()) { return switch (identifier.identityType()) {
case ACI -> buildBaseProfileResponseForAccountIdentity(targetAccount, case ACI -> buildBaseProfileResponseForAccountIdentity(targetAccount,
@ -386,7 +389,7 @@ public class ProfileController {
profileKeyCredentialResponse = ProfileHelper.getExpiringProfileKeyCredential(HexFormat.of().parseHex(encodedCredentialRequest), profileKeyCredentialResponse = ProfileHelper.getExpiringProfileKeyCredential(HexFormat.of().parseHex(encodedCredentialRequest),
profile, new ServiceId.Aci(account.getUuid()), zkProfileOperations); profile, new ServiceId.Aci(account.getUuid()), zkProfileOperations);
} catch (VerificationFailedException | InvalidInputException e) { } catch (VerificationFailedException | InvalidInputException e) {
throw new BadRequestException(Response.status(Response.Status.BAD_REQUEST).build(), e); throw new BadRequestException(e);
} }
return profileKeyCredentialResponse; return profileKeyCredentialResponse;
}) })
@ -474,7 +477,15 @@ public class ProfileController {
*/ */
private Account verifyPermissionToReceiveProfile(final Optional<Account> maybeRequester, private Account verifyPermissionToReceiveProfile(final Optional<Account> maybeRequester,
final Optional<Anonymous> maybeAccessKey, final Optional<Anonymous> maybeAccessKey,
final ServiceIdentifier accountIdentifier) throws RateLimitExceededException { final ServiceIdentifier accountIdentifier,
final String endpoint,
@Nullable final String userAgent) throws RateLimitExceededException {
if (maybeRequester.isPresent() && maybeAccessKey.isPresent()) {
Metrics.counter(DUPLICATE_AUTHENTICATION_COUNTER_NAME,
Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), io.micrometer.core.instrument.Tag.of("endpoint", endpoint)))
.increment();
}
if (maybeRequester.isPresent()) { if (maybeRequester.isPresent()) {
rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid()); rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid());

View File

@ -755,7 +755,7 @@ public class SubscriptionController {
@Nullable @Nullable
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) { private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
try { try {
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform(); return UserAgentUtil.parseUserAgentString(userAgentString).platform();
} catch (final UnrecognizedUserAgentException e) { } catch (final UnrecognizedUserAgentException e) {
return null; return null;
} }

View File

@ -5,11 +5,15 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import java.util.List; import java.util.List;
public record SetKeysRequest( public record SetKeysRequest(
@NotNull
@Valid @Valid
@Size(max=100)
@Schema(description = """ @Schema(description = """
A list of unsigned elliptic-curve prekeys to use for this device. If present and not empty, replaces all stored A list of unsigned elliptic-curve prekeys to use for this device. If present and not empty, replaces all stored
unsigned EC prekeys for the device; if absent or empty, any stored unsigned EC prekeys for the device are not unsigned EC prekeys for the device; if absent or empty, any stored unsigned EC prekeys for the device are not
@ -25,7 +29,9 @@ public record SetKeysRequest(
""") """)
ECSignedPreKey signedPreKey, ECSignedPreKey signedPreKey,
@NotNull
@Valid @Valid
@Size(max=100)
@Schema(description = """ @Schema(description = """
A list of signed post-quantum one-time prekeys to use for this device. Each key must have a valid signature from A list of signed post-quantum one-time prekeys to use for this device. Each key must have a valid signature from
the identity key in this request. If present and not empty, replaces all stored unsigned PQ prekeys for the the identity key in this request. If present and not empty, replaces all stored unsigned PQ prekeys for the
@ -40,4 +46,16 @@ public record SetKeysRequest(
deleted. If present, must have a valid signature from the identity key in this request. deleted. If present, must have a valid signature from the identity key in this request.
""") """)
KEMSignedPreKey pqLastResortPreKey) { KEMSignedPreKey pqLastResortPreKey) {
public SetKeysRequest {
// Its a little counter-intuitive, but this compact constructor allows a default value
// to be used when one isnt specified, allowing the field to still be
// validated as @NotNull
if (preKeys == null) {
preKeys = List.of();
}
if (pqPreKeys == null) {
pqPreKeys = List.of();
}
}
} }

View File

@ -81,7 +81,16 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) { @Nullable final UserAgent userAgent = RequestAttributesUtil.getUserAgent()
.map(userAgentString -> {
try {
return UserAgentUtil.parseUserAgentString(userAgentString);
} catch (final UnrecognizedUserAgentException e) {
return null;
}
}).orElse(null);
if (shouldBlock(userAgent)) {
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata()); call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
return new ServerCall.Listener<>() {}; return new ServerCall.Listener<>() {};
} else { } else {
@ -108,28 +117,28 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
return true; return true;
} }
if (blockedVersionsByPlatform.containsKey(userAgent.getPlatform())) { if (blockedVersionsByPlatform.containsKey(userAgent.platform())) {
if (blockedVersionsByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) { if (blockedVersionsByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
recordDeprecation(userAgent, BLOCKED_CLIENT_REASON); recordDeprecation(userAgent, BLOCKED_CLIENT_REASON);
shouldBlock = true; shouldBlock = true;
} }
} }
if (minimumVersionsByPlatform.containsKey(userAgent.getPlatform())) { if (minimumVersionsByPlatform.containsKey(userAgent.platform())) {
if (userAgent.getVersion().isLowerThan(minimumVersionsByPlatform.get(userAgent.getPlatform()))) { if (userAgent.version().isLowerThan(minimumVersionsByPlatform.get(userAgent.platform()))) {
recordDeprecation(userAgent, EXPIRED_CLIENT_REASON); recordDeprecation(userAgent, EXPIRED_CLIENT_REASON);
shouldBlock = true; shouldBlock = true;
} }
} }
if (versionsPendingBlockByPlatform.containsKey(userAgent.getPlatform())) { if (versionsPendingBlockByPlatform.containsKey(userAgent.platform())) {
if (versionsPendingBlockByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) { if (versionsPendingBlockByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
recordPendingDeprecation(userAgent, BLOCKED_CLIENT_REASON); recordPendingDeprecation(userAgent, BLOCKED_CLIENT_REASON);
} }
} }
if (versionsPendingDeprecationByPlatform.containsKey(userAgent.getPlatform())) { if (versionsPendingDeprecationByPlatform.containsKey(userAgent.platform())) {
if (userAgent.getVersion().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.getPlatform()))) { if (userAgent.version().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.platform()))) {
recordPendingDeprecation(userAgent, EXPIRED_CLIENT_REASON); recordPendingDeprecation(userAgent, EXPIRED_CLIENT_REASON);
} }
} }
@ -139,13 +148,13 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
private void recordDeprecation(final UserAgent userAgent, final String reason) { private void recordDeprecation(final UserAgent userAgent, final String reason) {
Metrics.counter(DEPRECATED_CLIENT_COUNTER_NAME, Metrics.counter(DEPRECATED_CLIENT_COUNTER_NAME,
PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized", PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized",
REASON_TAG_NAME, reason).increment(); REASON_TAG_NAME, reason).increment();
} }
private void recordPendingDeprecation(final UserAgent userAgent, final String reason) { private void recordPendingDeprecation(final UserAgent userAgent, final String reason) {
Metrics.counter(PENDING_DEPRECATION_COUNTER_NAME, Metrics.counter(PENDING_DEPRECATION_COUNTER_NAME,
PLATFORM_TAG, userAgent.getPlatform().name().toLowerCase(), PLATFORM_TAG, userAgent.platform().name().toLowerCase(),
REASON_TAG_NAME, reason).increment(); REASON_TAG_NAME, reason).increment();
} }
} }

View File

@ -15,8 +15,6 @@ import jakarta.ws.rs.container.ContainerRequestFilter;
import jakarta.ws.rs.core.SecurityContext; import jakarta.ws.rs.core.SecurityContext;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@ -70,8 +68,8 @@ public class RestDeprecationFilter implements ContainerRequestFilter {
try { try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
final ClientPlatform platform = userAgent.getPlatform(); final ClientPlatform platform = userAgent.platform();
final Semver version = userAgent.getVersion(); final Semver version = userAgent.version();
if (!minimumRestFreeVersion.containsKey(platform)) { if (!minimumRestFreeVersion.containsKey(platform)) {
return; return;
} }

View File

@ -0,0 +1,12 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
/**
* Indicates that a remote channel was not found for a given server call or remote address.
*/
public class ChannelNotFoundException extends Exception {
}

View File

@ -0,0 +1,55 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with
* {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise.
*/
public class ChannelShutdownInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (!grpcClientConnectionManager.handleServerCallStart(call)) {
// Don't allow new calls if the connection is getting ready to close
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {
@Override
public void onComplete() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onComplete();
}
@Override
public void onCancel() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onCancel();
}
};
}
}

View File

@ -253,7 +253,7 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me
story, story,
ephemeral, ephemeral,
urgent, urgent,
RequestAttributesUtil.getRawUserAgent().orElse(null)); RequestAttributesUtil.getUserAgent().orElse(null));
final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder(); final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder();

View File

@ -55,7 +55,7 @@ public class MessagesGrpcHelper {
messagesByDeviceId, messagesByDeviceId,
registrationIdsByDeviceId, registrationIdsByDeviceId,
syncMessageSenderDeviceId, syncMessageSenderDeviceId,
RequestAttributesUtil.getRawUserAgent().orElse(null)); RequestAttributesUtil.getUserAgent().orElse(null));
return SEND_MESSAGE_SUCCESS_RESPONSE; return SEND_MESSAGE_SUCCESS_RESPONSE;
} catch (final MismatchedDevicesException e) { } catch (final MismatchedDevicesException e) {

View File

@ -0,0 +1,16 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
public record RequestAttributes(InetAddress remoteAddress,
@Nullable String userAgent,
List<Locale.LanguageRange> acceptLanguage) {
}

View File

@ -2,28 +2,25 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Contexts; import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import io.grpc.Status; import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
/**
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
*
* @see RequestAttributesUtil
*/
public class RequestAttributesInterceptor implements ServerInterceptor { public class RequestAttributesInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager; private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager; this.grpcClientConnectionManager = grpcClientConnectionManager;
} }
@ -33,52 +30,12 @@ public class RequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { try {
Context context = Context.current(); return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
{ grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next);
final Optional<InetAddress> maybeRemoteAddress = grpcClientConnectionManager.getRemoteAddress(localAddress); } catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
if (maybeRemoteAddress.isEmpty()) {
// We should never have a call from a party whose remote address we can't identify
log.warn("No remote address available");
call.close(Status.INTERNAL, new Metadata());
return new ServerCall.Listener<>() {};
}
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
}
{
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
grpcClientConnectionManager.getAcceptableLanguages(localAddress);
if (maybeAcceptLanguage.isPresent()) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
}
}
{
final Optional<String> maybeRawUserAgent =
grpcClientConnectionManager.getRawUserAgent(localAddress);
if (maybeRawUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.RAW_USER_AGENT_CONTEXT_KEY, maybeRawUserAgent.get());
}
}
{
final Optional<UserAgent> maybeUserAgent = grpcClientConnectionManager.getUserAgent(localAddress);
if (maybeUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
}
}
return Contexts.interceptCall(context, call, headers, next);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
} }
} }
} }

View File

@ -3,18 +3,13 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context; import io.grpc.Context;
import java.net.InetAddress; import java.net.InetAddress;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class RequestAttributesUtil { public class RequestAttributesUtil {
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language"); static final Context.Key<RequestAttributes> REQUEST_ATTRIBUTES_CONTEXT_KEY = Context.key("request-attributes");
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
static final Context.Key<String> RAW_USER_AGENT_CONTEXT_KEY = Context.key("unparsed-user-agent");
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("parsed-user-agent");
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales()); private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
@ -23,8 +18,8 @@ public class RequestAttributesUtil {
* *
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified * @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
*/ */
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() { public static List<Locale.LanguageRange> getAcceptableLanguages() {
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get()); return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().acceptLanguage();
} }
/** /**
@ -35,9 +30,7 @@ public class RequestAttributesUtil {
* @return a list of distinct locales acceptable to the remote client and available in this JVM * @return a list of distinct locales acceptable to the remote client and available in this JVM
*/ */
public static List<Locale> getAvailableAcceptedLocales() { public static List<Locale> getAvailableAcceptedLocales() {
return getAcceptableLanguages() return Locale.filter(getAcceptableLanguages(), AVAILABLE_LOCALES);
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
.orElseGet(Collections::emptyList);
} }
/** /**
@ -46,16 +39,7 @@ public class RequestAttributesUtil {
* @return the remote address of the remote client * @return the remote address of the remote client
*/ */
public static InetAddress getRemoteAddress() { public static InetAddress getRemoteAddress() {
return REMOTE_ADDRESS_CONTEXT_KEY.get(); return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().remoteAddress();
}
/**
* Returns the parsed user-agent of the remote client in the current gRPC request context.
*
* @return the parsed user-agent of the remote client; may be empty if unparseable or not specified
*/
public static Optional<UserAgent> getUserAgent() {
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
} }
/** /**
@ -63,7 +47,7 @@ public class RequestAttributesUtil {
* *
* @return the unparsed user-agent of the remote client; may be empty if not specified * @return the unparsed user-agent of the remote client; may be empty if not specified
*/ */
public static Optional<String> getRawUserAgent() { public static Optional<String> getUserAgent() {
return Optional.ofNullable(RAW_USER_AGENT_CONTEXT_KEY.get()); return Optional.ofNullable(REQUEST_ATTRIBUTES_CONTEXT_KEY.get().userAgent());
} }
} }

View File

@ -0,0 +1,39 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
public class ServerInterceptorUtil {
@SuppressWarnings("rawtypes")
private static final ServerCall.Listener NO_OP_LISTENER = new ServerCall.Listener<>() {};
private static final Metadata EMPTY_TRAILERS = new Metadata();
private ServerInterceptorUtil() {
}
/**
* Closes the given server call with the given status, returning a no-op listener.
*
* @param call the server call to close
* @param status the status with which to close the call
*
* @return a no-op server call listener
*
* @param <ReqT> the type of request object handled by the server call
* @param <RespT> the type of response object returned by the server call
*/
public static <ReqT, RespT> ServerCall.Listener<ReqT> closeWithStatus(final ServerCall<ReqT, RespT> call, final Status status) {
call.close(status, EMPTY_TRAILERS);
//noinspection unchecked
return NO_OP_LISTENER;
}
}

View File

@ -12,8 +12,10 @@ import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
/** /**
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering * An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
@ -48,12 +50,12 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
@Override @Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) { public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) { if (event instanceof NoiseIdentityDeterminedEvent(final Optional<AuthenticatedDevice> authenticatedDevice)) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
// authenticated service. // authenticated service.
final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent() final LocalAddress grpcServerAddress = authenticatedDevice.isPresent()
? authenticatedGrpcServerAddress ? authenticatedGrpcServerAddress
: anonymousGrpcServerAddress; : anonymousGrpcServerAddress;
@ -72,7 +74,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
if (localChannelFuture.isSuccess()) { if (localChannelFuture.isSuccess()) {
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(), grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(), remoteChannelContext.channel(),
noiseIdentityDeterminedEvent.authenticatedDevice()); authenticatedDevice);
// Close the local connection if the remote channel closes and vice versa // Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close()); remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());

View File

@ -1,6 +1,8 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.Grpc;
import io.grpc.ServerCall;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
@ -23,15 +25,26 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent; import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.util.ClosableEpoch;
/** /**
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a * A client connection manager associates a local connection to a local gRPC server with a remote connection through a
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the * Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close * identity of the device that opened the connection (for non-anonymous connections). It can also close connections
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate. * associated with a given device if that device's credentials have changed and clients must reauthenticate.
* <p>
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
* narrow window between a server call being created and the start of call execution, in which case accessor methods
* in this class will throw a {@link ChannelNotFoundException}.
* <p>
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
* be called from any application code.
*/ */
public class GrpcClientConnectionManager implements DisconnectionRequestListener { public class GrpcClientConnectionManager implements DisconnectionRequestListener {
@ -43,94 +56,93 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice"); AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
@VisibleForTesting @VisibleForTesting
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY = public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress"); AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
@VisibleForTesting @VisibleForTesting
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY = static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent"); AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
@VisibleForTesting
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
@VisibleForTesting
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class); private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
/** /**
* Returns the authenticated device associated with the given local address, if any. An authenticated device is * Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
* available if and only if the given local address maps to an active local connection and that connection is * (i.e. unauthenticated), the returned value will be empty.
* authenticated (i.e. not anonymous).
* *
* @param localAddress the local address for which to find an authenticated device * @param serverCall the gRPC server call for which to find an authenticated device
* *
* @return the authenticated device associated with the given local address, if any * @return the authenticated device associated with the given local address, if any
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/ */
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) { public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall)
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress)); throws ChannelNotFoundException {
return getAuthenticatedDevice(getRemoteChannel(serverCall));
} }
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) { @VisibleForTesting
return Optional.ofNullable(remoteChannel) Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) {
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get()); return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
} }
/** /**
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may * Returns the request attributes associated with the given server call.
* be unavailable if the local connection associated with the given local address has already closed, if the client
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
* *
* @param localAddress the local address for which to find acceptable languages * @param serverCall the gRPC server call for which to retrieve request attributes
* *
* @return the acceptable languages associated with the given local address, if any * @return the request attributes associated with the given server call
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/ */
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) { public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) return getRequestAttributes(getRemoteChannel(serverCall));
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get()); }
@VisibleForTesting
RequestAttributes getRequestAttributes(final Channel remoteChannel) {
final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
if (requestAttributes == null) {
throw new IllegalStateException("Channel does not have request attributes");
}
return requestAttributes;
} }
/** /**
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if * Handles the start of a server call, incrementing the active call count for the remote channel associated with the
* the local connection associated with the given local address has already closed. * given server call.
* *
* @param localAddress the local address for which to find a remote address * @param serverCall the server call to start
* *
* @return the remote address associated with the given local address, if any * @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the
* underlying channel is closing
*/ */
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) { public boolean handleServerCallStart(final ServerCall<?, ?> serverCall) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) try {
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive();
} catch (final ChannelNotFoundException e) {
// This would only happen if the channel had already closed, which is certainly possible. In this case, the call
// should certainly not proceed.
return false;
}
} }
/** /**
* Returns the unparsed user agent provided by the client that opened the connection associated with the given local * Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel
* address. This method may return an empty value if no active local connection is associated with the given local * associated with the given server call.
* address.
* *
* @param localAddress the local address for which to find a User-Agent string * @param serverCall the server call to complete
*
* @return the user agent string associated with the given local address
*/ */
public Optional<String> getRawUserAgent(final LocalAddress localAddress) { public void handleServerCallComplete(final ServerCall<?, ?> serverCall) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) try {
.map(remoteChannel -> remoteChannel.attr(RAW_USER_AGENT_ATTRIBUTE_KEY).get()); getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart();
} catch (final ChannelNotFoundException ignored) {
// In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
} }
/**
* Returns the parsed user agent provided by the client that opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address or if the client's user-agent string was not recognized.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent associated with the given local address
*/
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
} }
/** /**
@ -145,10 +157,13 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
final List<Channel> channelsToClose = final List<Channel> channelsToClose =
new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList())); new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
channelsToClose.forEach(channel -> channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
}
private static void closeRemoteChannel(final Channel channel) {
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
.toWebSocketCloseStatus("Reauthentication required"))) .toWebSocketCloseStatus("Reauthentication required")))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE)); .addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} }
@VisibleForTesting @VisibleForTesting
@ -156,11 +171,32 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice); return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
} }
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRemoteChannel(getLocalAddress(serverCall));
}
@VisibleForTesting @VisibleForTesting
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) { Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
if (remoteChannel == null) {
throw new ChannelNotFoundException();
}
return remoteChannelsByLocalAddress.get(localAddress); return remoteChannelsByLocalAddress.get(localAddress);
} }
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
// a proxied Noise channel.
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
return localAddress;
}
/** /**
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake * Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
* request with the channel via which the handshake took place. * request with the channel via which the handshake took place.
@ -171,30 +207,23 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be * @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null} * {@code null}
*/ */
static void handleWebSocketHandshakeComplete(final Channel channel, static void handleHandshakeComplete(final Channel channel,
final InetAddress preferredRemoteAddress, final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader, @Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) { @Nullable final String acceptLanguageHeader) {
channel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress); @Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
if (StringUtils.isNotBlank(userAgentHeader)) {
channel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
try {
channel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
} catch (final UnrecognizedUserAgentException ignored) {
}
}
if (StringUtils.isNotBlank(acceptLanguageHeader)) { if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try { try {
channel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader)); acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
} catch (final IllegalArgumentException e) { } catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e); log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
} }
} }
channel.attr(REQUEST_ATTRIBUTES_KEY)
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
} }
/** /**
@ -212,6 +241,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
maybeAuthenticatedDevice.ifPresent(authenticatedDevice -> maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice)); remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
.set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel); remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice -> getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->

View File

@ -74,7 +74,7 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
preferredRemoteAddress = maybePreferredRemoteAddress.get(); preferredRemoteAddress = maybePreferredRemoteAddress.get();
} }
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(), GrpcClientConnectionManager.handleHandshakeComplete(context.channel(),
preferredRemoteAddress, preferredRemoteAddress,
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT), handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE)); handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));

View File

@ -0,0 +1,44 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import java.util.Optional;
public class DevicePlatformUtil {
private DevicePlatformUtil() {
}
/**
* Returns the most likely client platform for a device.
*
* @param device the device for which to find a client platform
*
* @return the most likely client platform for the given device or empty if no likely platform could be determined
*/
public static Optional<ClientPlatform> getDevicePlatform(final Device device) {
final Optional<ClientPlatform> clientPlatform;
if (StringUtils.isNotBlank(device.getGcmId())) {
clientPlatform = Optional.of(ClientPlatform.ANDROID);
} else if (StringUtils.isNotBlank(device.getApnId())) {
clientPlatform = Optional.of(ClientPlatform.IOS);
} else {
clientPlatform = Optional.empty();
}
return clientPlatform.or(() -> Optional.ofNullable(
switch (device.getUserAgent()) {
case "OWA" -> ClientPlatform.ANDROID;
case "OWI", "OWP" -> ClientPlatform.IOS;
case "OWD" -> ClientPlatform.DESKTOP;
case null, default -> null;
}));
}
}

View File

@ -1,32 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.sun.management.OperatingSystemMXBean;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.binder.MeterBinder;
import java.lang.management.ManagementFactory;
public class FreeMemoryGauge implements MeterBinder {
private final OperatingSystemMXBean operatingSystemMXBean;
public FreeMemoryGauge() {
this.operatingSystemMXBean = (com.sun.management.OperatingSystemMXBean)
ManagementFactory.getOperatingSystemMXBean();
}
@Override
public void bindTo(final MeterRegistry registry) {
Gauge.builder(name(FreeMemoryGauge.class, "freeMemory"), operatingSystemMXBean,
OperatingSystemMXBean::getFreeMemorySize)
.register(registry);
}
}

View File

@ -120,10 +120,7 @@ public class MetricsUtil {
public static void registerSystemResourceMetrics(final Environment environment) { public static void registerSystemResourceMetrics(final Environment environment) {
new ProcessorMetrics().bindTo(Metrics.globalRegistry); new ProcessorMetrics().bindTo(Metrics.globalRegistry);
new FreeMemoryGauge().bindTo(Metrics.globalRegistry);
new FileDescriptorMetrics().bindTo(Metrics.globalRegistry); new FileDescriptorMetrics().bindTo(Metrics.globalRegistry);
new OperatingSystemMemoryGauge("Buffers").bindTo(Metrics.globalRegistry);
new OperatingSystemMemoryGauge("Cached").bindTo(Metrics.globalRegistry);
new JvmMemoryMetrics().bindTo(Metrics.globalRegistry); new JvmMemoryMetrics().bindTo(Metrics.globalRegistry);
new JvmThreadMetrics().bindTo(Metrics.globalRegistry); new JvmThreadMetrics().bindTo(Metrics.globalRegistry);

View File

@ -67,7 +67,7 @@ public class OpenWebSocketCounter {
try { try {
final ClientPlatform clientPlatform = final ClientPlatform clientPlatform =
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).getPlatform(); UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).platform();
calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform); calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform);
calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform); calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform);

View File

@ -1,56 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.binder.MeterBinder;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
public class OperatingSystemMemoryGauge implements MeterBinder {
private final String metricName;
private static final File MEMINFO_FILE = new File("/proc/meminfo");
private static final Pattern MEMORY_METRIC_PATTERN = Pattern.compile("^([^:]+):\\s+([0-9]+).*$");
public OperatingSystemMemoryGauge(final String metricName) {
this.metricName = metricName;
}
@Override
public void bindTo(MeterRegistry registry) {
final String metricName = this.metricName;
Gauge.builder(name(OperatingSystemMemoryGauge.class, metricName.toLowerCase(Locale.ROOT)), () -> {
try (final BufferedReader bufferedReader = new BufferedReader(new FileReader(MEMINFO_FILE))) {
return getValue(bufferedReader.lines(), metricName);
} catch (final IOException e) {
return 0L;
}
})
.register(registry);
}
@VisibleForTesting
static double getValue(final Stream<String> lines, final String metricName) {
return lines.map(MEMORY_METRIC_PATTERN::matcher)
.filter(Matcher::matches)
.filter(matcher -> metricName.equalsIgnoreCase(matcher.group(1)))
.map(matcher -> Double.parseDouble(matcher.group(2)))
.findFirst()
.orElse(0d);
}
}

View File

@ -9,6 +9,7 @@ import io.micrometer.core.instrument.Tag;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.WhisperServerVersion; import org.whispersystems.textsecuregcm.WhisperServerVersion;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
@ -48,15 +49,15 @@ public class UserAgentTagUtil {
} }
public static Tag getPlatformTag(@Nullable final UserAgent userAgent) { public static Tag getPlatformTag(@Nullable final UserAgent userAgent) {
return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized"); return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized");
} }
public static Optional<Tag> getClientVersionTag(final String userAgentString, final ClientReleaseManager clientReleaseManager) { public static Optional<Tag> getClientVersionTag(final String userAgentString, final ClientReleaseManager clientReleaseManager) {
try { try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
if (clientReleaseManager.isVersionActive(userAgent.getPlatform(), userAgent.getVersion())) { if (clientReleaseManager.isVersionActive(userAgent.platform(), userAgent.version())) {
return Optional.of(Tag.of(VERSION_TAG, userAgent.getVersion().toString())); return Optional.of(Tag.of(VERSION_TAG, userAgent.version().toString()));
} }
} catch (final UnrecognizedUserAgentException ignored) { } catch (final UnrecognizedUserAgentException ignored) {
} }
@ -70,10 +71,8 @@ public class UserAgentTagUtil {
try { try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
platform = userAgent.getPlatform().name().toLowerCase(); platform = userAgent.platform().name().toLowerCase();
libsignal = userAgent.getAdditionalSpecifiers() libsignal = StringUtils.contains(userAgent.additionalSpecifiers(), "libsignal");
.map(additionalSpecifiers -> additionalSpecifiers.contains("libsignal"))
.orElse(false);
} catch (final UnrecognizedUserAgentException e) { } catch (final UnrecognizedUserAgentException e) {
platform = "unrecognized"; platform = "unrecognized";
libsignal = false; libsignal = false;

View File

@ -24,16 +24,19 @@ import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices; import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -52,10 +55,14 @@ public class MessageSender {
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
public static final String ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT = "androidSkipLowUrgencyPush";
// Note that these names deliberately reference `MessageController` for metric continuity // Note that these names deliberately reference `MessageController` for metric continuity
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage"); private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage");
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize"); private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize");
private static final String EMPTY_MESSAGE_LIST_COUNTER_NAME = MetricsUtil.name(MessageSender.class, "emptyMessageList");
private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage"); private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage");
private static final String EPHEMERAL_TAG_NAME = "ephemeral"; private static final String EPHEMERAL_TAG_NAME = "ephemeral";
@ -64,6 +71,7 @@ public class MessageSender {
private static final String STORY_TAG_NAME = "story"; private static final String STORY_TAG_NAME = "story";
private static final String SEALED_SENDER_TAG_NAME = "sealedSender"; private static final String SEALED_SENDER_TAG_NAME = "sealedSender";
private static final String MULTI_RECIPIENT_TAG_NAME = "multiRecipient"; private static final String MULTI_RECIPIENT_TAG_NAME = "multiRecipient";
private static final String SYNC_MESSAGE_TAG_NAME = "sync";
@VisibleForTesting @VisibleForTesting
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes(); public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
@ -71,9 +79,13 @@ public class MessageSender {
@VisibleForTesting @VisibleForTesting
static final byte NO_EXCLUDED_DEVICE_ID = -1; static final byte NO_EXCLUDED_DEVICE_ID = -1;
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) { public MessageSender(
final MessagesManager messagesManager,
final PushNotificationManager pushNotificationManager,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
this.experimentEnrollmentManager = experimentEnrollmentManager;
} }
/** /**
@ -105,6 +117,13 @@ public class MessageSender {
throw new IllegalArgumentException("Destination account not identified by destination service identifier"); throw new IllegalArgumentException("Destination account not identified by destination service identifier");
} }
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
if (messagesByDeviceId.isEmpty()) {
Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME,
Tags.of(SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment();
}
final byte excludedDeviceId; final byte excludedDeviceId;
if (syncMessageSenderDeviceId.isPresent()) { if (syncMessageSenderDeviceId.isPresent()) {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) || if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) ||
@ -137,7 +156,7 @@ public class MessageSender {
.forEach((deviceId, destinationPresent) -> { .forEach((deviceId, destinationPresent) -> {
final Envelope message = messagesByDeviceId.get(deviceId); final Envelope message = messagesByDeviceId.get(deviceId);
if (!destinationPresent && !message.getEphemeral()) { if (!destinationPresent && !message.getEphemeral() && !shouldSkipPush(destination, deviceId, message.getUrgent())) {
try { try {
pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent()); pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent());
} catch (final NotPushRegisteredException ignored) { } catch (final NotPushRegisteredException ignored) {
@ -150,13 +169,21 @@ public class MessageSender {
URGENT_TAG_NAME, String.valueOf(message.getUrgent()), URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()), STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()), SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()),
SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent()),
MULTI_RECIPIENT_TAG_NAME, "false") MULTI_RECIPIENT_TAG_NAME, "false")
.and(UserAgentTagUtil.getPlatformTag(userAgent)); .and(platformTag);
Metrics.counter(SEND_COUNTER_NAME, tags).increment(); Metrics.counter(SEND_COUNTER_NAME, tags).increment();
}); });
} }
private boolean shouldSkipPush(final Account account, byte deviceId, boolean urgent) {
final boolean isAndroidFcm = account.getDevice(deviceId).map(Device::getGcmId).isPresent();
return !urgent
&& isAndroidFcm
&& experimentEnrollmentManager.isEnrolled(account.getUuid(), ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT);
}
/** /**
* Sends messages to a group of recipients. If a destination device has a valid push notification token and does not * Sends messages to a group of recipients. If a destination device has a valid push notification token and does not
* have an active connection to a Signal server, then this method will also send a push notification to that device to * have an active connection to a Signal server, then this method will also send a push notification to that device to
@ -234,6 +261,7 @@ public class MessageSender {
URGENT_TAG_NAME, String.valueOf(isUrgent), URGENT_TAG_NAME, String.valueOf(isUrgent),
STORY_TAG_NAME, String.valueOf(isStory), STORY_TAG_NAME, String.valueOf(isStory),
SEALED_SENDER_TAG_NAME, "true", SEALED_SENDER_TAG_NAME, "true",
SYNC_MESSAGE_TAG_NAME, "false",
MULTI_RECIPIENT_TAG_NAME, "true") MULTI_RECIPIENT_TAG_NAME, "true")
.and(UserAgentTagUtil.getPlatformTag(userAgent)); .and(UserAgentTagUtil.getPlatformTag(userAgent));

View File

@ -38,8 +38,8 @@ public class SchedulingUtil {
final LocalTime preferredTime, final LocalTime preferredTime,
final Clock clock) { final Clock clock) {
final ZonedDateTime candidateNotificationTime = getZoneOffset(account, clock) final ZonedDateTime candidateNotificationTime = getZoneId(account, clock)
.map(zoneOffset -> ZonedDateTime.now(zoneOffset).with(preferredTime)) .map(zoneId -> ZonedDateTime.now(clock.withZone(zoneId)).with(preferredTime))
.orElseGet(() -> { .orElseGet(() -> {
// We couldn't find a reasonable timezone for the account for some reason, so make an educated guess at a // We couldn't find a reasonable timezone for the account for some reason, so make an educated guess at a
// reasonable time to send a notification based on the account's creation time. // reasonable time to send a notification based on the account's creation time.
@ -59,7 +59,7 @@ public class SchedulingUtil {
} }
@VisibleForTesting @VisibleForTesting
static Optional<ZoneOffset> getZoneOffset(final Account account, final Clock clock) { static Optional<ZoneId> getZoneId(final Account account, final Clock clock) {
try { try {
final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(account.getNumber(), null); final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(account.getNumber(), null);
@ -70,7 +70,7 @@ public class SchedulingUtil {
return Optional.empty(); return Optional.empty();
} }
final List<ZoneOffset> sortedZoneOffsets = timeZonesForNumber final List<ZoneId> sortedZoneOffsets = timeZonesForNumber
.stream() .stream()
.map(id -> { .map(id -> {
try { try {
@ -80,9 +80,6 @@ public class SchedulingUtil {
} }
}) })
.filter(Objects::nonNull) .filter(Objects::nonNull)
.map(ZoneId::getRules)
.distinct()
.map(zoneRules -> zoneRules.getOffset(clock.instant()))
.sorted() .sorted()
.toList(); .toList();

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.time.Clock;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -29,12 +30,16 @@ public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class); private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class);
private final MessageSender messageSender; private final MessageSender messageSender;
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final Clock clock;
public ChangeNumberManager( public ChangeNumberManager(
final MessageSender messageSender, final MessageSender messageSender,
final AccountsManager accountsManager) { final AccountsManager accountsManager,
final Clock clock) {
this.messageSender = messageSender; this.messageSender = messageSender;
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.clock = clock;
} }
public Account changeNumber(final Account account, final String number, public Account changeNumber(final Account account, final String number,
@ -97,7 +102,7 @@ public class ChangeNumberManager {
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException { final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
try { try {
final long serverTimestamp = System.currentTimeMillis(); final long serverTimestamp = clock.millis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid()); final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream() final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
@ -114,8 +119,8 @@ public class ChangeNumberManager {
.setEphemeral(false) .setEphemeral(false)
.build())); .build()));
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream() final Map<Byte, Integer> registrationIdsByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId)); .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
messageSender.sendMessages(account, messageSender.sendMessages(account,
serviceIdentifier, serviceIdentifier,

View File

@ -12,10 +12,13 @@ import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
@ -24,7 +27,10 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -37,6 +43,7 @@ public class MessagePersister implements Managed {
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager; private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final Duration persistDelay; private final Duration persistDelay;
@ -44,6 +51,8 @@ public class MessagePersister implements Managed {
private volatile boolean running; private volatile boolean running;
private static final String OVERSIZED_QUEUE_COUNTER_NAME = name(MessagePersister.class, "persistQueueOversized"); private static final String OVERSIZED_QUEUE_COUNTER_NAME = name(MessagePersister.class, "persistQueueOversized");
private static final String PERSISTED_MESSAGE_COUNTER_NAME = name(MessagePersister.class, "persistMessage");
private static final String PERSISTED_BYTES_COUNTER_NAME = name(MessagePersister.class, "persistBytes");
private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(MessagePersister.class, "getQueues")); private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(MessagePersister.class, "getQueues"));
private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(MessagePersister.class, "persistQueue")); private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(MessagePersister.class, "persistQueue"));
@ -57,10 +66,7 @@ public class MessagePersister implements Managed {
.publishPercentileHistogram(true) .publishPercentileHistogram(true)
.register(Metrics.globalRegistry); .register(Metrics.globalRegistry);
private static final DistributionSummary QUEUE_SIZE_DISTRIBUTION_SUMMARY = DistributionSummary.builder( private static final String QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME = name(MessagePersister.class, "queueSize");
name(MessagePersister.class, "queueSize"))
.publishPercentileHistogram(true)
.register(Metrics.globalRegistry);
static final int QUEUE_BATCH_LIMIT = 100; static final int QUEUE_BATCH_LIMIT = 100;
static final int MESSAGE_BATCH_LIMIT = 100; static final int MESSAGE_BATCH_LIMIT = 100;
@ -75,6 +81,7 @@ public class MessagePersister implements Managed {
final MessagesManager messagesManager, final MessagesManager messagesManager,
final AccountsManager accountsManager, final AccountsManager accountsManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ExperimentEnrollmentManager experimentEnrollmentManager,
final Duration persistDelay, final Duration persistDelay,
final int dedicatedProcessWorkerThreadCount) { final int dedicatedProcessWorkerThreadCount) {
@ -82,6 +89,7 @@ public class MessagePersister implements Managed {
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.persistDelay = persistDelay; this.persistDelay = persistDelay;
this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount]; this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount];
@ -139,6 +147,7 @@ public class MessagePersister implements Managed {
@VisibleForTesting @VisibleForTesting
int persistNextQueues(final Instant currentTime) { int persistNextQueues(final Instant currentTime) {
final int slot = messagesCache.getNextSlotToPersist(); final int slot = messagesCache.getNextSlotToPersist();
final String shard = messagesCache.shardForSlot(slot);
List<String> queuesToPersist; List<String> queuesToPersist;
int queuesPersisted = 0; int queuesPersisted = 0;
@ -162,10 +171,11 @@ public class MessagePersister implements Managed {
continue; continue;
} }
try { try {
persistQueue(maybeAccount.get(), maybeDevice.get()); persistQueue(maybeAccount.get(), maybeDevice.get(), shard);
} catch (final Exception e) { } catch (final Exception e) {
PERSIST_QUEUE_EXCEPTION_METER.increment(); PERSIST_QUEUE_EXCEPTION_METER.increment();
logger.warn("Failed to persist queue {}::{}; will schedule for retry", accountUuid, deviceId, e); logger.warn("Failed to persist queue {}::{} (slot {}, shard {}); will schedule for retry",
accountUuid, deviceId, slot, shard, e);
messagesCache.addQueueToPersist(accountUuid, deviceId); messagesCache.addQueueToPersist(accountUuid, deviceId);
@ -183,10 +193,14 @@ public class MessagePersister implements Managed {
} }
@VisibleForTesting @VisibleForTesting
void persistQueue(final Account account, final Device device) throws MessagePersistenceException { void persistQueue(final Account account, final Device device, final String shard) throws MessagePersistenceException {
final UUID accountUuid = account.getUuid(); final UUID accountUuid = account.getUuid();
final byte deviceId = device.getId(); final byte deviceId = device.getId();
final Tag platformTag = Tag.of("platform", DevicePlatformUtil.getDevicePlatform(device)
.map(platform -> platform.name().toLowerCase(Locale.ROOT))
.orElse("unknown"));
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
messagesCache.lockQueueForPersistence(accountUuid, deviceId); messagesCache.lockQueueForPersistence(accountUuid, deviceId);
@ -200,6 +214,16 @@ public class MessagePersister implements Managed {
do { do {
messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT); messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT);
final int urgentMessageCount = (int) messages.stream().filter(MessageProtos.Envelope::getUrgent).count();
final int nonUrgentMessageCount = messages.size() - urgentMessageCount;
final Tags tags = Tags.of(platformTag, Tag.of("shard", shard));
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "true")).increment(urgentMessageCount);
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "false")).increment(nonUrgentMessageCount);
Metrics.counter(PERSISTED_BYTES_COUNTER_NAME, tags)
.increment(messages.stream().mapToInt(MessageProtos.Envelope::getSerializedSize).sum());
int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, device, messages); int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, device, messages);
messageCount += messages.size(); messageCount += messages.size();
@ -215,7 +239,14 @@ public class MessagePersister implements Managed {
} while (!messages.isEmpty()); } while (!messages.isEmpty());
QUEUE_SIZE_DISTRIBUTION_SUMMARY.record(messageCount); final boolean inSkipExperiment = device.getGcmId() != null && experimentEnrollmentManager.isEnrolled(
accountUuid,
MessageSender.ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT);
DistributionSummary.builder(QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME)
.tags(Tags.of(platformTag).and("lowUrgencySkip", Boolean.toString(inSkipExperiment)))
.publishPercentileHistogram(true)
.register(Metrics.globalRegistry)
.record(messageCount);
} catch (ItemCollectionSizeLimitExceededException e) { } catch (ItemCollectionSizeLimitExceededException e) {
final boolean isPrimary = deviceId == Device.PRIMARY_ID; final boolean isPrimary = deviceId == Device.PRIMARY_ID;
Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment(); Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment();
@ -234,7 +265,6 @@ public class MessagePersister implements Managed {
messagesCache.unlockQueueForPersistence(accountUuid, deviceId); messagesCache.unlockQueueForPersistence(accountUuid, deviceId);
sample.stop(PERSIST_QUEUE_TIMER); sample.stop(PERSIST_QUEUE_TIMER);
} }
} }
private void trimQueue(final Account account, byte deviceId) { private void trimQueue(final Account account, byte deviceId) {

View File

@ -15,6 +15,8 @@ import io.lettuce.core.Range;
import io.lettuce.core.ScoredValue; import io.lettuce.core.ScoredValue;
import io.lettuce.core.ZAddArgs; import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.models.partitions.ClusterPartitionParser;
import io.lettuce.core.cluster.models.partitions.Partitions;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
@ -668,6 +670,15 @@ public class MessagesCache {
.thenRun(() -> sample.stop(clearQueueTimer)); .thenRun(() -> sample.stop(clearQueueTimer));
} }
public String shardForSlot(int slot) {
try {
return redisCluster.withBinaryCluster(
connection -> connection.getPartitions().getPartitionBySlot(slot).getUri().getHost());
} catch (Throwable ignored) {
return "unknown";
}
}
int getNextSlotToPersist() { int getNextSlotToPersist() {
return (int) (redisCluster.withCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY)) return (int) (redisCluster.withCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY))
% SlotHash.SLOT_COUNT); % SlotHash.SLOT_COUNT);

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.subscriptions;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.stripe.Stripe;
import com.stripe.StripeClient; import com.stripe.StripeClient;
import com.stripe.exception.CardException; import com.stripe.exception.CardException;
import com.stripe.exception.IdempotencyException; import com.stripe.exception.IdempotencyException;
@ -71,6 +72,7 @@ import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerVersion;
import org.whispersystems.textsecuregcm.storage.PaymentTime; import org.whispersystems.textsecuregcm.storage.PaymentTime;
import org.whispersystems.textsecuregcm.storage.SubscriptionException; import org.whispersystems.textsecuregcm.storage.SubscriptionException;
import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Conversions;
@ -97,6 +99,9 @@ public class StripeManager implements CustomerAwareSubscriptionPaymentProcessor
if (Strings.isNullOrEmpty(apiKey)) { if (Strings.isNullOrEmpty(apiKey)) {
throw new IllegalArgumentException("apiKey cannot be empty"); throw new IllegalArgumentException("apiKey cannot be empty");
} }
Stripe.setAppInfo("Signal-Server", WhisperServerVersion.getServerVersion());
this.stripeClient = new StripeClient(apiKey); this.stripeClient = new StripeClient(apiKey);
this.executor = Objects.requireNonNull(executor); this.executor = Objects.requireNonNull(executor);
this.idempotencyKeyGenerator = Objects.requireNonNull(idempotencyKeyGenerator); this.idempotencyKeyGenerator = Objects.requireNonNull(idempotencyKeyGenerator);

View File

@ -0,0 +1,92 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.google.common.annotations.VisibleForTesting;
import java.util.concurrent.atomic.AtomicInteger;
/**
* A closable epoch is a concurrency construct that measures the number of callers in some critical section. A closable
* epoch can be closed to prevent new callers from entering the critical section, and takes a specific action when the
* critical section is empty after closure.
*/
public class ClosableEpoch {
private final Runnable onCloseHandler;
private final AtomicInteger state = new AtomicInteger();
private static final int CLOSING_BIT_MASK = 0x00000001;
/**
* Constructs a new closable epoch that will execute the given handler when the epoch is closed and all callers have
* departed the critical section. The handler will be executed on the thread that calls {@link #close()} if the
* critical section is empty at the time of the call or on the last thread to call {@link #depart()} otherwise.
* Callers should provide handlers that delegate execution to a specific thread/executor if more precise control over
* which thread runs the handler is required.
*
* @param onCloseHandler a handler to run when the epoch is closed and all callers have departed the critical section
*/
public ClosableEpoch(final Runnable onCloseHandler) {
this.onCloseHandler = onCloseHandler;
}
/**
* Announces the arrival of a caller at the start of the critical section. If the caller is allowed to enter the
* critical section, the epoch's internal caller counter is incremented accordingly.
*
* @return {@code true} if the caller is allowed to enter the critical section or {@code false} if it is not allowed
* to enter the critical section because this epoch is closing
*/
public boolean tryArrive() {
// Increment the number of active callers if and only if we're not closing. We add 2 because the lowest bit encodes
// the "closing" state, and the bits above it encode the actual call count. More verbosely, we're doing
// `state += (1 << 1)` to avoid overwriting the closing state bit.
return !isClosing(state.updateAndGet(state -> isClosing(state) ? state : state + 2));
}
/**
* Announces the departure of a caller from the critical section. If the epoch is closing and the caller is the last
* to depart the critical section, then the epoch will fire its {@code onCloseHandler}.
*/
public void depart() {
// Decrement the active caller count unconditionally. As with `tryActive`, we work in increments of 2 to "dodge" the
// "is closing" bit. If the call count is zero and we're closing then `state` will just have the "closing" bit set.
if (state.addAndGet(-2) == CLOSING_BIT_MASK) {
onCloseHandler.run();
}
}
/**
* Closes this epoch, preventing new callers from entering the critical section. If the critical section is empty when
* this method is called, it will trigger the {@code onCloseHandler} immediately. Otherwise, the
* {@code onCloseHandler} will fire when the last caller departs the critical section.
*
* @throws IllegalStateException if this epoch is already closed; note that this exception is thrown on a
* "best-effort" basis to help callers detect bugs
*/
public void close() {
// Note that this is not airtight and is a "best-effort" check
if (isClosing(state.get())) {
throw new IllegalStateException("Epoch already closed");
}
// Set the "closing" bit. If the closing bit is the only bit set, then the call count is zero and we can call the
// "on close" handler.
if (state.updateAndGet(state -> state | CLOSING_BIT_MASK) == CLOSING_BIT_MASK) {
onCloseHandler.run();
}
}
@VisibleForTesting
int getActiveCallers() {
return state.get() >> 1;
}
private static boolean isClosing(final int state) {
return (state & CLOSING_BIT_MASK) != 0;
}
}

View File

@ -46,7 +46,7 @@ public class LoggingUnhandledExceptionMapper extends LoggingExceptionMapper<Thro
// streamline the user-agent if it is recognized // streamline the user-agent if it is recognized
final UserAgent ua = UserAgentUtil.parseUserAgentString(userAgent); final UserAgent ua = UserAgentUtil.parseUserAgentString(userAgent);
userAgent = String.format("%s %s", ua.getPlatform(), ua.getVersion()); userAgent = String.format("%s %s", ua.platform(), ua.version());
} catch (final UnrecognizedUserAgentException ignored) { } catch (final UnrecognizedUserAgentException ignored) {
} catch (final Exception e) { } catch (final Exception e) {

View File

@ -6,58 +6,8 @@
package org.whispersystems.textsecuregcm.util.ua; package org.whispersystems.textsecuregcm.util.ua;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import javax.annotation.Nullable;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
public class UserAgent { public record UserAgent(ClientPlatform platform, Semver version, @Nullable String additionalSpecifiers) {
private final ClientPlatform platform;
private final Semver version;
private final String additionalSpecifiers;
public UserAgent(final ClientPlatform platform, final Semver version) {
this(platform, version, null);
}
public UserAgent(final ClientPlatform platform, final Semver version, final String additionalSpecifiers) {
this.platform = platform;
this.version = version;
this.additionalSpecifiers = additionalSpecifiers;
}
public ClientPlatform getPlatform() {
return platform;
}
public Semver getVersion() {
return version;
}
public Optional<String> getAdditionalSpecifiers() {
return Optional.ofNullable(additionalSpecifiers);
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final UserAgent userAgent = (UserAgent)o;
return platform == userAgent.platform &&
version.equals(userAgent.version) &&
Objects.equals(additionalSpecifiers, userAgent.additionalSpecifiers);
}
@Override
public int hashCode() {
return Objects.hash(platform, version, additionalSpecifiers);
}
@Override
public String toString() {
return "UserAgent{" +
"platform=" + platform +
", version=" + version +
", additionalSpecifiers='" + additionalSpecifiers + '\'' +
'}';
}
} }

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.util.ua; package org.whispersystems.textsecuregcm.util.ua;
import com.google.common.annotations.VisibleForTesting;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@ -21,10 +20,10 @@ public class UserAgentUtil {
} }
try { try {
final UserAgent standardUserAgent = parseStandardUserAgentString(userAgentString); final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);
if (standardUserAgent != null) { if (matcher.matches()) {
return standardUserAgent; return new UserAgent(ClientPlatform.valueOf(matcher.group(1).toUpperCase()), new Semver(matcher.group(2)), StringUtils.stripToNull(matcher.group(4)));
} }
} catch (final Exception e) { } catch (final Exception e) {
throw new UnrecognizedUserAgentException(e); throw new UnrecognizedUserAgentException(e);
@ -32,15 +31,4 @@ public class UserAgentUtil {
throw new UnrecognizedUserAgentException(); throw new UnrecognizedUserAgentException();
} }
@VisibleForTesting
static UserAgent parseStandardUserAgentString(final String userAgentString) {
final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);
if (matcher.matches()) {
return new UserAgent(ClientPlatform.valueOf(matcher.group(1).toUpperCase()), new Semver(matcher.group(2)), StringUtils.stripToNull(matcher.group(4)));
}
return null;
}
} }

View File

@ -12,6 +12,7 @@ import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
@ -45,6 +46,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final Scheduler messageDeliveryScheduler; private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager; private final ClientReleaseManager clientReleaseManager;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter; private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter; private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
@ -58,7 +60,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager, ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) { MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.receiptSender = receiptSender; this.receiptSender = receiptSender;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics; this.messageMetrics = messageMetrics;
@ -69,6 +72,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.messageDeliveryScheduler = messageDeliveryScheduler; this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
this.experimentEnrollmentManager = experimentEnrollmentManager;
openAuthenticatedWebSocketCounter = openAuthenticatedWebSocketCounter =
new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, CONNECTED_DURATION_TIMER_NAME, Tags.of(AUTHENTICATED_TAG_NAME, "true")); new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, CONNECTED_DURATION_TIMER_NAME, Tags.of(AUTHENTICATED_TAG_NAME, "true"));
@ -98,7 +102,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
scheduledExecutorService, scheduledExecutorService,
messageDeliveryScheduler, messageDeliveryScheduler,
clientReleaseManager, clientReleaseManager,
messageDeliveryLoopMonitor); messageDeliveryLoopMonitor,
experimentEnrollmentManager);
context.addWebsocketClosedListener((closingContext, statusCode, reason) -> { context.addWebsocketClosedListener((closingContext, statusCode, reason) -> {
// We begin the shutdown process by removing this client's "presence," which means it will again begin to // We begin the shutdown process by removing this client's "presence," which means it will again begin to

View File

@ -39,6 +39,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
@ -46,6 +47,7 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
@ -120,6 +122,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler; private final PushNotificationScheduler pushNotificationScheduler;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final AuthenticatedDevice auth; private final AuthenticatedDevice auth;
private final WebSocketClient client; private final WebSocketClient client;
@ -159,7 +162,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager, ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) { MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
ExperimentEnrollmentManager experimentEnrollmentManager) {
this(receiptSender, this(receiptSender,
messagesManager, messagesManager,
@ -172,7 +176,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
scheduledExecutorService, scheduledExecutorService,
messageDeliveryScheduler, messageDeliveryScheduler,
clientReleaseManager, clientReleaseManager,
messageDeliveryLoopMonitor); messageDeliveryLoopMonitor, experimentEnrollmentManager);
} }
@VisibleForTesting @VisibleForTesting
@ -187,7 +191,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager, ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) { MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
ExperimentEnrollmentManager experimentEnrollmentManager) {
this.receiptSender = receiptSender; this.receiptSender = receiptSender;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
@ -201,6 +206,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
this.messageDeliveryScheduler = messageDeliveryScheduler; this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
this.experimentEnrollmentManager = experimentEnrollmentManager;
} }
public void start() { public void start() {
@ -331,7 +337,13 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
// Cleared the queue! Send a queue empty message if we need to // Cleared the queue! Send a queue empty message if we need to
consecutiveRetries.set(0); consecutiveRetries.set(0);
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); final boolean inSkipExperiment = auth.getAuthenticatedDevice().getGcmId() != null && experimentEnrollmentManager.isEnrolled(
auth.getAccount().getUuid(),
MessageSender.ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT);
final Tags tags = Tags
.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()))
.and("lowUrgencySkip", Boolean.toString(inSkipExperiment));
final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get(); final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get();
Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum());

View File

@ -14,6 +14,7 @@ import java.time.Duration;
import net.sourceforge.argparse4j.inf.Namespace; import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser; import net.sourceforge.argparse4j.inf.Subparser;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.MessagePersister; import org.whispersystems.textsecuregcm.storage.MessagePersister;
import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler; import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler;
@ -64,6 +65,7 @@ public class MessagePersisterServiceCommand extends ServerCommand<WhisperServerC
deps.messagesManager(), deps.messagesManager(),
deps.accountsManager(), deps.accountsManager(),
deps.dynamicConfigurationManager(), deps.dynamicConfigurationManager(),
new ExperimentEnrollmentManager(deps.dynamicConfigurationManager()),
Duration.ofMinutes(configuration.getMessageCacheConfiguration().getPersistDelayMinutes()), Duration.ofMinutes(configuration.getMessageCacheConfiguration().getPersistDelayMinutes()),
namespace.getInt(WORKER_COUNT)); namespace.getInt(WORKER_COUNT));

View File

@ -5,8 +5,11 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; import static com.github.tomakehurst.wiremock.client.WireMock.created;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -15,17 +18,19 @@ import static org.mockito.Mockito.when;
import com.github.tomakehurst.wiremock.junit5.WireMockExtension; import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import io.netty.resolver.dns.DnsNameResolver; import io.netty.resolver.dns.DnsNameResolver;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GlobalEventExecutor;
import io.netty.util.concurrent.SucceededFuture;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.security.cert.CertificateException; import java.time.Duration;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -35,31 +40,41 @@ import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
public class CloudflareTurnCredentialsManagerTest { public class CloudflareTurnCredentialsManagerTest {
@RegisterExtension @RegisterExtension
private final WireMockExtension wireMock = WireMockExtension.newInstance() private static final WireMockExtension wireMock = WireMockExtension.newInstance()
.options(wireMockConfig().dynamicPort().dynamicHttpsPort()) .options(wireMockConfig().dynamicPort().dynamicHttpsPort())
.build(); .build();
private static final String GET_CREDENTIALS_PATH = "/v1/turn/keys/LMNOP/credentials/generate";
private static final String TURN_HOSTNAME = "localhost";
private ExecutorService httpExecutor; private ExecutorService httpExecutor;
private ScheduledExecutorService retryExecutor; private ScheduledExecutorService retryExecutor;
private DnsNameResolver dnsResolver; private DnsNameResolver dnsResolver;
private Future<List<InetAddress>> dnsResult;
private CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = null; private CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
private static final String GET_CREDENTIALS_PATH = "/v1/turn/keys/LMNOP/credentials/generate";
private static final String TURN_HOSTNAME = "localhost";
private static final String API_TOKEN = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final String USERNAME = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final String CREDENTIAL = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final List<String> CLOUDFLARE_TURN_URLS = List.of("turn:cf.example.com");
private static final Duration REQUESTED_CREDENTIAL_TTL = Duration.ofSeconds(100);
private static final Duration CLIENT_CREDENTIAL_TTL = REQUESTED_CREDENTIAL_TTL.dividedBy(2);
private static final List<String> IP_URL_PATTERNS = List.of("turn:%s", "turn:%s:80?transport=tcp", "turns:%s:443?transport=tcp");
@BeforeEach @BeforeEach
void setUp() throws CertificateException { void setUp() {
httpExecutor = Executors.newSingleThreadExecutor(); httpExecutor = Executors.newSingleThreadExecutor();
retryExecutor = Executors.newSingleThreadScheduledExecutor(); retryExecutor = Executors.newSingleThreadScheduledExecutor();
dnsResolver = mock(DnsNameResolver.class); dnsResolver = mock(DnsNameResolver.class);
dnsResult = mock(Future.class);
cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager( cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
"API_TOKEN", API_TOKEN,
"http://localhost:" + wireMock.getPort() + GET_CREDENTIALS_PATH, "http://localhost:" + wireMock.getPort() + GET_CREDENTIALS_PATH,
100, REQUESTED_CREDENTIAL_TTL,
List.of("turn:cf.example.com"), CLIENT_CREDENTIAL_TTL,
List.of("turn:%s", "turn:%s:80?transport=tcp", "turns:%s:443?transport=tcp"), CLOUDFLARE_TURN_URLS,
IP_URL_PATTERNS,
TURN_HOSTNAME, TURN_HOSTNAME,
2, 2,
new CircuitBreakerConfiguration(), new CircuitBreakerConfiguration(),
@ -73,26 +88,61 @@ public class CloudflareTurnCredentialsManagerTest {
@AfterEach @AfterEach
void tearDown() throws InterruptedException { void tearDown() throws InterruptedException {
httpExecutor.shutdown(); httpExecutor.shutdown();
httpExecutor.awaitTermination(1, TimeUnit.SECONDS);
retryExecutor.shutdown(); retryExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
httpExecutor.awaitTermination(1, TimeUnit.SECONDS);
//noinspection ResultOfMethodCallIgnored
retryExecutor.awaitTermination(1, TimeUnit.SECONDS); retryExecutor.awaitTermination(1, TimeUnit.SECONDS);
} }
@Test @Test
public void testSuccess() throws IOException, CancellationException, ExecutionException, InterruptedException { public void testSuccess() throws IOException, CancellationException {
wireMock.stubFor(post(urlEqualTo(GET_CREDENTIALS_PATH)) wireMock.stubFor(post(urlEqualTo(GET_CREDENTIALS_PATH))
.willReturn(aResponse().withStatus(201).withHeader("Content-Type", new String[]{"application/json"}).withBody("{\"iceServers\":{\"urls\":[\"turn:cloudflare.example.com:3478?transport=udp\"],\"username\":\"ABC\",\"credential\":\"XYZ\"}}"))); .willReturn(created()
when(dnsResult.get()) .withHeader("Content-Type", "application/json")
.thenReturn(List.of(InetAddress.getByName("127.0.0.1"), InetAddress.getByName("::1"))); .withBody("""
{
"iceServers": {
"urls": [
"turn:cloudflare.example.com:3478?transport=udp"
],
"username": "%s",
"credential": "%s"
}
}
""".formatted(USERNAME, CREDENTIAL))));
when(dnsResolver.resolveAll(TURN_HOSTNAME)) when(dnsResolver.resolveAll(TURN_HOSTNAME))
.thenReturn(dnsResult); .thenReturn(new SucceededFuture<>(GlobalEventExecutor.INSTANCE,
List.of(InetAddress.getByName("127.0.0.1"), InetAddress.getByName("::1"))));
TurnToken token = cloudflareTurnCredentialsManager.retrieveFromCloudflare(); TurnToken token = cloudflareTurnCredentialsManager.retrieveFromCloudflare();
assertThat(token.username()).isEqualTo("ABC"); wireMock.verify(postRequestedFor(urlEqualTo(GET_CREDENTIALS_PATH))
assertThat(token.password()).isEqualTo("XYZ"); .withHeader("Content-Type", equalTo("application/json"))
assertThat(token.hostname()).isEqualTo("localhost"); .withHeader("Authorization", equalTo("Bearer " + API_TOKEN))
assertThat(token.urlsWithIps()).containsAll(List.of("turn:127.0.0.1", "turn:127.0.0.1:80?transport=tcp", "turns:127.0.0.1:443?transport=tcp", "turn:[0:0:0:0:0:0:0:1]", "turn:[0:0:0:0:0:0:0:1]:80?transport=tcp", "turns:[0:0:0:0:0:0:0:1]:443?transport=tcp"));; .withRequestBody(equalToJson("""
assertThat(token.urls()).isEqualTo(List.of("turn:cf.example.com")); {
"ttl": %d
}
""".formatted(REQUESTED_CREDENTIAL_TTL.toSeconds()))));
assertThat(token.username()).isEqualTo(USERNAME);
assertThat(token.password()).isEqualTo(CREDENTIAL);
assertThat(token.hostname()).isEqualTo(TURN_HOSTNAME);
assertThat(token.urls()).isEqualTo(CLOUDFLARE_TURN_URLS);
assertThat(token.ttlSeconds()).isEqualTo(CLIENT_CREDENTIAL_TTL.toSeconds());
final List<String> expectedUrlsWithIps = new ArrayList<>();
for (final String ip : new String[] {"127.0.0.1", "[0:0:0:0:0:0:0:1]"}) {
for (final String pattern : IP_URL_PATTERNS) {
expectedUrlsWithIps.add(pattern.formatted(ip));
}
}
assertThat(token.urlsWithIps()).containsExactlyElementsOf(expectedUrlsWithIps);
} }
} }

View File

@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Status; import io.grpc.Status;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -22,7 +23,7 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
} }
@Test @Test
void interceptCall() { void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
@ -34,6 +35,10 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
} }
} }

View File

@ -9,6 +9,7 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -22,12 +23,12 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
} }
@Test @Test
void interceptCall() { void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
@ -35,5 +36,9 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice(); final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier()); assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId()); assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
} }
} }

View File

@ -40,14 +40,15 @@ class CallRoutingControllerV2Test {
private static final TurnToken CLOUDFLARE_TURN_TOKEN = new TurnToken( private static final TurnToken CLOUDFLARE_TURN_TOKEN = new TurnToken(
"ABC", "ABC",
"XYZ", "XYZ",
43_200,
List.of("turn:cloudflare.example.com:3478?transport=udp"), List.of("turn:cloudflare.example.com:3478?transport=udp"),
null, null,
"cf.example.com"); "cf.example.com");
private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter getCallEndpointLimiter = mock(RateLimiter.class); private static final RateLimiter getCallEndpointLimiter = mock(RateLimiter.class);
private static final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = mock( private static final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager =
CloudflareTurnCredentialsManager.class); mock(CloudflareTurnCredentialsManager.class);
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
@ -66,21 +67,14 @@ class CallRoutingControllerV2Test {
@AfterEach @AfterEach
void tearDown() { void tearDown() {
reset( rateLimiters, getCallEndpointLimiter); reset(rateLimiters, getCallEndpointLimiter);
}
void initializeMocksWith(TurnToken cloudflareToken) {
try {
when(cloudflareTurnCredentialsManager.retrieveFromCloudflare()).thenReturn(cloudflareToken);
} catch (IOException ignored) {
}
} }
@Test @Test
void testGetRelaysBothRouting() { void testGetRelaysBothRouting() throws IOException {
initializeMocksWith(CLOUDFLARE_TURN_TOKEN); when(cloudflareTurnCredentialsManager.retrieveFromCloudflare()).thenReturn(CLOUDFLARE_TURN_TOKEN);
try (Response rawResponse = resources.getJerseyTest() try (final Response rawResponse = resources.getJerseyTest()
.target(GET_CALL_RELAYS_PATH) .target(GET_CALL_RELAYS_PATH)
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
@ -88,11 +82,8 @@ class CallRoutingControllerV2Test {
assertThat(rawResponse.getStatus()).isEqualTo(200); assertThat(rawResponse.getStatus()).isEqualTo(200);
CallRoutingControllerV2.GetCallingRelaysResponse response = rawResponse.readEntity( assertThat(rawResponse.readEntity(GetCallingRelaysResponse.class).relays())
CallRoutingControllerV2.GetCallingRelaysResponse.class); .isEqualTo(List.of(CLOUDFLARE_TURN_TOKEN));
List<TurnToken> relays = response.relays();
assertThat(relays).isEqualTo(List.of(CLOUDFLARE_TURN_TOKEN));
} }
} }

View File

@ -41,6 +41,7 @@ import java.util.Optional;
import java.util.OptionalInt; import java.util.OptionalInt;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
@ -942,6 +943,45 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getStatus()).isEqualTo(400);
} }
@Test
void putKeysTooManySingleUseECKeys() {
final List<ECPreKey> preKeys = IntStream.range(31337, 31438).mapToObj(KeysHelper::ecPreKey).toList();
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, AuthHelper.VALID_IDENTITY_KEY_PAIR);
final SetKeysRequest setKeysRequest = new SetKeysRequest(preKeys, signedPreKey, null, null);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(setKeysRequest, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
verifyNoMoreInteractions(KEYS);
}
@Test
void putKeysTooManySingleUseKEMKeys() {
final List<KEMSignedPreKey> pqPreKeys = IntStream.range(31337, 31438)
.mapToObj(id -> KeysHelper.signedKEMPreKey(id, AuthHelper.VALID_IDENTITY_KEY_PAIR))
.toList();
final SetKeysRequest setKeysRequest = new SetKeysRequest(null, null, pqPreKeys, null);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(setKeysRequest, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
verifyNoMoreInteractions(KEYS);
}
@Test @Test
void putKeysByPhoneNumberIdentifierTestV2() { void putKeysByPhoneNumberIdentifierTestV2() {
final ECPreKey preKey = KeysHelper.ecPreKey(31337); final ECPreKey preKey = KeysHelper.ecPreKey(31337);

View File

@ -1140,6 +1140,58 @@ class ProfileControllerTest {
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false), new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false)); new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false));
} }
}
@Test
void testSetProfileBadgeAfterUpdateTries() throws Exception {
final ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(
new ServiceId.Aci(AuthHelper.VALID_UUID));
final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156);
final String version = versionHex("anotherversion");
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
reset(accountsManager);
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(AuthHelper.VALID_ACCOUNT_TWO.getBadges())
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true)
))
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), true)
));
try (final Response response = resources.getJerseyTest()
.target("/v1/profile/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))
.put(Entity.entity(new CreateProfileRequest(commitment, version, name, emoji, about, null, false, false,
Optional.of(List.of("TEST1")), null), MediaType.APPLICATION_JSON_TYPE))) {
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.hasEntity()).isFalse();
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).getBadges();
final List<AccountBadge> badges = badgeCaptor.getValue();
assertThat(badges).isNotNull().hasSize(4).containsOnly(
new AccountBadge("TEST1", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), false));
}
} }
@ParameterizedTest @ParameterizedTest

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.dropwizard.core.Application; import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration; import io.dropwizard.core.Configuration;
@ -24,7 +25,6 @@ import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Client; import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import java.net.InetAddress;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.Set; import java.util.Set;
@ -39,6 +39,7 @@ import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.InetAddressRange; import org.whispersystems.textsecuregcm.util.InetAddressRange;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@ -157,7 +158,7 @@ class ExternalRequestFilterTest {
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRemoteAddress(InetAddress.getByName("127.0.0.1")); mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest") testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest")
.directExecutor() .directExecutor()

View File

@ -15,6 +15,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
@ -40,11 +41,10 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.grpc.StatusConstants; import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RemoteDeprecationFilterTest { class RemoteDeprecationFilterTest {
@ -130,11 +130,7 @@ class RemoteDeprecationFilterTest {
@MethodSource(value="testFilter") @MethodSource(value="testFilter")
void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException { void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), userAgentString, null));
try {
mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString));
} catch (UnrecognizedUserAgentException ignored) {
}
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest") final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor() .directExecutor()

View File

@ -72,7 +72,8 @@ class AccountsAnonymousGrpcServiceTest extends
when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1")); getMockRequestAttributesInterceptor().setRequestAttributes(
new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
return new AccountsAnonymousGrpcService(accountsManager, rateLimiters); return new AccountsAnonymousGrpcService(accountsManager, rateLimiters);
} }

View File

@ -0,0 +1,88 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
class ChannelShutdownInterceptorTest {
private GrpcClientConnectionManager grpcClientConnectionManager;
private ChannelShutdownInterceptor channelShutdownInterceptor;
private ServerCallHandler<String, String> nextCallHandler;
private static final Metadata HEADERS = new Metadata();
@BeforeEach
void setUp() {
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager);
//noinspection unchecked
nextCallHandler = mock(ServerCallHandler.class);
//noinspection unchecked
when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class));
}
@Test
void interceptCallComplete() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onComplete();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallCancelled() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onCancel();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallChannelClosing() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false);
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall);
verify(serverCall).close(eq(Status.UNAVAILABLE), any());
}
}

View File

@ -12,14 +12,38 @@ import org.signal.chat.rpc.EchoServiceGrpc;
public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase { public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase {
@Override @Override
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) { public void echo(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build()); responseObserver.onNext(buildResponse(echoRequest));
responseObserver.onCompleted(); responseObserver.onCompleted();
} }
@Override @Override
public void echo2(EchoRequest req, StreamObserver<EchoResponse> responseObserver) { public void echo2(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build()); responseObserver.onNext(buildResponse(echoRequest));
responseObserver.onCompleted(); responseObserver.onCompleted();
} }
@Override
public StreamObserver<EchoRequest> echoStream(final StreamObserver<EchoResponse> responseObserver) {
return new StreamObserver<>() {
@Override
public void onNext(final EchoRequest echoRequest) {
responseObserver.onNext(buildResponse(echoRequest));
}
@Override
public void onError(final Throwable throwable) {
responseObserver.onError(throwable);
}
@Override
public void onCompleted() {
responseObserver.onCompleted();
}
};
}
private static EchoResponse buildResponse(final EchoRequest echoRequest) {
return EchoResponse.newBuilder().setPayload(echoRequest.getPayload()).build();
}
} }

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import com.google.common.net.InetAddresses;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Contexts; import io.grpc.Contexts;
import io.grpc.Metadata; import io.grpc.Metadata;
@ -19,25 +20,10 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class MockRequestAttributesInterceptor implements ServerInterceptor { public class MockRequestAttributesInterceptor implements ServerInterceptor {
@Nullable private RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null);
private InetAddress remoteAddress;
@Nullable public void setRequestAttributes(final RequestAttributes requestAttributes) {
private UserAgent userAgent; this.requestAttributes = requestAttributes;
@Nullable
private List<Locale.LanguageRange> acceptLanguage;
public void setRemoteAddress(@Nullable final InetAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
public void setUserAgent(@Nullable final UserAgent userAgent) {
this.userAgent = userAgent;
}
public void setAcceptLanguage(@Nullable final List<Locale.LanguageRange> acceptLanguage) {
this.acceptLanguage = acceptLanguage;
} }
@Override @Override
@ -45,20 +31,7 @@ public class MockRequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
Context context = Context.current(); return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes), serverCall, headers, next);
if (remoteAddress != null) {
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress);
}
if (userAgent != null) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent);
}
if (acceptLanguage != null) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage);
}
return Contexts.interceptCall(context, serverCall, headers, next);
} }
} }

View File

@ -15,6 +15,7 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -75,8 +76,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> { public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@ -96,13 +95,9 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
@Override @Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() { protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us")); getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
try { Locale.LanguageRange.parse("en-us")));
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
return new ProfileAnonymousGrpcService( return new ProfileAnonymousGrpcService(
accountsManager, accountsManager,

View File

@ -15,13 +15,16 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.refEq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber; import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
@ -30,6 +33,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -93,18 +97,18 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceCapability; import org.whispersystems.textsecuregcm.storage.DeviceCapability;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@ -144,6 +148,8 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
@Mock @Mock
private ServerZkProfileOperations serverZkProfileOperations; private ServerZkProfileOperations serverZkProfileOperations;
private Clock clock;
@Override @Override
protected ProfileGrpcService createServiceBeforeEachTest() { protected ProfileGrpcService createServiceBeforeEachTest() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class); @SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@ -170,13 +176,9 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164); PhoneNumberUtil.PhoneNumberFormat.E164);
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us")); getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
try { Locale.LanguageRange.parse("en-us")));
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter); when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());
@ -203,8 +205,10 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null)); when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
clock = Clock.fixed(Instant.ofEpochSecond(42), ZoneId.of("Etc/UTC"));
return new ProfileGrpcService( return new ProfileGrpcService(
Clock.systemUTC(), clock,
accountsManager, accountsManager,
profilesManager, profilesManager,
dynamicConfigurationManager, dynamicConfigurationManager,
@ -392,6 +396,42 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
} }
} }
@Test
void setProfileBadges() throws InvalidInputException {
final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(AUTHENTICATED_ACI)).serialize();
final SetProfileRequest request = SetProfileRequest.newBuilder()
.setVersion(VERSION)
.setName(ByteString.copyFrom(VALID_NAME))
.setAvatarChange(AvatarChange.AVATAR_CHANGE_UNCHANGED)
.addAllBadgeIds(List.of("TEST3"))
.setCommitment(ByteString.copyFrom(commitment))
.build();
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(account.getBadges())
.thenReturn(List.of(new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)))
.thenReturn(List.of(new AccountBadge("TEST2", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)));
//noinspection ResultOfMethodCallIgnored
authenticatedServiceStub().setProfile(request);
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(account, times(2)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(account, times(accountsManagerUpdateRetryCount)).getBadges();
assertEquals(List.of(
new AccountBadge("TEST3", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(41), false)),
badgeCaptor.getValue());
}
@ParameterizedTest @ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"}) @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void getUnversionedProfile(final IdentityType identityType) { void getUnversionedProfile(final IdentityType identityType) {

View File

@ -6,7 +6,6 @@ import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest; import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse; import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc; import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.chat.rpc.UserAgent;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
@ -19,21 +18,15 @@ public class RequestAttributesServiceImpl extends RequestAttributesGrpc.RequestA
final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder(); final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder();
RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages -> RequestAttributesUtil.getAcceptableLanguages()
acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()))); .forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()));
RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale -> RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale ->
responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag())); responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag()));
responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress()); responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress());
RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder() RequestAttributesUtil.getUserAgent().ifPresent(responseBuilder::setUserAgent);
.setPlatform(userAgent.getPlatform().toString())
.setVersion(userAgent.getVersion().toString())
.setAdditionalSpecifiers(userAgent.getAdditionalSpecifiers().orElse(""))
.build()));
RequestAttributesUtil.getRawUserAgent().ifPresent(responseBuilder::setRawUserAgent);
responseObserver.onNext(responseBuilder.build()); responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted(); responseObserver.onCompleted();

View File

@ -3,172 +3,84 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import io.grpc.ManagedChannel; import io.grpc.Context;
import io.grpc.Server; import java.net.InetAddress;
import io.grpc.Status; import java.util.Collections;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import org.junit.jupiter.api.AfterAll; import java.util.concurrent.Callable;
import org.junit.jupiter.api.AfterEach; import javax.annotation.Nullable;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RequestAttributesUtilTest { class RequestAttributesUtilTest {
private static DefaultEventLoopGroup eventLoopGroup; private static final InetAddress REMOTE_ADDRESS = InetAddresses.forString("127.0.0.1");
private GrpcClientConnectionManager grpcClientConnectionManager; @Test
void getAcceptableLanguages() throws Exception {
assertEquals(Collections.emptyList(),
callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAcceptableLanguages));
private Server server; assertEquals(Locale.LanguageRange.parse("en,ja"),
private ManagedChannel managedChannel; callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAcceptableLanguages));
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server");
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString("127.0.0.1")));
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
} }
@Test @Test
void getAcceptableLanguages() { void getAvailableAcceptedLocales() throws Exception {
when(grpcClientConnectionManager.getAcceptableLanguages(any())) assertEquals(Collections.emptyList(),
.thenReturn(Optional.empty()); callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAvailableAcceptedLocales));
assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty()); final List<Locale> availableAcceptedLocales =
callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAvailableAcceptedLocales);
when(grpcClientConnectionManager.getAcceptableLanguages(any())) assertFalse(availableAcceptedLocales.isEmpty());
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList()); availableAcceptedLocales.forEach(locale ->
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage())));
} }
@Test @Test
void getAvailableAcceptedLocales() { void getRemoteAddress() throws Exception {
when(grpcClientConnectionManager.getAcceptableLanguages(any())) assertEquals(REMOTE_ADDRESS,
.thenReturn(Optional.empty()); callWithRequestAttributes(new RequestAttributes(REMOTE_ADDRESS, null, null),
RequestAttributesUtil::getRemoteAddress));
assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty());
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
final GetRequestAttributesResponse response = getRequestAttributes();
assertFalse(response.getAvailableAcceptedLocalesList().isEmpty());
response.getAvailableAcceptedLocalesList().forEach(languageTag -> {
final Locale locale = Locale.forLanguageTag(languageTag);
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()));
});
} }
@Test @Test
void getRemoteAddress() { void getUserAgent() throws Exception {
when(grpcClientConnectionManager.getRemoteAddress(any())) assertEquals(Optional.empty(),
.thenReturn(Optional.empty()); callWithRequestAttributes(buildRequestAttributes((String) null),
RequestAttributesUtil::getUserAgent));
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes); assertEquals(Optional.of("Signal-Desktop/1.2.3 Linux"),
callWithRequestAttributes(buildRequestAttributes("Signal-Desktop/1.2.3 Linux"),
final String remoteAddressString = "6.7.8.9"; RequestAttributesUtil::getUserAgent));
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString(remoteAddressString)));
assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress());
} }
@Test private static <V> V callWithRequestAttributes(final RequestAttributes requestAttributes, final Callable<V> callable) throws Exception {
void getUserAgent() throws UnrecognizedUserAgentException { return Context.current()
when(grpcClientConnectionManager.getUserAgent(any())) .withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes)
.thenReturn(Optional.empty()); .call(callable);
assertFalse(getRequestAttributes().hasUserAgent());
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
when(grpcClientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.of(userAgent));
final GetRequestAttributesResponse response = getRequestAttributes();
assertTrue(response.hasUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} }
@Test private static RequestAttributes buildRequestAttributes(final String userAgent) {
void getRawUserAgent() { return buildRequestAttributes(userAgent, Collections.emptyList());
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getRawUserAgent().isBlank());
final String userAgentString = "Signal-Desktop/1.2.3 Linux";
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.of(userAgentString));
assertEquals(userAgentString, getRequestAttributes().getRawUserAgent());
} }
private GetRequestAttributesResponse getRequestAttributes() { private static RequestAttributes buildRequestAttributes(final List<Locale.LanguageRange> acceptLanguage) {
return RequestAttributesGrpc.newBlockingStub(managedChannel) return buildRequestAttributes(null, acceptLanguage);
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()); }
private static RequestAttributes buildRequestAttributes(@Nullable final String userAgent,
final List<Locale.LanguageRange> acceptLanguage) {
return new RequestAttributes(REMOTE_ADDRESS, userAgent, acceptLanguage);
} }
} }

View File

@ -1,7 +1,11 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -12,6 +16,12 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel; import io.netty.channel.local.LocalServerChannel;
import java.net.InetAddress;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
@ -21,20 +31,9 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
class GrpcClientConnectionManagerTest { class GrpcClientConnectionManagerTest {
@ -103,7 +102,7 @@ class GrpcClientConnectionManagerTest {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
assertEquals(maybeAuthenticatedDevice, assertEquals(maybeAuthenticatedDevice,
grpcClientConnectionManager.getAuthenticatedDevice(localChannel.localAddress())); grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel));
} }
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() { private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
@ -114,170 +113,115 @@ class GrpcClientConnectionManagerTest {
} }
@Test @Test
void getAcceptableLanguages() { void getRequestAttributes() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(), assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel));
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
final List<Locale.LanguageRange> acceptLanguageRanges = Locale.LanguageRange.parse("en,ja"); final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null);
remoteChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges); remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes);
assertEquals(Optional.of(acceptLanguageRanges), assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel));
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
} }
@Test @Test
void getRemoteAddress() { void closeConnection() throws InterruptedException, ChannelNotFoundException {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9");
remoteChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress);
assertEquals(Optional.of(remoteAddress),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
remoteChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent);
assertEquals(Optional.of(userAgent),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
}
@Test
void closeConnection() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertTrue(remoteChannel.isOpen()); assertTrue(remoteChannel.isOpen());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), assertEquals(List.of(remoteChannel),
grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await(); remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
} }
@Test @ParameterizedTest
void handleWebSocketHandshakeCompleteRemoteAddress() { @MethodSource
void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress,
final String userAgentHeader,
final String acceptLanguageHeader,
final RequestAttributes expectedRequestAttributes) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1"); GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel,
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
preferredRemoteAddress, preferredRemoteAddress,
null,
null);
assertEquals(preferredRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader,
@Nullable final UserAgent expectedParsedUserAgent) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
userAgentHeader, userAgentHeader,
null);
assertEquals(userAgentHeader,
embeddedChannel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get());
assertEquals(expectedParsedUserAgent,
embeddedChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteUserAgent() {
return List.of(
// Recognized user-agent
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
// Unrecognized user-agent
Arguments.of("Not a valid user-agent string", null),
// Missing user-agent
Arguments.of(null, null)
);
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader,
@Nullable final List<Locale.LanguageRange> expectedLanguageRanges) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
null,
acceptLanguageHeader); acceptLanguageHeader);
assertEquals(expectedLanguageRanges, assertEquals(expectedRequestAttributes,
embeddedChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get()); embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
} }
private static List<Arguments> handleWebSocketHandshakeCompleteAcceptLanguage() { private static List<Arguments> handleHandshakeCompleteRequestAttributes() {
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
return List.of( return List.of(
// Parseable list Arguments.argumentSet("Null User-Agent and Accept-Language headers",
Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")), preferredRemoteAddress, null, null,
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())),
// Unparsable list Arguments.argumentSet("Recognized User-Agent and null Accept-Language header",
Arguments.of("This is not a valid language preference list", null), preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null,
new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())),
// Missing list Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header",
Arguments.of(null, null) preferredRemoteAddress, "Not a valid user-agent string", null,
new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())),
Arguments.argumentSet("Null User-Agent and parsable Accept-Language header",
preferredRemoteAddress, null, "ja,en;q=0.4",
new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))),
Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header",
preferredRemoteAddress, null, "This is not a valid language preference list",
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList()))
); );
} }
@Test @Test
void handleConnectionEstablishedAuthenticated() throws InterruptedException { void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await(); remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
} }
@Test @Test
void handleConnectionEstablishedAnonymous() throws InterruptedException { void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException {
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
remoteChannel.close().await(); remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
} }
} }

View File

@ -1,6 +1,7 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
@ -8,10 +9,12 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder; import io.grpc.ServerBuilder;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
@ -61,6 +64,9 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest; import org.signal.chat.rpc.GetRequestAttributesRequest;
@ -71,6 +77,8 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
@ -83,6 +91,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private static NioEventLoopGroup nioEventLoopGroup; private static NioEventLoopGroup nioEventLoopGroup;
private static DefaultEventLoopGroup defaultEventLoopGroup; private static DefaultEventLoopGroup defaultEventLoopGroup;
private static ExecutorService delegatedTaskExecutor; private static ExecutorService delegatedTaskExecutor;
private static ExecutorService serverCallExecutor;
private static X509Certificate serverTlsCertificate; private static X509Certificate serverTlsCertificate;
@ -136,7 +145,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
static void setUpBeforeAll() throws CertificateException { static void setUpBeforeAll() throws CertificateException {
nioEventLoopGroup = new NioEventLoopGroup(); nioEventLoopGroup = new NioEventLoopGroup();
defaultEventLoopGroup = new DefaultEventLoopGroup(); defaultEventLoopGroup = new DefaultEventLoopGroup();
delegatedTaskExecutor = Executors.newSingleThreadExecutor(); delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate( serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
@ -171,7 +181,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override @Override
protected void configureServer(final ServerBuilder<?> serverBuilder) { protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new RequestAttributesServiceImpl()) serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.addService(new EchoServiceImpl())
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager)); .intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
} }
@ -182,7 +196,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override @Override
protected void configureServer(final ServerBuilder<?> serverBuilder) { protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new RequestAttributesServiceImpl()) serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager)); .intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
} }
@ -235,6 +251,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
delegatedTaskExecutor.shutdown(); delegatedTaskExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS); delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
serverCallExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
} }
@ParameterizedTest @ParameterizedTest
@ -523,10 +543,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
assertEquals(remoteAddress, response.getRemoteAddress()); assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList()); assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals(userAgent, response.getUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -582,6 +599,89 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
} }
} }
@Test
void waitForCallCompletion() throws InterruptedException {
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
final AtomicBoolean closedByServer = new AtomicBoolean(false);
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(false);
connectionCloseLatch.countDown();
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(true);
connectionCloseLatch.countDown();
}
};
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
// Start an open-ended server call and leave it in a non-complete state
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
new StreamObserver<>() {
@Override
public void onNext(final EchoResponse echoResponse) {
responseCountDownLatch.countDown();
}
@Override
public void onError(final Throwable throwable) {
}
@Override
public void onCompleted() {
}
});
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
// truly started before requesting connection closure.
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
assertFalse(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should not close until active requests have finished");
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
// Complete the open-ended server call
echoRequestStreamObserver.onCompleted();
assertTrue(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should close once active requests have finished");
assertTrue(closedByServer.get());
assertEquals(4004, serverCloseStatusCode.get());
} finally {
channel.shutdown();
}
}
}
private NoiseWebSocketTunnelClient.Builder anonymous() { private NoiseWebSocketTunnelClient.Builder anonymous() {
return new NoiseWebSocketTunnelClient return new NoiseWebSocketTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())

View File

@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -16,6 +17,7 @@ import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.Attribute;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
@ -31,6 +33,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
@ -134,8 +137,13 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
embeddedChannel.setRemoteAddress(remoteAddress); embeddedChannel.setRemoteAddress(remoteAddress);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertEquals(expectedRemoteAddress, assertEquals(expectedRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY))
.map(Attribute::get)
.map(RequestAttributes::remoteAddress)
.orElse(null));
} }
private static List<Arguments> getRemoteAddress() { private static List<Arguments> getRemoteAddress() {
@ -144,53 +152,53 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1"); final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
return List.of( return List.of(
// Recognized proxy, single forwarded-for address argumentSet("Recognized proxy, single forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
clientAddress), clientAddress),
// Recognized proxy, multiple forwarded-for addresses argumentSet("Recognized proxy, multiple forwarded-for addresses",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
remoteAddress, remoteAddress,
proxyAddress), proxyAddress),
// No recognized proxy header, single forwarded-for address argumentSet("No recognized proxy header, single forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// No recognized proxy header, no forwarded-for address argumentSet("No recognized proxy header, no forwarded-for address",
Arguments.of(new DefaultHttpHeaders(), new DefaultHttpHeaders(),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// Incorrect proxy header, single forwarded-for address argumentSet("Incorrect proxy header, single forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect") .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// Recognized proxy, no forwarded-for address argumentSet("Recognized proxy, no forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// Recognized proxy, bogus forwarded-for address argumentSet("Recognized proxy, bogus forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
remoteAddress, remoteAddress,
null), null),
// No forwarded-for address, non-InetSocketAddress remote address argumentSet("No forwarded-for address, non-InetSocketAddress remote address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
new LocalAddress("local-address"), new LocalAddress("local-address"),
null) null)

View File

@ -1,91 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
class OperatingSystemMemoryGaugeTest {
private static final String MEMINFO =
"""
MemTotal: 16052208 kB
MemFree: 4568468 kB
MemAvailable: 7702848 kB
Buffers: 636372 kB
Cached: 5019116 kB
SwapCached: 6692 kB
Active: 7746436 kB
Inactive: 2729876 kB
Active(anon): 5580980 kB
Inactive(anon): 1648108 kB
Active(file): 2165456 kB
Inactive(file): 1081768 kB
Unevictable: 443948 kB
Mlocked: 4924 kB
SwapTotal: 1003516 kB
SwapFree: 935932 kB
Dirty: 28308 kB
Writeback: 0 kB
AnonPages: 5258396 kB
Mapped: 1530740 kB
Shmem: 2419340 kB
KReclaimable: 229392 kB
Slab: 408156 kB
SReclaimable: 229392 kB
SUnreclaim: 178764 kB
KernelStack: 17360 kB
PageTables: 50436 kB
NFS_Unstable: 0 kB
Bounce: 0 kB
WritebackTmp: 0 kB
CommitLimit: 9029620 kB
Committed_AS: 16681884 kB
VmallocTotal: 34359738367 kB
VmallocUsed: 41944 kB
VmallocChunk: 0 kB
Percpu: 4240 kB
HardwareCorrupted: 0 kB
AnonHugePages: 0 kB
ShmemHugePages: 0 kB
ShmemPmdMapped: 0 kB
FileHugePages: 0 kB
FilePmdMapped: 0 kB
CmaTotal: 0 kB
CmaFree: 0 kB
HugePages_Total: 0
HugePages_Free: 7
HugePages_Rsvd: 0
HugePages_Surp: 0
Hugepagesize: 2048 kB
Hugetlb: 0 kB
DirectMap4k: 481804 kB
DirectMap2M: 14901248 kB
DirectMap1G: 2097152 kB
""";
@ParameterizedTest
@MethodSource
void testGetValue(final String metricName, final long expectedValue) {
assertEquals(expectedValue, OperatingSystemMemoryGauge.getValue(MEMINFO.lines(), metricName));
}
@SuppressWarnings("unused")
private static Stream<Arguments> testGetValue() {
return Stream.of(
Arguments.of("MemTotal", 16052208L),
Arguments.of("Active(anon)", 5580980L),
Arguments.of("Committed_AS", 16681884L),
Arguments.of("HugePages_Free", 7L),
Arguments.of("NonsenseMetric", 0L)
);
}
}

View File

@ -12,6 +12,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -39,6 +40,7 @@ import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
@ -54,13 +56,67 @@ class MessageSenderTest {
private MessagesManager messagesManager; private MessagesManager messagesManager;
private PushNotificationManager pushNotificationManager; private PushNotificationManager pushNotificationManager;
private MessageSender messageSender; private MessageSender messageSender;
private ExperimentEnrollmentManager experimentEnrollmentManager;
@BeforeEach @BeforeEach
void setUp() { void setUp() {
messagesManager = mock(MessagesManager.class); messagesManager = mock(MessagesManager.class);
pushNotificationManager = mock(PushNotificationManager.class); pushNotificationManager = mock(PushNotificationManager.class);
experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
messageSender = new MessageSender(messagesManager, pushNotificationManager); messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager);
}
@CartesianTest
void pushSkippedExperiment(
@CartesianTest.Values(booleans = {true, false}) final boolean hasGcmToken,
@CartesianTest.Values(booleans = {true, false}) final boolean isUrgent,
@CartesianTest.Values(booleans = {true, false}) final boolean inExperiment) throws NotPushRegisteredException {
final boolean shouldSkip = hasGcmToken && !isUrgent && inExperiment;
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder()
.setEphemeral(false)
.setUrgent(isUrgent)
.build();
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
if (hasGcmToken) {
when(device.getGcmId()).thenReturn("gcm-token");
} else {
when(device.getApnId()).thenReturn("apn-token");
}
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, false));
when(experimentEnrollmentManager.isEnrolled(accountIdentifier, MessageSender.ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT))
.thenReturn(inExperiment);
assertDoesNotThrow(() -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(device.getId(), message),
Map.of(device.getId(), registrationId),
Optional.empty(),
null));
if (shouldSkip) {
verifyNoInteractions(pushNotificationManager);
} else {
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, isUrgent);
}
} }
@CartesianTest @CartesianTest

View File

@ -6,8 +6,10 @@ import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import java.time.Clock; import java.time.Clock;
import java.time.LocalDateTime;
import java.time.LocalTime; import java.time.LocalTime;
import java.time.ZoneId; import java.time.ZoneId;
import java.time.ZoneOffset;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -65,4 +67,26 @@ class SchedulingUtilTest {
Clock.fixed(afterNotificationTime.toInstant(), ZoneId.systemDefault()))); Clock.fixed(afterNotificationTime.toInstant(), ZoneId.systemDefault())));
} }
} }
@Test
void getNextRecommendedNotificationTimeDaylightSavings() {
final Account account = mock(Account.class);
// The account has a phone number that can be resolved to a region with known timezones
when(account.getNumber()).thenReturn(PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("DE"), PhoneNumberUtil.PhoneNumberFormat.E164));
final LocalDateTime afterNotificationTime = LocalDateTime.of(2025, 3, 29, 15, 0);
final ZoneId berlinZoneId = ZoneId.of("Europe/Berlin");
final ZoneOffset berlineZoneOffset = berlinZoneId.getRules().getOffset(afterNotificationTime);
// Daylight Savings Time started on 2025-03-30 at 2:00AM in Germany.
// Instantiating a ZonedDateTime with a zone ID factors in daylight savings when we adjust the time.
final ZonedDateTime afterNotificationTimeWithZoneId = ZonedDateTime.of(afterNotificationTime, berlinZoneId);
assertEquals(
afterNotificationTimeWithZoneId.with(LocalTime.of(14, 0)).plusDays(1).toInstant(),
SchedulingUtil.getNextRecommendedNotificationTime(account, LocalTime.of(14, 0),
Clock.fixed(afterNotificationTime.toInstant(berlineZoneOffset), berlinZoneId)));
}
} }

View File

@ -9,13 +9,16 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -35,8 +38,10 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
public class ChangeNumberManagerTest { public class ChangeNumberManagerTest {
private AccountsManager accountsManager; private AccountsManager accountsManager;
@ -45,11 +50,13 @@ public class ChangeNumberManagerTest {
private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount; private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount;
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
accountsManager = mock(AccountsManager.class); accountsManager = mock(AccountsManager.class);
messageSender = mock(MessageSender.class); messageSender = mock(MessageSender.class);
changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, CLOCK);
updatedPhoneNumberIdentifiersByAccount = new HashMap<>(); updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
@ -132,45 +139,59 @@ public class ChangeNumberManagerTest {
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni); when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class); final Device primaryDevice = mock(Device.class);
final byte deviceId2 = 2; when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(d2.getId()).thenReturn(deviceId2); when(primaryDevice.getRegistrationId()).thenReturn(7);
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); final Device linkedDevice = mock(Device.class);
when(account.getDevices()).thenReturn(List.of(d2)); final byte linkedDeviceId = Device.PRIMARY_ID + 1;
final int linkedDeviceRegistrationId = 17;
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID, final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair), KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); linkedDeviceId, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, linkedDeviceId, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2); when(msg.type()).thenReturn(1);
when(msg.destinationDeviceId()).thenReturn(linkedDeviceId);
when(msg.destinationRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(msg.content()).thenReturn(new byte[]{1}); when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds); verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder()
ArgumentCaptor.forClass(Map.class); .setType(MessageProtos.Envelope.Type.forNumber(msg.type()))
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setDestinationServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setContent(ByteString.copyFrom(msg.content()))
.setSourceServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(updatedPhoneNumberIdentifiersByAccount.get(account).toString())
.setUrgent(true)
.setEphemeral(false)
.build();
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); verify(messageSender).sendMessages(argThat(a -> a.getUuid().equals(aci)),
eq(new AciServiceIdentifier(aci)),
assertEquals(1, envelopeCaptor.getValue().size()); eq(Map.of(linkedDeviceId, expectedEnvelope)),
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); eq(Map.of(linkedDeviceId, linkedDeviceRegistrationId)),
eq(Optional.of(Device.PRIMARY_ID)),
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); any());
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
} }
@Test @Test
void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final String originalE164 = "+18005551234"; final String originalE164 = "+18005551234";

View File

@ -32,6 +32,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
@ -97,7 +98,7 @@ class MessagePersisterIntegrationTest {
webSocketConnectionEventManager.start(); webSocketConnectionEventManager.start();
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY, 1); dynamicConfigurationManager, mock(ExperimentEnrollmentManager.class), PERSIST_DELAY, 1);
account = mock(Account.class); account = mock(Account.class);

View File

@ -54,6 +54,7 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagePersisterConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagePersisterConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
@ -118,7 +119,7 @@ class MessagePersisterTest {
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY, 1); dynamicConfigurationManager, mock(ExperimentEnrollmentManager.class), PERSIST_DELAY, 1);
when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
@ -257,7 +258,7 @@ class MessagePersisterTest {
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
assertThrows(MessagePersistenceException.class, assertThrows(MessagePersistenceException.class,
() -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE))); () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test")));
} }
@Test @Test
@ -298,7 +299,7 @@ class MessagePersisterTest {
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test"));
verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID); verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID);
} }
@ -400,7 +401,7 @@ class MessagePersisterTest {
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenReturn(CompletableFuture.failedFuture(new TimeoutException())); when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenReturn(CompletableFuture.failedFuture(new TimeoutException()));
assertThrows(CompletionException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); assertThrows(CompletionException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test"));
} }
@SuppressWarnings("SameParameterValue") @SuppressWarnings("SameParameterValue")

View File

@ -31,8 +31,8 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
public class AccountsHelper { public class AccountsHelper {
@ -62,6 +62,71 @@ public class AccountsHelper {
setupMockUpdate(mockAccountsManager, false); setupMockUpdate(mockAccountsManager, false);
} }
/**
* Sets up stubbing for:
* <ul>
* <li>{@link AccountsManager#update(Account, Consumer)}</li>
* <li>{@link AccountsManager#updateAsync(Account, Consumer)}</li>
* <li>{@link AccountsManager#updateDevice(Account, byte, Consumer)}</li>
* <li>{@link AccountsManager#updateDeviceAsync(Account, byte, Consumer)}</li>
* </ul>
*
* with multiple calls to the {@link Consumer<Account>}. This simulates retries from {@link org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException}.
* Callers will typically set up stubbing for relevant {@link Account} methods with multiple {@link org.mockito.stubbing.OngoingStubbing#thenReturn(Object)}
* calls:
* <pre>
* // example stubbing
* when(account.getNextDeviceId())
* .thenReturn(2)
* .thenReturn(3);
* </pre>
*/
@SuppressWarnings("unchecked")
public static void setupMockUpdateWithRetries(final AccountsManager mockAccountsManager, final int retryCount) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
for (int i = 0; i < retryCount; i++) {
answer.getArgument(1, Consumer.class).accept(account);
}
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateAsync(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
for (int i = 0; i < retryCount; i++) {
answer.getArgument(1, Consumer.class).accept(account);
}
return CompletableFuture.completedFuture(copyAndMarkStale(account));
});
when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
for (int i = 0; i < retryCount; i++) {
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
}
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateDeviceAsync(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
for (int i = 0; i < retryCount; i++) {
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
}
return CompletableFuture.completedFuture(copyAndMarkStale(account));
});
}
@SuppressWarnings("unchecked")
private static void setupMockUpdate(final AccountsManager mockAccountsManager, final boolean markStale) { private static void setupMockUpdate(final AccountsManager mockAccountsManager, final boolean markStale) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> { when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class); final Account account = answer.getArgument(0, Account.class);

View File

@ -0,0 +1,101 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.junit.jupiter.api.Assertions.*;
class ClosableEpochTest {
@Test
void close() {
{
final AtomicBoolean closed = new AtomicBoolean(false);
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true));
assertTrue(closableEpoch.tryArrive(), "New callers should be allowed to arrive before closure");
assertEquals(1, closableEpoch.getActiveCallers());
closableEpoch.close();
assertFalse(closableEpoch.tryArrive(), "New callers should not be allowed to arrive after closure");
assertEquals(1, closableEpoch.getActiveCallers());
assertFalse(closed.get(), "Close handler should not fire until all callers have departed");
closableEpoch.depart();
assertTrue(closed.get(), "Close handler should fire after last caller departs");
assertEquals(0, closableEpoch.getActiveCallers());
assertThrows(IllegalStateException.class, closableEpoch::close,
"Double-closing a epoch should throw an exception");
}
{
final AtomicBoolean closed = new AtomicBoolean(false);
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true));
closableEpoch.close();
assertTrue(closed.get(), "Empty epoch should fire close handler immediately on closure");
assertEquals(0, closableEpoch.getActiveCallers());
}
}
@Test
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
void closeConcurrent() throws InterruptedException {
final AtomicBoolean closed = new AtomicBoolean(false);
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> {
synchronized (closed) {
closed.set(true);
closed.notifyAll();
}
});
final int threadCount = 128;
final CyclicBarrier cyclicBarrier = new CyclicBarrier(threadCount);
// Spawn a bunch of threads doing some simulated work. Close the epoch roughly halfway through. Some threads should
// successfully enter the critical section and others should be rejected.
for (int t = 0; t < threadCount; t++) {
final boolean shouldClose = t == threadCount / 2;
Thread.ofVirtual().start(() -> {
try {
// Wait for all threads to reach the proverbial starting line
cyclicBarrier.await();
} catch (final InterruptedException | BrokenBarrierException ignored) {
}
if (shouldClose) {
closableEpoch.close();
}
if (closableEpoch.tryArrive()) {
// Perform some simulated "work"
try {
Thread.sleep(1);
} catch (final InterruptedException ignored) {
} finally {
closableEpoch.depart();
}
}
});
}
while (!closed.get()) {
synchronized (closed) {
closed.wait();
}
}
assertEquals(0, closableEpoch.getActiveCallers());
}
}

View File

@ -13,28 +13,20 @@ import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import javax.annotation.Nullable;
class UserAgentUtilTest { class UserAgentUtilTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource("argumentsForTestParseStandardUserAgentString")
void testParseBogusUserAgentString(final String userAgentString) { void testParseStandardUserAgentString(final String userAgentString, @Nullable final UserAgent expectedUserAgent)
throws UnrecognizedUserAgentException {
if (expectedUserAgent != null) {
assertEquals(expectedUserAgent, UserAgentUtil.parseUserAgentString(userAgentString));
} else {
assertThrows(UnrecognizedUserAgentException.class, () -> UserAgentUtil.parseUserAgentString(userAgentString)); assertThrows(UnrecognizedUserAgentException.class, () -> UserAgentUtil.parseUserAgentString(userAgentString));
} }
@SuppressWarnings("unused")
private static Stream<String> testParseBogusUserAgentString() {
return Stream.of(
null,
"This is obviously not a reasonable User-Agent string.",
"Signal-Android/4.6-8.3.unreasonableversionstring-17"
);
}
@ParameterizedTest
@MethodSource("argumentsForTestParseStandardUserAgentString")
void testParseStandardUserAgentString(final String userAgentString, final UserAgent expectedUserAgent) {
assertEquals(expectedUserAgent, UserAgentUtil.parseStandardUserAgentString(userAgentString));
} }
private static Stream<Arguments> argumentsForTestParseStandardUserAgentString() { private static Stream<Arguments> argumentsForTestParseStandardUserAgentString() {
@ -42,18 +34,18 @@ class UserAgentUtilTest {
Arguments.of("This is obviously not a reasonable User-Agent string.", null), Arguments.of("This is obviously not a reasonable User-Agent string.", null),
Arguments.of("Signal-Android/4.68.3 Android/25", Arguments.of("Signal-Android/4.68.3 Android/25",
new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25")), new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25")),
Arguments.of("Signal-Android/4.68.3", new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"))), Arguments.of("Signal-Android/4.68.3", new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), null)),
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")), Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
Arguments.of("Signal-Desktop/1.2.3 macOS", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "macOS")), Arguments.of("Signal-Desktop/1.2.3 macOS", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "macOS")),
Arguments.of("Signal-Desktop/1.2.3 Windows", Arguments.of("Signal-Desktop/1.2.3 Windows",
new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Windows")), new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Windows")),
Arguments.of("Signal-Desktop/1.2.3", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"))), Arguments.of("Signal-Desktop/1.2.3", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), null)),
Arguments.of("Signal-Desktop/1.32.0-beta.3", Arguments.of("Signal-Desktop/1.32.0-beta.3",
new UserAgent(ClientPlatform.DESKTOP, new Semver("1.32.0-beta.3"))), new UserAgent(ClientPlatform.DESKTOP, new Semver("1.32.0-beta.3"), null)),
Arguments.of("Signal-iOS/3.9.0 (iPhone; iOS 12.2; Scale/3.00)", Arguments.of("Signal-iOS/3.9.0 (iPhone; iOS 12.2; Scale/3.00)",
new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)")), new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)")),
Arguments.of("Signal-iOS/3.9.0 iOS/14.2", new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "iOS/14.2")), Arguments.of("Signal-iOS/3.9.0 iOS/14.2", new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "iOS/14.2")),
Arguments.of("Signal-iOS/3.9.0", new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"))), Arguments.of("Signal-iOS/3.9.0", new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), null)),
Arguments.of("Signal-Android/7.11.23-nightly-1982-06-28-07-07-07 tonic/0.31", Arguments.of("Signal-Android/7.11.23-nightly-1982-06-28-07-07-07 tonic/0.31",
new UserAgent(ClientPlatform.ANDROID, new Semver("7.11.23-nightly-1982-06-28-07-07-07"), "tonic/0.31")), new UserAgent(ClientPlatform.ANDROID, new Semver("7.11.23-nightly-1982-06-28-07-07-07"), "tonic/0.31")),
Arguments.of("Signal-Android/7.11.23-nightly-1982-06-28-07-07-07 Android/42 tonic/0.31", Arguments.of("Signal-Android/7.11.23-nightly-1982-06-28-07-07-07 Android/42 tonic/0.31",

View File

@ -47,6 +47,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@ -141,7 +142,8 @@ class WebSocketConnectionIntegrationTest {
scheduledExecutorService, scheduledExecutorService,
messageDeliveryScheduler, messageDeliveryScheduler,
clientReleaseManager, clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class)); mock(MessageDeliveryLoopMonitor.class),
mock(ExperimentEnrollmentManager.class));
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
@ -229,7 +231,8 @@ class WebSocketConnectionIntegrationTest {
scheduledExecutorService, scheduledExecutorService,
messageDeliveryScheduler, messageDeliveryScheduler,
clientReleaseManager, clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class)); mock(MessageDeliveryLoopMonitor.class),
mock(ExperimentEnrollmentManager.class));
final int persistedMessageCount = 207; final int persistedMessageCount = 207;
final int cachedMessageCount = 173; final int cachedMessageCount = 173;
@ -299,7 +302,8 @@ class WebSocketConnectionIntegrationTest {
scheduledExecutorService, scheduledExecutorService,
messageDeliveryScheduler, messageDeliveryScheduler,
clientReleaseManager, clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class)); mock(MessageDeliveryLoopMonitor.class),
mock(ExperimentEnrollmentManager.class));
final int persistedMessageCount = 207; final int persistedMessageCount = 207;
final int cachedMessageCount = 173; final int cachedMessageCount = 173;

View File

@ -54,6 +54,7 @@ import org.junit.jupiter.api.Test;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
@ -125,7 +126,8 @@ class WebSocketConnectionTest {
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
mock(WebSocketConnectionEventManager.class), retrySchedulingExecutor, mock(WebSocketConnectionEventManager.class), retrySchedulingExecutor,
messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class)); messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class),
mock(ExperimentEnrollmentManager.class));
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
@ -629,7 +631,8 @@ class WebSocketConnectionTest {
private WebSocketConnection webSocketConnection(final WebSocketClient client) { private WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(), return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client, mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class)); retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class), mock(ExperimentEnrollmentManager.class));
} }
@Test @Test

View File

@ -13,6 +13,7 @@ package org.signal.chat.rpc;
service EchoService { service EchoService {
rpc echo (EchoRequest) returns (EchoResponse) {} rpc echo (EchoRequest) returns (EchoResponse) {}
rpc echo2 (EchoRequest) returns (EchoResponse) {} rpc echo2 (EchoRequest) returns (EchoResponse) {}
rpc echoStream (stream EchoRequest) returns (stream EchoResponse) {}
} }
message EchoRequest { message EchoRequest {

View File

@ -23,14 +23,7 @@ message GetRequestAttributesResponse {
repeated string acceptable_languages = 1; repeated string acceptable_languages = 1;
repeated string available_accepted_locales = 2; repeated string available_accepted_locales = 2;
string remote_address = 3; string remote_address = 3;
string raw_user_agent = 4; string user_agent = 4;
UserAgent user_agent = 5;
}
message UserAgent {
string platform = 1;
string version = 2;
string additional_specifiers = 3;
} }
message GetAuthenticatedDeviceRequest { message GetAuthenticatedDeviceRequest {

View File

@ -470,7 +470,8 @@ turn:
cloudflare: cloudflare:
apiToken: secret://turn.cloudflare.apiToken apiToken: secret://turn.cloudflare.apiToken
endpoint: https://rtc.live.cloudflare.com/v1/turn/keys/LMNOP/credentials/generate endpoint: https://rtc.live.cloudflare.com/v1/turn/keys/LMNOP/credentials/generate
ttl: 86400 requestedCredentialTtl: PT24H
clientCredentialTtl: PT12H
urls: urls:
- turn:turn.example.com:80 - turn:turn.example.com:80
urlsWithIps: urlsWithIps:

@ -1 +1 @@
Subproject commit 8f566196d763c8eb1f3c8fcefd5be3c35ff8d148 Subproject commit 9b664d54d5d9a04a2c005ad376910001c11e47e4