Compare commits

..

No commits in common. "main" and "v20250408.0.0" have entirely different histories.

109 changed files with 1236 additions and 4674 deletions

View File

@ -1,7 +1,6 @@
name: Service CI name: Service CI
on: on:
pull_request:
push: push:
branches-ignore: branches-ignore:
- gh-pages - gh-pages

31
pom.xml
View File

@ -46,7 +46,7 @@
<!-- can be updated to latest version with Dropwizard 5 (Jetty 12); will then need to disable telemetry --> <!-- can be updated to latest version with Dropwizard 5 (Jetty 12); will then need to disable telemetry -->
<dynamodblocal.version>2.2.1</dynamodblocal.version> <dynamodblocal.version>2.2.1</dynamodblocal.version>
<google-cloud-libraries.version>26.57.0</google-cloud-libraries.version> <google-cloud-libraries.version>26.57.0</google-cloud-libraries.version>
<grpc.version>1.70.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom --> <grpc.version>1.69.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
<gson.version>2.12.1</gson.version> <gson.version>2.12.1</gson.version>
<!-- several libraries (AWS, Google Cloud) use Apache http components transitively, and we need to align them --> <!-- several libraries (AWS, Google Cloud) use Apache http components transitively, and we need to align them -->
<httpcore.version>4.4.16</httpcore.version> <httpcore.version>4.4.16</httpcore.version>
@ -65,15 +65,14 @@
<luajava.version>3.5.0</luajava.version> <luajava.version>3.5.0</luajava.version>
<micrometer.version>1.14.5</micrometer.version> <micrometer.version>1.14.5</micrometer.version>
<netty.version>4.1.119.Final</netty.version> <netty.version>4.1.119.Final</netty.version>
<!-- Must be less than or equal to the value from Google libraries-bom which controls the protobuf runtime version. <!-- Must be greater than or equal to the value from Google libraries-bom
See https://protobuf.dev/support/cross-version-runtime-guarantee/. --> since some of its libraries generate code. See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
<protoc.version>4.29.4</protoc.version> <protobuf.version>3.25.5</protobuf.version>
<pushy.version>0.15.4</pushy.version> <pushy.version>0.15.4</pushy.version>
<reactive.grpc.version>1.2.4</reactive.grpc.version> <reactive.grpc.version>1.2.4</reactive.grpc.version>
<reactor-bom.version>2024.0.4</reactor-bom.version> <!-- 3.7.4, see https://github.com/reactor/reactor#bom-versioning-scheme --> <reactor-bom.version>2024.0.4</reactor-bom.version> <!-- 3.7.4, see https://github.com/reactor/reactor#bom-versioning-scheme -->
<resilience4j.version>2.3.0</resilience4j.version> <resilience4j.version>2.3.0</resilience4j.version>
<semver4j.version>3.1.0</semver4j.version> <semver4j.version>3.1.0</semver4j.version>
<simple-grpc.version>0.1.0</simple-grpc.version>
<slf4j.version>2.0.17</slf4j.version> <slf4j.version>2.0.17</slf4j.version>
<stripe.version>23.10.0</stripe.version> <stripe.version>23.10.0</stripe.version>
<swagger.version>2.2.27</swagger.version> <swagger.version>2.2.27</swagger.version>
@ -127,7 +126,7 @@
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.google.cloud</groupId> <groupId>com.google.cloud</groupId>
<artifactId>libraries-bom</artifactId> <artifactId>libraries-bom-protobuf3</artifactId>
<version>${google-cloud-libraries.version}</version> <version>${google-cloud-libraries.version}</version>
<type>pom</type> <type>pom</type>
<scope>import</scope> <scope>import</scope>
@ -175,6 +174,11 @@
<artifactId>pushy-dropwizard-metrics-listener</artifactId> <artifactId>pushy-dropwizard-metrics-listener</artifactId>
<version>${pushy.version}</version> <version>${pushy.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>
<dependency> <dependency>
<groupId>com.googlecode.libphonenumber</groupId> <groupId>com.googlecode.libphonenumber</groupId>
<artifactId>libphonenumber</artifactId> <artifactId>libphonenumber</artifactId>
@ -263,11 +267,6 @@
<artifactId>libsignal-server</artifactId> <artifactId>libsignal-server</artifactId>
<version>0.67.6</version> <version>0.67.6</version>
</dependency> </dependency>
<dependency>
<groupId>org.signal</groupId>
<artifactId>simple-grpc-runtime</artifactId>
<version>${simple-grpc.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.signal.forks</groupId> <groupId>org.signal.forks</groupId>
<artifactId>noise-java</artifactId> <artifactId>noise-java</artifactId>
@ -438,7 +437,7 @@
<version>0.6.1</version> <version>0.6.1</version>
<configuration> <configuration>
<checkStaleness>false</checkStaleness> <checkStaleness>false</checkStaleness>
<protocArtifact>com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}</protocArtifact> <protocArtifact>com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}</protocArtifact>
<pluginId>grpc-java</pluginId> <pluginId>grpc-java</pluginId>
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact> <pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
@ -450,14 +449,6 @@
<version>${reactive.grpc.version}</version> <version>${reactive.grpc.version}</version>
<mainClass>com.salesforce.reactorgrpc.ReactorGrpcGenerator</mainClass> <mainClass>com.salesforce.reactorgrpc.ReactorGrpcGenerator</mainClass>
</protocPlugin> </protocPlugin>
<protocPlugin>
<id>simple</id>
<groupId>org.signal</groupId>
<artifactId>simple-grpc-generator</artifactId>
<version>${simple-grpc.version}</version>
<mainClass>org.signal.grpc.simple.SimpleGrpcGenerator</mainClass>
</protocPlugin>
</protocPlugins> </protocPlugins>
</configuration> </configuration>
<executions> <executions>

View File

@ -482,8 +482,7 @@ turn:
- turn:%s - turn:%s
- turn:%s:80?transport=tcp - turn:%s:80?transport=tcp
- turns:%s:443?transport=tcp - turns:%s:443?transport=tcp
requestedCredentialTtl: PT24H ttl: 86400
clientCredentialTtl: PT12H
hostname: turn.cloudflare.example.com hostname: turn.cloudflare.example.com
numHttpClients: 1 numHttpClients: 1

View File

@ -80,11 +80,6 @@
<artifactId>noise-java</artifactId> <artifactId>noise-java</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.signal</groupId>
<artifactId>simple-grpc-runtime</artifactId>
</dependency>
<dependency> <dependency>
<groupId>io.dropwizard</groupId> <groupId>io.dropwizard</groupId>
<artifactId>dropwizard-core</artifactId> <artifactId>dropwizard-core</artifactId>

View File

@ -668,13 +668,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager); final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager); final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager);
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor); final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager( final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
config.getTurnConfiguration().cloudflare().apiToken().value(), config.getTurnConfiguration().cloudflare().apiToken().value(),
config.getTurnConfiguration().cloudflare().endpoint(), config.getTurnConfiguration().cloudflare().endpoint(),
config.getTurnConfiguration().cloudflare().requestedCredentialTtl(), config.getTurnConfiguration().cloudflare().ttl(),
config.getTurnConfiguration().cloudflare().clientCredentialTtl(),
config.getTurnConfiguration().cloudflare().urls(), config.getTurnConfiguration().cloudflare().urls(),
config.getTurnConfiguration().cloudflare().urlsWithIps(), config.getTurnConfiguration().cloudflare().urlsWithIps(),
config.getTurnConfiguration().cloudflare().hostname(), config.getTurnConfiguration().cloudflare().hostname(),
@ -694,7 +693,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager, PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager,
pushChallengeDynamoDb); pushChallengeDynamoDb);
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, Clock.systemUTC()); ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build(); HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build();
FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients() FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients()
@ -988,7 +987,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.setConnectListener( webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager, new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor, pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager)); messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor));
webSocketEnvironment.jersey() webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager)); .register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters)); webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));

View File

@ -15,7 +15,6 @@ import java.net.Inet6Address;
import java.net.URI; import java.net.URI;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -40,18 +39,16 @@ public class CloudflareTurnCredentialsManager {
private final List<String> cloudflareTurnUrls; private final List<String> cloudflareTurnUrls;
private final List<String> cloudflareTurnUrlsWithIps; private final List<String> cloudflareTurnUrlsWithIps;
private final String cloudflareTurnHostname; private final String cloudflareTurnHostname;
private final HttpRequest getCredentialsRequest; private final HttpRequest request;
private final FaultTolerantHttpClient cloudflareTurnClient; private final FaultTolerantHttpClient cloudflareTurnClient;
private final DnsNameResolver dnsNameResolver; private final DnsNameResolver dnsNameResolver;
private final Duration clientCredentialTtl; record CredentialRequest(long ttl) {}
private record CredentialRequest(long ttl) {} record CloudflareTurnResponse(IceServer iceServers) {
private record CloudflareTurnResponse(IceServer iceServers) { record IceServer(
private record IceServer(
String username, String username,
String credential, String credential,
List<String> urls) { List<String> urls) {
@ -59,17 +56,10 @@ public class CloudflareTurnCredentialsManager {
} }
public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken, public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken,
final String cloudflareTurnEndpoint, final String cloudflareTurnEndpoint, final long cloudflareTurnTtl, final List<String> cloudflareTurnUrls,
final Duration requestedCredentialTtl, final List<String> cloudflareTurnUrlsWithIps, final String cloudflareTurnHostname,
final Duration clientCredentialTtl, final int cloudflareTurnNumHttpClients, final CircuitBreakerConfiguration circuitBreaker,
final List<String> cloudflareTurnUrls, final ExecutorService executor, final RetryConfiguration retry, final ScheduledExecutorService retryExecutor,
final List<String> cloudflareTurnUrlsWithIps,
final String cloudflareTurnHostname,
final int cloudflareTurnNumHttpClients,
final CircuitBreakerConfiguration circuitBreaker,
final ExecutorService executor,
final RetryConfiguration retry,
final ScheduledExecutorService retryExecutor,
final DnsNameResolver dnsNameResolver) { final DnsNameResolver dnsNameResolver) {
this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder() this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder()
@ -85,24 +75,17 @@ public class CloudflareTurnCredentialsManager {
this.cloudflareTurnHostname = cloudflareTurnHostname; this.cloudflareTurnHostname = cloudflareTurnHostname;
this.dnsNameResolver = dnsNameResolver; this.dnsNameResolver = dnsNameResolver;
final String credentialsRequestBody;
try { try {
credentialsRequestBody = final String body = SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(cloudflareTurnTtl));
SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(requestedCredentialTtl.toSeconds())); this.request = HttpRequest.newBuilder()
} catch (final JsonProcessingException e) { .uri(URI.create(cloudflareTurnEndpoint))
.header("Content-Type", "application/json")
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
.POST(HttpRequest.BodyPublishers.ofString(body))
.build();
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }
// We repeat the same request to Cloudflare every time, so we can construct it once and re-use it
this.getCredentialsRequest = HttpRequest.newBuilder()
.uri(URI.create(cloudflareTurnEndpoint))
.header("Content-Type", "application/json")
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
.POST(HttpRequest.BodyPublishers.ofString(credentialsRequestBody))
.build();
this.clientCredentialTtl = clientCredentialTtl;
} }
public TurnToken retrieveFromCloudflare() throws IOException { public TurnToken retrieveFromCloudflare() throws IOException {
@ -122,7 +105,7 @@ public class CloudflareTurnCredentialsManager {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
final HttpResponse<String> response; final HttpResponse<String> response;
try { try {
response = cloudflareTurnClient.sendAsync(getCredentialsRequest, HttpResponse.BodyHandlers.ofString()).join(); response = cloudflareTurnClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join();
sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME) sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME)
.publishPercentileHistogram(true) .publishPercentileHistogram(true)
.tags("outcome", "success") .tags("outcome", "success")
@ -147,7 +130,6 @@ public class CloudflareTurnCredentialsManager {
return new TurnToken( return new TurnToken(
cloudflareTurnResponse.iceServers().username(), cloudflareTurnResponse.iceServers().username(),
cloudflareTurnResponse.iceServers().credential(), cloudflareTurnResponse.iceServers().credential(),
clientCredentialTtl.toSeconds(),
cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls, cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls,
cloudflareTurnComposedUrls, cloudflareTurnComposedUrls,
cloudflareTurnHostname cloudflareTurnHostname

View File

@ -5,15 +5,13 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List;
public record TurnToken( public record TurnToken(
String username, String username,
String password, String password,
@JsonProperty("ttl") long ttlSeconds,
@Nonnull List<String> urls, @Nonnull List<String> urls,
@Nonnull List<String> urlsWithIps, @Nonnull List<String> urlsWithIps,
@Nullable String hostname) { @Nullable String hostname) {

View File

@ -1,22 +1,34 @@
package org.whispersystems.textsecuregcm.auth.grpc; package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import java.util.Optional; import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import java.util.Optional;
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor { abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager; private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Metadata EMPTY_TRAILERS = new Metadata();
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager; this.grpcClientConnectionManager = grpcClientConnectionManager;
} }
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) {
throws ChannelNotFoundException { if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
return grpcClientConnectionManager.getAuthenticatedDevice(localAddress);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
return grpcClientConnectionManager.getAuthenticatedDevice(call); protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) {
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
return new ServerCall.Listener<>() {};
} }
} }

View File

@ -3,17 +3,12 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/** /**
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not * A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated * originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
* device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined * device are closed with an {@code UNAUTHENTICATED} status.
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
* reject the call with a status of {@code UNAVAILABLE}.
*/ */
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor { public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -26,15 +21,8 @@ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInt
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
try { return getAuthenticatedDevice(call)
return getAuthenticatedDevice(call) .map(ignored -> closeAsUnauthenticated(call))
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited .orElseGet(() -> next.startCall(call, headers));
// service via an authenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
.orElseGet(() -> next.startCall(call, headers));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
} }
} }

View File

@ -5,16 +5,12 @@ import io.grpc.Contexts;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/** /**
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an * A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED} * authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
* status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed * status.
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
*/ */
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor { public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -27,17 +23,10 @@ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInte
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
try { return getAuthenticatedDevice(call)
return getAuthenticatedDevice(call) .map(authenticatedDevice -> Contexts.interceptCall(Context.current()
.map(authenticatedDevice -> Contexts.interceptCall(Context.current() .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), call, headers, next))
call, headers, next)) .orElseGet(() -> closeAsUnauthenticated(call));
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
} }
} }

View File

@ -6,36 +6,16 @@
package org.whispersystems.textsecuregcm.configuration; package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import java.time.Duration;
import java.util.List; import java.util.List;
import jakarta.validation.constraints.Positive; import jakarta.validation.constraints.Positive;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString; import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
/**
* Configuration properties for Cloudflare TURN integration.
*
* @param apiToken the API token to use when requesting TURN tokens from Cloudflare
* @param endpoint the URI of the Cloudflare API endpoint that vends TURN tokens
* @param requestedCredentialTtl the lifetime of TURN tokens to request from Cloudflare
* @param clientCredentialTtl the time clients may cache a TURN token; must be less than or equal to {@link #requestedCredentialTtl}
* @param urls a collection of TURN URLs to include verbatim in responses to clients
* @param urlsWithIps a collection of {@link String#format(String, Object...)} patterns to be populated with resolved IP
* addresses for {@link #hostname} in responses to clients; each pattern must include a single
* {@code %s} placeholder for the IP address
* @param circuitBreaker a circuit breaker for requests to Cloudflare
* @param retry a retry policy for requests to Cloudflare
* @param hostname the hostname to resolve to IP addresses for use with {@link #urlsWithIps}; also transmitted to
* clients for use as an SNI when connecting to pre-resolved hosts
* @param numHttpClients the number of parallel HTTP clients to use to communicate with Cloudflare
*/
public record CloudflareTurnConfiguration(@NotNull SecretString apiToken, public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
@NotBlank String endpoint, @NotBlank String endpoint,
@NotNull Duration requestedCredentialTtl, @NotBlank long ttl,
@NotNull Duration clientCredentialTtl,
@NotNull @NotEmpty @Valid List<@NotBlank String> urls, @NotNull @NotEmpty @Valid List<@NotBlank String> urls,
@NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps, @NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps,
@NotNull @Valid CircuitBreakerConfiguration circuitBreaker, @NotNull @Valid CircuitBreakerConfiguration circuitBreaker,
@ -55,9 +35,4 @@ public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
retry = new RetryConfiguration(); retry = new RetryConfiguration();
} }
} }
@AssertTrue
public boolean isClientTtlShorterThanRequestedTtl() {
return clientCredentialTtl.compareTo(requestedCredentialTtl) <= 0;
}
} }

View File

@ -15,12 +15,16 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import jakarta.ws.rs.GET; import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces; import jakarta.ws.rs.Produces;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.MediaType;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager; import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly; import org.whispersystems.websocket.auth.ReadOnly;
@ -28,16 +32,14 @@ import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v2/calling") @Path("/v2/calling")
public class CallRoutingControllerV2 { public class CallRoutingControllerV2 {
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager; private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER =
Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
public CallRoutingControllerV2( public CallRoutingControllerV2(
final RateLimiters rateLimiters, final RateLimiters rateLimiters,
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager) { final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager
) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager; this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager;
} }
@ -56,17 +58,25 @@ public class CallRoutingControllerV2 {
@ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponse(responseCode = "429", description = "Rate limited.")
public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth) public GetCallingRelaysResponse getCallingRelays(
throws RateLimitExceededException, IOException { final @ReadOnly @Auth AuthenticatedDevice auth
) throws RateLimitExceededException, IOException {
final UUID aci = auth.getAccount().getUuid(); UUID aci = auth.getAccount().getUuid();
rateLimiters.getCallEndpointLimiter().validate(aci); rateLimiters.getCallEndpointLimiter().validate(aci);
List<TurnToken> tokens = new ArrayList<>();
try { try {
return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare())); tokens.add(cloudflareTurnCredentialsManager.retrieveFromCloudflare());
} catch (final Exception e) { } catch (Exception e) {
CLOUDFLARE_TURN_ERROR_COUNTER.increment(); CallRoutingControllerV2.CLOUDFLARE_TURN_ERROR_COUNTER.increment();
throw e; throw e;
} }
return new GetCallingRelaysResponse(tokens);
}
public record GetCallingRelaysResponse(
List<TurnToken> relays
) {
} }
} }

View File

@ -44,6 +44,7 @@ import java.util.EnumMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -51,6 +52,7 @@ import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
@ -72,7 +74,6 @@ import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -401,7 +402,7 @@ public class DeviceController {
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) { private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
try { try {
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).platform()); return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform());
} catch (final UnrecognizedUserAgentException ignored) { } catch (final UnrecognizedUserAgentException ignored) {
return linkedDeviceListenersForUnrecognizedPlatforms; return linkedDeviceListenersForUnrecognizedPlatforms;
} }
@ -599,9 +600,25 @@ public class DeviceController {
} }
private static io.micrometer.core.instrument.Tag primaryPlatformTag(final Account account) { private static io.micrometer.core.instrument.Tag primaryPlatformTag(final Account account) {
final Device primaryDevice = account.getPrimaryDevice();
Optional<ClientPlatform> clientPlatform = Optional.empty();
if (StringUtils.isNotBlank(primaryDevice.getGcmId())) {
clientPlatform = Optional.of(ClientPlatform.ANDROID);
} else if (StringUtils.isNotBlank(primaryDevice.getApnId())) {
clientPlatform = Optional.of(ClientPlatform.IOS);
}
clientPlatform = clientPlatform.or(() -> Optional.ofNullable(
switch (primaryDevice.getUserAgent()) {
case "OWA" -> ClientPlatform.ANDROID;
case "OWI", "OWP" -> ClientPlatform.IOS;
case "OWD" -> ClientPlatform.DESKTOP;
case null, default -> null;
}));
return io.micrometer.core.instrument.Tag.of( return io.micrometer.core.instrument.Tag.of(
"primaryPlatform", "primaryPlatform",
DevicePlatformUtil.getDevicePlatform(account.getPrimaryDevice()) clientPlatform
.map(p -> p.name().toLowerCase(Locale.ROOT)) .map(p -> p.name().toLowerCase(Locale.ROOT))
.orElse("unknown")); .orElse("unknown"));
} }

View File

@ -97,20 +97,21 @@ public class DonationController {
return redeemedReceiptsManager.put( return redeemedReceiptsManager.put(
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid()) receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid())
.thenCompose(receiptMatched -> { .thenCompose(receiptMatched -> {
if (!receiptMatched) { if (!receiptMatched) {
return CompletableFuture.completedFuture(
Response.status(Status.BAD_REQUEST).entity("receipt serial is already redeemed")
.type(MediaType.TEXT_PLAIN_TYPE).build());
}
return CompletableFuture.completedFuture( return accountsManager.getByAccountIdentifierAsync(auth.getAccount().getUuid())
Response.status(Status.BAD_REQUEST).entity("receipt serial is already redeemed") .thenCompose(optionalAccount ->
.type(MediaType.TEXT_PLAIN_TYPE).build()); optionalAccount.map(account -> accountsManager.updateAsync(account, a -> {
}
return accountsManager.updateAsync(auth.getAccount(), a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible())); a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) { if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId); a.makeBadgePrimaryIfExists(clock, badgeId);
} }
}) })).orElse(CompletableFuture.completedFuture(null)))
.thenApply(ignored -> Response.ok().build()); .thenApply(ignored -> Response.ok().build());
}); });
}).thenCompose(Function.identity()); }).thenCompose(Function.identity());
} }

View File

@ -1,13 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import java.util.List;
public record GetCallingRelaysResponse(List<TurnToken> relays) {
}

View File

@ -152,7 +152,7 @@ public class KeysController {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4); final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
if (!setKeysRequest.preKeys().isEmpty()) { if (setKeysRequest.preKeys() != null && !setKeysRequest.preKeys().isEmpty()) {
Metrics.counter(STORE_KEYS_COUNTER_NAME, Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec"))) Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec")))
.increment(); .increment();
@ -168,7 +168,7 @@ public class KeysController {
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey())); storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
} }
if (!setKeysRequest.pqPreKeys().isEmpty()) { if (setKeysRequest.pqPreKeys() != null && !setKeysRequest.pqPreKeys().isEmpty()) {
Metrics.counter(STORE_KEYS_COUNTER_NAME, Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber"))) Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber")))
.increment(); .increment();
@ -192,7 +192,11 @@ public class KeysController {
final IdentityKey identityKey, final IdentityKey identityKey,
@Nullable final String userAgent) { @Nullable final String userAgent) {
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(setKeysRequest.pqPreKeys()); final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>();
if (setKeysRequest.pqPreKeys() != null) {
signedPreKeys.addAll(setKeysRequest.pqPreKeys());
}
if (setKeysRequest.pqLastResortPreKey() != null) { if (setKeysRequest.pqLastResortPreKey() != null) {
signedPreKeys.add(setKeysRequest.pqLastResortPreKey()); signedPreKeys.add(setKeysRequest.pqLastResortPreKey());

View File

@ -43,6 +43,7 @@ import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status; import jakarta.ws.rs.core.Response.Status;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -95,7 +96,6 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException; import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.push.MessageUtil;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
@ -105,6 +105,7 @@ import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
@ -113,7 +114,11 @@ import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.WebsocketHeaders; import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.auth.ReadOnly; import org.whispersystems.websocket.auth.ReadOnly;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/messages") @Path("/v1/messages")
@ -140,6 +145,8 @@ public class MessageController {
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final Clock clock; private final Clock clock;
private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture<?>[0]; private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture<?>[0];
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes"); private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
@ -436,16 +443,11 @@ public class MessageController {
final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream() final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId)); .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
final Optional<Byte> syncMessageSenderDeviceId = messageType == MessageType.SYNC
? Optional.ofNullable(sender).map(authenticatedDevice -> authenticatedDevice.getAuthenticatedDevice().getId())
: Optional.empty();
try { try {
messageSender.sendMessages(destination, messageSender.sendMessages(destination,
destinationIdentifier, destinationIdentifier,
messagesByDeviceId, messagesByDeviceId,
registrationIdsByDeviceId, registrationIdsByDeviceId,
syncMessageSenderDeviceId,
userAgent); userAgent);
} catch (final MismatchedDevicesException e) { } catch (final MismatchedDevicesException e) {
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) { if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
@ -561,9 +563,7 @@ public class MessageController {
final ContainerRequestContext context) { final ContainerRequestContext context) {
// Perform fast, inexpensive checks before attempting to resolve recipients // Perform fast, inexpensive checks before attempting to resolve recipients
if (MessageUtil.hasDuplicateDevices(multiRecipientMessage)) { validateNoDuplicateDevices(multiRecipientMessage);
throw new BadRequestException("Multi-recipient message contains duplicate recipient");
}
if (groupSendTokenHeader == null && combinedUnidentifiedSenderAccessKeys == null) { if (groupSendTokenHeader == null && combinedUnidentifiedSenderAccessKeys == null) {
throw new NotAuthorizedException("A group send endorsement token or unidentified access key is required for non-story messages"); throw new NotAuthorizedException("A group send endorsement token or unidentified access key is required for non-story messages");
@ -582,14 +582,7 @@ public class MessageController {
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient // At this point, the caller has at least superficially provided the information needed to send a multi-recipient
// message. Attempt to resolve the destination service identifiers to Signal accounts. // message. Attempt to resolve the destination service identifiers to Signal accounts.
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients = final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
MessageUtil.resolveRecipients(accountsManager, multiRecipientMessage); resolveRecipients(multiRecipientMessage, groupSendTokenHeader == null);
final List<ServiceIdentifier> unresolvedRecipientServiceIdentifiers =
MessageUtil.getUnresolvedRecipients(multiRecipientMessage, resolvedRecipients);
if (groupSendTokenHeader == null && !unresolvedRecipientServiceIdentifiers.isEmpty()) {
throw new NotFoundException();
}
// Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above. // Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above.
// Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them // Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them
@ -605,6 +598,17 @@ public class MessageController {
urgent, urgent,
context); context);
final List<ServiceIdentifier> unresolvedRecipientServiceIdentifiers;
if (groupSendTokenHeader != null) {
unresolvedRecipientServiceIdentifiers = multiRecipientMessage.getRecipients().entrySet().stream()
.filter(entry -> !resolvedRecipients.containsKey(entry.getValue()))
.map(entry -> ServiceIdentifier.fromLibsignal(entry.getKey()))
.toList();
} else {
unresolvedRecipientServiceIdentifiers = List.of();
}
return new SendMultiRecipientMessageResponse(unresolvedRecipientServiceIdentifiers); return new SendMultiRecipientMessageResponse(unresolvedRecipientServiceIdentifiers);
} }
@ -616,14 +620,12 @@ public class MessageController {
final ContainerRequestContext context) { final ContainerRequestContext context) {
// Perform fast, inexpensive checks before attempting to resolve recipients // Perform fast, inexpensive checks before attempting to resolve recipients
if (MessageUtil.hasDuplicateDevices(multiRecipientMessage)) { validateNoDuplicateDevices(multiRecipientMessage);
throw new BadRequestException("Multi-recipient message contains duplicate recipient");
}
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient // At this point, the caller has at least superficially provided the information needed to send a multi-recipient
// message. Attempt to resolve the destination service identifiers to Signal accounts. // message. Attempt to resolve the destination service identifiers to Signal accounts.
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients = final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
MessageUtil.resolveRecipients(accountsManager, multiRecipientMessage); resolveRecipients(multiRecipientMessage, false);
// We might filter out all the recipients of a story (if none exist). // We might filter out all the recipients of a story (if none exist).
// In this case there is no error so we should just return 200 now. // In this case there is no error so we should just return 200 now.
@ -907,4 +909,43 @@ public class MessageController {
return Response.status(Status.ACCEPTED) return Response.status(Status.ACCEPTED)
.build(); .build();
} }
private static void validateNoDuplicateDevices(final SealedSenderMultiRecipientMessage multiRecipientMessage) {
final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID + 1];
for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) {
if (recipient.getDevices().length == 1) {
// A recipient can't have repeated devices if they only have one device
continue;
}
Arrays.fill(usedDeviceIds, false);
for (final byte deviceId : recipient.getDevices()) {
if (usedDeviceIds[deviceId]) {
throw new BadRequestException();
}
usedDeviceIds[deviceId] = true;
}
}
}
private Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolveRecipients(final SealedSenderMultiRecipientMessage multiRecipientMessage,
final boolean throwOnNotFound) {
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.flatMap(serviceIdAndRecipient -> {
final ServiceIdentifier serviceIdentifier =
ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.flatMap(Mono::justOrEmpty)
.switchIfEmpty(throwOnNotFound ? Mono.error(NotFoundException::new) : Mono.empty())
.map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account));
}, MAX_FETCH_ACCOUNT_CONCURRENCY)
.collectMap(Tuple2::getT1, Tuple2::getT2)
.blockOptional()
.orElse(Collections.emptyMap());
}
} }

View File

@ -428,7 +428,7 @@ public class OneTimeDonationController {
@Nullable @Nullable
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) { private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
try { try {
return UserAgentUtil.parseUserAgentString(userAgentString).platform(); return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
} catch (final UnrecognizedUserAgentException e) { } catch (final UnrecognizedUserAgentException e) {
return null; return null;
} }

View File

@ -47,7 +47,6 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.ServiceId;
@ -124,7 +123,6 @@ public class ProfileController {
private static final String EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE = "expiringProfileKey"; private static final String EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE = "expiringProfileKey";
private static final String VERSION_NOT_FOUND_COUNTER_NAME = name(ProfileController.class, "versionNotFound"); private static final String VERSION_NOT_FOUND_COUNTER_NAME = name(ProfileController.class, "versionNotFound");
private static final String DUPLICATE_AUTHENTICATION_COUNTER_NAME = name(ProfileController.class, "duplicateAuthentication");
public ProfileController( public ProfileController(
Clock clock, Clock clock,
@ -206,12 +204,11 @@ public class ProfileController {
.build())); .build()));
} }
final List<AccountBadge> updatedBadges = request.badges()
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, auth.getAccount().getBadges()))
.orElseGet(() -> auth.getAccount().getBadges());
accountsManager.update(auth.getAccount(), a -> { accountsManager.update(auth.getAccount(), a -> {
final List<AccountBadge> updatedBadges = request.badges()
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
.orElseGet(a::getBadges);
a.setBadges(clock, updatedBadges); a.setBadges(clock, updatedBadges);
a.setCurrentProfileVersion(request.version()); a.setCurrentProfileVersion(request.version());
}); });
@ -232,12 +229,11 @@ public class ProfileController {
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version, @PathParam("version") String version)
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException { throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount); final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "getVersionedProfile", userAgent); final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier);
return buildVersionedProfileResponse(targetAccount, return buildVersionedProfileResponse(targetAccount,
version, version,
@ -256,8 +252,7 @@ public class ProfileController {
@PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version, @PathParam("version") String version,
@PathParam("credentialRequest") String credentialRequest, @PathParam("credentialRequest") String credentialRequest,
@QueryParam("credentialType") String credentialType, @QueryParam("credentialType") String credentialType)
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException { throws RateLimitExceededException {
if (!EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE.equals(credentialType)) { if (!EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE.equals(credentialType)) {
@ -265,7 +260,7 @@ public class ProfileController {
} }
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount); final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "credentialRequest", userAgent); final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier);
final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false); final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false);
return buildExpiringProfileKeyCredentialProfileResponse(targetAccount, return buildExpiringProfileKeyCredentialProfileResponse(targetAccount,
@ -287,7 +282,8 @@ public class ProfileController {
@HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken, @HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@PathParam("identifier") ServiceIdentifier identifier) @PathParam("identifier") ServiceIdentifier identifier,
@QueryParam("ca") boolean useCaCertificate)
throws RateLimitExceededException { throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount); final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
@ -306,7 +302,7 @@ public class ProfileController {
} }
} else { } else {
targetAccount = verifyPermissionToReceiveProfile( targetAccount = verifyPermissionToReceiveProfile(
maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier, "getUnversionedProfile", userAgent); maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier);
} }
return switch (identifier.identityType()) { return switch (identifier.identityType()) {
case ACI -> buildBaseProfileResponseForAccountIdentity(targetAccount, case ACI -> buildBaseProfileResponseForAccountIdentity(targetAccount,
@ -389,7 +385,7 @@ public class ProfileController {
profileKeyCredentialResponse = ProfileHelper.getExpiringProfileKeyCredential(HexFormat.of().parseHex(encodedCredentialRequest), profileKeyCredentialResponse = ProfileHelper.getExpiringProfileKeyCredential(HexFormat.of().parseHex(encodedCredentialRequest),
profile, new ServiceId.Aci(account.getUuid()), zkProfileOperations); profile, new ServiceId.Aci(account.getUuid()), zkProfileOperations);
} catch (VerificationFailedException | InvalidInputException e) { } catch (VerificationFailedException | InvalidInputException e) {
throw new BadRequestException(e); throw new BadRequestException(Response.status(Response.Status.BAD_REQUEST).build(), e);
} }
return profileKeyCredentialResponse; return profileKeyCredentialResponse;
}) })
@ -477,15 +473,7 @@ public class ProfileController {
*/ */
private Account verifyPermissionToReceiveProfile(final Optional<Account> maybeRequester, private Account verifyPermissionToReceiveProfile(final Optional<Account> maybeRequester,
final Optional<Anonymous> maybeAccessKey, final Optional<Anonymous> maybeAccessKey,
final ServiceIdentifier accountIdentifier, final ServiceIdentifier accountIdentifier) throws RateLimitExceededException {
final String endpoint,
@Nullable final String userAgent) throws RateLimitExceededException {
if (maybeRequester.isPresent() && maybeAccessKey.isPresent()) {
Metrics.counter(DUPLICATE_AUTHENTICATION_COUNTER_NAME,
Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), io.micrometer.core.instrument.Tag.of("endpoint", endpoint)))
.increment();
}
if (maybeRequester.isPresent()) { if (maybeRequester.isPresent()) {
rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid()); rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid());

View File

@ -755,7 +755,7 @@ public class SubscriptionController {
@Nullable @Nullable
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) { private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
try { try {
return UserAgentUtil.parseUserAgentString(userAgentString).platform(); return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
} catch (final UnrecognizedUserAgentException e) { } catch (final UnrecognizedUserAgentException e) {
return null; return null;
} }

View File

@ -5,15 +5,11 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import java.util.List; import java.util.List;
public record SetKeysRequest( public record SetKeysRequest(
@NotNull
@Valid @Valid
@Size(max=100)
@Schema(description = """ @Schema(description = """
A list of unsigned elliptic-curve prekeys to use for this device. If present and not empty, replaces all stored A list of unsigned elliptic-curve prekeys to use for this device. If present and not empty, replaces all stored
unsigned EC prekeys for the device; if absent or empty, any stored unsigned EC prekeys for the device are not unsigned EC prekeys for the device; if absent or empty, any stored unsigned EC prekeys for the device are not
@ -29,9 +25,7 @@ public record SetKeysRequest(
""") """)
ECSignedPreKey signedPreKey, ECSignedPreKey signedPreKey,
@NotNull
@Valid @Valid
@Size(max=100)
@Schema(description = """ @Schema(description = """
A list of signed post-quantum one-time prekeys to use for this device. Each key must have a valid signature from A list of signed post-quantum one-time prekeys to use for this device. Each key must have a valid signature from
the identity key in this request. If present and not empty, replaces all stored unsigned PQ prekeys for the the identity key in this request. If present and not empty, replaces all stored unsigned PQ prekeys for the
@ -46,16 +40,4 @@ public record SetKeysRequest(
deleted. If present, must have a valid signature from the identity key in this request. deleted. If present, must have a valid signature from the identity key in this request.
""") """)
KEMSignedPreKey pqLastResortPreKey) { KEMSignedPreKey pqLastResortPreKey) {
public SetKeysRequest {
// Its a little counter-intuitive, but this compact constructor allows a default value
// to be used when one isnt specified, allowing the field to still be
// validated as @NotNull
if (preKeys == null) {
preKeys = List.of();
}
if (pqPreKeys == null) {
pqPreKeys = List.of();
}
}
} }

View File

@ -81,16 +81,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
@Nullable final UserAgent userAgent = RequestAttributesUtil.getUserAgent() if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) {
.map(userAgentString -> {
try {
return UserAgentUtil.parseUserAgentString(userAgentString);
} catch (final UnrecognizedUserAgentException e) {
return null;
}
}).orElse(null);
if (shouldBlock(userAgent)) {
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata()); call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
return new ServerCall.Listener<>() {}; return new ServerCall.Listener<>() {};
} else { } else {
@ -117,28 +108,28 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
return true; return true;
} }
if (blockedVersionsByPlatform.containsKey(userAgent.platform())) { if (blockedVersionsByPlatform.containsKey(userAgent.getPlatform())) {
if (blockedVersionsByPlatform.get(userAgent.platform()).contains(userAgent.version())) { if (blockedVersionsByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) {
recordDeprecation(userAgent, BLOCKED_CLIENT_REASON); recordDeprecation(userAgent, BLOCKED_CLIENT_REASON);
shouldBlock = true; shouldBlock = true;
} }
} }
if (minimumVersionsByPlatform.containsKey(userAgent.platform())) { if (minimumVersionsByPlatform.containsKey(userAgent.getPlatform())) {
if (userAgent.version().isLowerThan(minimumVersionsByPlatform.get(userAgent.platform()))) { if (userAgent.getVersion().isLowerThan(minimumVersionsByPlatform.get(userAgent.getPlatform()))) {
recordDeprecation(userAgent, EXPIRED_CLIENT_REASON); recordDeprecation(userAgent, EXPIRED_CLIENT_REASON);
shouldBlock = true; shouldBlock = true;
} }
} }
if (versionsPendingBlockByPlatform.containsKey(userAgent.platform())) { if (versionsPendingBlockByPlatform.containsKey(userAgent.getPlatform())) {
if (versionsPendingBlockByPlatform.get(userAgent.platform()).contains(userAgent.version())) { if (versionsPendingBlockByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) {
recordPendingDeprecation(userAgent, BLOCKED_CLIENT_REASON); recordPendingDeprecation(userAgent, BLOCKED_CLIENT_REASON);
} }
} }
if (versionsPendingDeprecationByPlatform.containsKey(userAgent.platform())) { if (versionsPendingDeprecationByPlatform.containsKey(userAgent.getPlatform())) {
if (userAgent.version().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.platform()))) { if (userAgent.getVersion().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.getPlatform()))) {
recordPendingDeprecation(userAgent, EXPIRED_CLIENT_REASON); recordPendingDeprecation(userAgent, EXPIRED_CLIENT_REASON);
} }
} }
@ -148,13 +139,13 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
private void recordDeprecation(final UserAgent userAgent, final String reason) { private void recordDeprecation(final UserAgent userAgent, final String reason) {
Metrics.counter(DEPRECATED_CLIENT_COUNTER_NAME, Metrics.counter(DEPRECATED_CLIENT_COUNTER_NAME,
PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized", PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized",
REASON_TAG_NAME, reason).increment(); REASON_TAG_NAME, reason).increment();
} }
private void recordPendingDeprecation(final UserAgent userAgent, final String reason) { private void recordPendingDeprecation(final UserAgent userAgent, final String reason) {
Metrics.counter(PENDING_DEPRECATION_COUNTER_NAME, Metrics.counter(PENDING_DEPRECATION_COUNTER_NAME,
PLATFORM_TAG, userAgent.platform().name().toLowerCase(), PLATFORM_TAG, userAgent.getPlatform().name().toLowerCase(),
REASON_TAG_NAME, reason).increment(); REASON_TAG_NAME, reason).increment();
} }
} }

View File

@ -15,6 +15,8 @@ import jakarta.ws.rs.container.ContainerRequestFilter;
import jakarta.ws.rs.core.SecurityContext; import jakarta.ws.rs.core.SecurityContext;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@ -68,8 +70,8 @@ public class RestDeprecationFilter implements ContainerRequestFilter {
try { try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
final ClientPlatform platform = userAgent.platform(); final ClientPlatform platform = userAgent.getPlatform();
final Semver version = userAgent.version(); final Semver version = userAgent.getVersion();
if (!minimumRestFreeVersion.containsKey(platform)) { if (!minimumRestFreeVersion.containsKey(platform)) {
return; return;
} }

View File

@ -1,12 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
/**
* Indicates that a remote channel was not found for a given server call or remote address.
*/
public class ChannelNotFoundException extends Exception {
}

View File

@ -1,55 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with
* {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise.
*/
public class ChannelShutdownInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (!grpcClientConnectionManager.handleServerCallStart(call)) {
// Don't allow new calls if the connection is getting ready to close
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {
@Override
public void onComplete() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onComplete();
}
@Override
public void onCancel() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onCancel();
}
};
}
}

View File

@ -7,9 +7,8 @@ package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Clock; import java.time.Clock;
import java.util.Collection;
import java.util.List; import java.util.List;
import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
@ -19,6 +18,8 @@ import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair;
import org.signal.libsignal.zkgroup.groupsend.GroupSendFullToken; import org.signal.libsignal.zkgroup.groupsend.GroupSendFullToken;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import reactor.core.publisher.Mono;
public class GroupSendTokenUtil { public class GroupSendTokenUtil {
private final ServerSecretParams serverSecretParams; private final ServerSecretParams serverSecretParams;
@ -29,22 +30,16 @@ public class GroupSendTokenUtil {
this.clock = clock; this.clock = clock;
} }
public void checkGroupSendToken(final ByteString serializedGroupSendToken, public Mono<Void> checkGroupSendToken(final ByteString serializedGroupSendToken, List<ServiceIdentifier> serviceIdentifiers) {
final ServiceIdentifier serviceIdentifier) throws StatusException {
checkGroupSendToken(serializedGroupSendToken, List.of(serviceIdentifier.toLibsignal()));
}
public void checkGroupSendToken(final ByteString serializedGroupSendToken,
final Collection<ServiceId> serviceIds) throws StatusException {
try { try {
final GroupSendFullToken token = new GroupSendFullToken(serializedGroupSendToken.toByteArray()); final GroupSendFullToken token = new GroupSendFullToken(serializedGroupSendToken.toByteArray());
final List<ServiceId> serviceIds = serviceIdentifiers.stream().map(ServiceIdentifier::toLibsignal).toList();
token.verify(serviceIds, clock.instant(), GroupSendDerivedKeyPair.forExpiration(token.getExpiration(), serverSecretParams)); token.verify(serviceIds, clock.instant(), GroupSendDerivedKeyPair.forExpiration(token.getExpiration(), serverSecretParams));
} catch (final InvalidInputException e) { return Mono.empty();
throw Status.INVALID_ARGUMENT.asException(); } catch (InvalidInputException e) {
return Mono.error(Status.INVALID_ARGUMENT.asException());
} catch (VerificationFailedException e) { } catch (VerificationFailedException e) {
throw Status.UNAUTHENTICATED.asException(); return Mono.error(Status.UNAUTHENTICATED.asException());
} }
} }
} }

View File

@ -7,11 +7,12 @@ package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.time.Clock; import java.time.Clock;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import org.signal.chat.keys.CheckIdentityKeyRequest; import org.signal.chat.keys.CheckIdentityKeyRequest;
import org.signal.chat.keys.CheckIdentityKeyResponse; import org.signal.chat.keys.CheckIdentityKeyResponse;
import org.signal.chat.keys.GetPreKeysAnonymousRequest; import org.signal.chat.keys.GetPreKeysAnonymousRequest;
@ -51,24 +52,16 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony
: KeysGrpcHelper.ALL_DEVICES; : KeysGrpcHelper.ALL_DEVICES;
return switch (request.getAuthorizationCase()) { return switch (request.getAuthorizationCase()) {
case GROUP_SEND_TOKEN -> { case GROUP_SEND_TOKEN ->
try { groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), List.of(serviceIdentifier))
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), serviceIdentifier); .then(lookUpAccount(serviceIdentifier, Status.NOT_FOUND))
yield lookUpAccount(serviceIdentifier, Status.NOT_FOUND)
.flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager)); .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager));
} catch (final StatusException e) {
yield Mono.error(e);
}
}
case UNIDENTIFIED_ACCESS_KEY -> case UNIDENTIFIED_ACCESS_KEY ->
lookUpAccount(serviceIdentifier, Status.UNAUTHENTICATED) lookUpAccount(serviceIdentifier, Status.UNAUTHENTICATED)
.flatMap(targetAccount -> .flatMap(targetAccount ->
UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray()) UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray())
? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager) ? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager)
: Mono.error(Status.UNAUTHENTICATED.asException())); : Mono.error(Status.UNAUTHENTICATED.asException()));
default -> Mono.error(Status.INVALID_ARGUMENT.asException()); default -> Mono.error(Status.INVALID_ARGUMENT.asException());
}; };
} }

View File

@ -1,302 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Clock;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.signal.chat.messages.IndividualRecipientMessageBundle;
import org.signal.chat.messages.MultiRecipientMismatchedDevices;
import org.signal.chat.messages.SendMessageResponse;
import org.signal.chat.messages.SendMultiRecipientMessageRequest;
import org.signal.chat.messages.SendMultiRecipientMessageResponse;
import org.signal.chat.messages.SendMultiRecipientStoryRequest;
import org.signal.chat.messages.SendSealedSenderMessageRequest;
import org.signal.chat.messages.SendStoryMessageRequest;
import org.signal.chat.messages.SimpleMessagesAnonymousGrpc;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.InvalidVersionException;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.push.MessageUtil;
import org.whispersystems.textsecuregcm.spam.GrpcResponse;
import org.whispersystems.textsecuregcm.spam.MessageType;
import org.whispersystems.textsecuregcm.spam.SpamCheckResult;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.MessagesAnonymousImplBase {
private final AccountsManager accountsManager;
private final RateLimiters rateLimiters;
private final MessageSender messageSender;
private final GroupSendTokenUtil groupSendTokenUtil;
private final CardinalityEstimator messageByteLimitEstimator;
private final SpamChecker spamChecker;
private final Clock clock;
private static final SendMessageResponse SEND_MESSAGE_SUCCESS_RESPONSE = SendMessageResponse.newBuilder().build();
public MessagesAnonymousGrpcService(final AccountsManager accountsManager,
final RateLimiters rateLimiters,
final MessageSender messageSender,
final GroupSendTokenUtil groupSendTokenUtil,
final CardinalityEstimator messageByteLimitEstimator,
final SpamChecker spamChecker,
final Clock clock) {
this.accountsManager = accountsManager;
this.rateLimiters = rateLimiters;
this.messageSender = messageSender;
this.messageByteLimitEstimator = messageByteLimitEstimator;
this.spamChecker = spamChecker;
this.clock = clock;
this.groupSendTokenUtil = groupSendTokenUtil;
}
@Override
public SendMessageResponse sendSingleRecipientMessage(final SendSealedSenderMessageRequest request)
throws StatusException, RateLimitExceededException {
final ServiceIdentifier destinationServiceIdentifier =
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination());
final Account destination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier)
.orElseThrow(Status.UNAUTHENTICATED::asException);
switch (request.getAuthorizationCase()) {
case UNIDENTIFIED_ACCESS_KEY -> {
if (!UnidentifiedAccessUtil.checkUnidentifiedAccess(destination, request.getUnidentifiedAccessKey().toByteArray())) {
throw Status.UNAUTHENTICATED.asException();
}
}
case GROUP_SEND_TOKEN ->
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), destinationServiceIdentifier);
case AUTHORIZATION_NOT_SET -> throw Status.UNAUTHENTICATED.asException();
}
return sendIndividualMessage(destination,
destinationServiceIdentifier,
request.getMessages(),
request.getEphemeral(),
request.getUrgent(),
false);
}
@Override
public SendMessageResponse sendStory(final SendStoryMessageRequest request)
throws StatusException, RateLimitExceededException {
final ServiceIdentifier destinationServiceIdentifier =
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination());
final Optional<Account> maybeDestination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier);
if (maybeDestination.isEmpty()) {
// Don't reveal to unauthenticated callers whether a destination account actually exists
return SEND_MESSAGE_SUCCESS_RESPONSE;
}
final Account destination = maybeDestination.get();
rateLimiters.getStoriesLimiter().validate(destination.getIdentifier(IdentityType.ACI));
return sendIndividualMessage(destination,
destinationServiceIdentifier,
request.getMessages(),
false,
request.getUrgent(),
true);
}
private SendMessageResponse sendIndividualMessage(final Account destination,
final ServiceIdentifier destinationServiceIdentifier,
final IndividualRecipientMessageBundle messages,
final boolean ephemeral,
final boolean urgent,
final boolean story) throws StatusException, RateLimitExceededException {
final SpamCheckResult<GrpcResponse<SendMessageResponse>> spamCheckResult =
spamChecker.checkForIndividualRecipientSpamGrpc(
story ? MessageType.INDIVIDUAL_STORY : MessageType.INDIVIDUAL_SEALED_SENDER,
Optional.empty(),
Optional.of(destination),
destinationServiceIdentifier);
if (spamCheckResult.response().isPresent()) {
return spamCheckResult.response().get().getResponseOrThrowStatus();
}
try {
final int totalPayloadLength = messages.getMessagesMap().values().stream()
.mapToInt(message -> message.getPayload().size())
.sum();
rateLimiters.getInboundMessageBytes().validate(destinationServiceIdentifier.uuid(), totalPayloadLength);
} catch (final RateLimitExceededException e) {
messageByteLimitEstimator.add(destinationServiceIdentifier.uuid().toString());
throw e;
}
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = messages.getMessagesMap().entrySet()
.stream()
.collect(Collectors.toMap(
entry -> DeviceIdUtil.validate(entry.getKey()),
entry -> {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER)
.setClientTimestamp(messages.getTimestamp())
.setServerTimestamp(clock.millis())
.setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString())
.setEphemeral(ephemeral)
.setUrgent(urgent)
.setStory(story)
.setContent(entry.getValue().getPayload());
spamCheckResult.token().ifPresent(reportSpamToken ->
envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken)));
return envelopeBuilder.build();
}
));
final Map<Byte, Integer> registrationIdsByDeviceId = messages.getMessagesMap().entrySet().stream()
.collect(Collectors.toMap(
entry -> entry.getKey().byteValue(),
entry -> entry.getValue().getRegistrationId()));
return MessagesGrpcHelper.sendMessage(messageSender,
destination,
destinationServiceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
Optional.empty());
}
@Override
public SendMultiRecipientMessageResponse sendMultiRecipientMessage(final SendMultiRecipientMessageRequest request)
throws StatusException {
final SealedSenderMultiRecipientMessage multiRecipientMessage =
parseAndValidateMultiRecipientMessage(request.getMessage().getPayload().toByteArray());
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), multiRecipientMessage.getRecipients().keySet());
return sendMultiRecipientMessage(multiRecipientMessage,
request.getMessage().getTimestamp(),
request.getEphemeral(),
request.getUrgent(),
false);
}
@Override
public SendMultiRecipientMessageResponse sendMultiRecipientStory(final SendMultiRecipientStoryRequest request)
throws StatusException {
final SealedSenderMultiRecipientMessage multiRecipientMessage =
parseAndValidateMultiRecipientMessage(request.getMessage().getPayload().toByteArray());
return sendMultiRecipientMessage(multiRecipientMessage,
request.getMessage().getTimestamp(),
false,
request.getUrgent(),
true)
.toBuilder()
// Don't identify unresolved recipients for stories
.clearUnresolvedRecipients()
.build();
}
private SendMultiRecipientMessageResponse sendMultiRecipientMessage(
final SealedSenderMultiRecipientMessage multiRecipientMessage,
final long timestamp,
final boolean ephemeral,
final boolean urgent,
final boolean story) throws StatusException {
final SpamCheckResult<GrpcResponse<SendMultiRecipientMessageResponse>> spamCheckResult =
spamChecker.checkForMultiRecipientSpamGrpc(story
? MessageType.MULTI_RECIPIENT_STORY
: MessageType.MULTI_RECIPIENT_SEALED_SENDER);
if (spamCheckResult.response().isPresent()) {
return spamCheckResult.response().get().getResponseOrThrowStatus();
}
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient
// message. Attempt to resolve the destination service identifiers to Signal accounts.
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
MessageUtil.resolveRecipients(accountsManager, multiRecipientMessage);
try {
messageSender.sendMultiRecipientMessage(multiRecipientMessage,
resolvedRecipients,
timestamp,
story,
ephemeral,
urgent,
RequestAttributesUtil.getUserAgent().orElse(null));
final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder();
MessageUtil.getUnresolvedRecipients(multiRecipientMessage, resolvedRecipients).stream()
.map(ServiceIdentifierUtil::toGrpcServiceIdentifier)
.forEach(responseBuilder::addUnresolvedRecipients);
return responseBuilder.build();
} catch (final MessageTooLargeException e) {
throw Status.INVALID_ARGUMENT
.withDescription("Message for an individual recipient was too large")
.withCause(e)
.asRuntimeException();
} catch (final MultiRecipientMismatchedDevicesException e) {
final MultiRecipientMismatchedDevices.Builder mismatchedDevicesBuilder =
MultiRecipientMismatchedDevices.newBuilder();
e.getMismatchedDevicesByServiceIdentifier().forEach((serviceIdentifier, mismatchedDevices) ->
mismatchedDevicesBuilder.addMismatchedDevices(MessagesGrpcHelper.buildMismatchedDevices(serviceIdentifier, mismatchedDevices)));
return SendMultiRecipientMessageResponse.newBuilder()
.setMismatchedDevices(mismatchedDevicesBuilder)
.build();
}
}
private SealedSenderMultiRecipientMessage parseAndValidateMultiRecipientMessage(
final byte[] serializedMultiRecipientMessage) throws StatusException {
final SealedSenderMultiRecipientMessage multiRecipientMessage;
try {
multiRecipientMessage = SealedSenderMultiRecipientMessage.parse(serializedMultiRecipientMessage);
} catch (final InvalidMessageException | InvalidVersionException e) {
throw Status.INVALID_ARGUMENT.withCause(e).asException();
}
// Check that the request is well-formed and doesn't contain repeated entries for the same device for the same
// recipient
if (MessageUtil.hasDuplicateDevices(multiRecipientMessage)) {
throw Status.INVALID_ARGUMENT.withDescription("Multi-recipient message contains duplicate recipient").asException();
}
return multiRecipientMessage;
}
}

View File

@ -1,91 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Status;
import io.grpc.StatusException;
import java.util.Map;
import java.util.Optional;
import org.signal.chat.messages.MismatchedDevices;
import org.signal.chat.messages.SendMessageResponse;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.storage.Account;
public class MessagesGrpcHelper {
private static final SendMessageResponse SEND_MESSAGE_SUCCESS_RESPONSE = SendMessageResponse.newBuilder().build();
/**
* Sends a "bundle" of messages to an individual destination account, mapping common exceptions to appropriate gRPC
* statuses.
*
* @param messageSender the {@code MessageSender} instance to use to send the messages
* @param destination the destination account for the messages
* @param destinationServiceIdentifier the service identifier for the destination account
* @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
* caller's own account), contains the ID of the device that sent the message
*
* @return a response object to send to callers
*
* @throws StatusException if the message bundle could not be sent due to an out-of-date device set or an invalid
* message payload
* @throws RateLimitExceededException if the message bundle could not be sent due to a violated rated limit
*/
public static SendMessageResponse sendMessage(final MessageSender messageSender,
final Account destination,
final ServiceIdentifier destinationServiceIdentifier,
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId)
throws StatusException, RateLimitExceededException {
try {
messageSender.sendMessages(destination,
destinationServiceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
syncMessageSenderDeviceId,
RequestAttributesUtil.getUserAgent().orElse(null));
return SEND_MESSAGE_SUCCESS_RESPONSE;
} catch (final MismatchedDevicesException e) {
return SendMessageResponse.newBuilder()
.setMismatchedDevices(buildMismatchedDevices(destinationServiceIdentifier, e.getMismatchedDevices()))
.build();
} catch (final MessageTooLargeException e) {
throw Status.INVALID_ARGUMENT.withDescription("Message too large").withCause(e).asException();
}
}
/**
* Translates an internal {@link org.whispersystems.textsecuregcm.controllers.MismatchedDevices} entity to a gRPC
* {@link MismatchedDevices} entity.
*
* @param serviceIdentifier the service identifier to which the mismatched device response applies
* @param mismatchedDevices the mismatched device entity to translate to gRPC
*
* @return a gRPC {@code MismatchedDevices} representation of the given mismatched devices
*/
public static MismatchedDevices buildMismatchedDevices(final ServiceIdentifier serviceIdentifier,
final org.whispersystems.textsecuregcm.controllers.MismatchedDevices mismatchedDevices) {
final MismatchedDevices.Builder mismatchedDevicesBuilder = MismatchedDevices.newBuilder()
.setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier));
mismatchedDevices.missingDeviceIds().forEach(mismatchedDevicesBuilder::addMissingDevices);
mismatchedDevices.extraDeviceIds().forEach(mismatchedDevicesBuilder::addExtraDevices);
mismatchedDevices.staleDeviceIds().forEach(mismatchedDevicesBuilder::addStaleDevices);
return mismatchedDevicesBuilder.build();
}
}

View File

@ -1,188 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Clock;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.signal.chat.messages.AuthenticatedSenderMessageType;
import org.signal.chat.messages.IndividualRecipientMessageBundle;
import org.signal.chat.messages.SendAuthenticatedSenderMessageRequest;
import org.signal.chat.messages.SendMessageResponse;
import org.signal.chat.messages.SendSyncMessageRequest;
import org.signal.chat.messages.SimpleMessagesGrpc;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.spam.GrpcResponse;
import org.whispersystems.textsecuregcm.spam.MessageType;
import org.whispersystems.textsecuregcm.spam.SpamCheckResult;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class MessagesGrpcService extends SimpleMessagesGrpc.MessagesImplBase {
private final AccountsManager accountsManager;
private final RateLimiters rateLimiters;
private final MessageSender messageSender;
private final CardinalityEstimator messageByteLimitEstimator;
private final SpamChecker spamChecker;
private final Clock clock;
public MessagesGrpcService(final AccountsManager accountsManager,
final RateLimiters rateLimiters,
final MessageSender messageSender,
final CardinalityEstimator messageByteLimitEstimator,
final SpamChecker spamChecker,
final Clock clock) {
this.accountsManager = accountsManager;
this.rateLimiters = rateLimiters;
this.messageSender = messageSender;
this.messageByteLimitEstimator = messageByteLimitEstimator;
this.spamChecker = spamChecker;
this.clock = clock;
}
@Override
public SendMessageResponse sendMessage(final SendAuthenticatedSenderMessageRequest request)
throws StatusException, RateLimitExceededException {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
final AciServiceIdentifier senderServiceIdentifier = new AciServiceIdentifier(authenticatedDevice.accountIdentifier());
final Account sender =
accountsManager.getByServiceIdentifier(senderServiceIdentifier).orElseThrow(Status.UNAUTHENTICATED::asException);
final ServiceIdentifier destinationServiceIdentifier =
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination());
if (sender.isIdentifiedBy(destinationServiceIdentifier)) {
throw Status.INVALID_ARGUMENT
.withDescription("Use `sendSyncMessage` to send messages to own account")
.asException();
}
final Account destination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier)
.orElseThrow(Status.NOT_FOUND::asException);
rateLimiters.getMessagesLimiter().validate(authenticatedDevice.accountIdentifier(), destination.getUuid());
return sendMessage(destination,
destinationServiceIdentifier,
authenticatedDevice,
request.getType(),
MessageType.INDIVIDUAL_IDENTIFIED_SENDER,
request.getMessages(),
request.getEphemeral(),
request.getUrgent());
}
@Override
public SendMessageResponse sendSyncMessage(final SendSyncMessageRequest request)
throws StatusException, RateLimitExceededException {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
final AciServiceIdentifier senderServiceIdentifier = new AciServiceIdentifier(authenticatedDevice.accountIdentifier());
final Account sender =
accountsManager.getByServiceIdentifier(senderServiceIdentifier).orElseThrow(Status.UNAUTHENTICATED::asException);
return sendMessage(sender,
senderServiceIdentifier,
authenticatedDevice,
request.getType(),
MessageType.SYNC,
request.getMessages(),
false,
request.getUrgent());
}
private SendMessageResponse sendMessage(final Account destination,
final ServiceIdentifier destinationServiceIdentifier,
final AuthenticatedDevice sender,
final AuthenticatedSenderMessageType envelopeType,
final MessageType messageType,
final IndividualRecipientMessageBundle messages,
final boolean ephemeral,
final boolean urgent) throws StatusException, RateLimitExceededException {
try {
final int totalPayloadLength = messages.getMessagesMap().values().stream()
.mapToInt(message -> message.getPayload().size())
.sum();
rateLimiters.getInboundMessageBytes().validate(destinationServiceIdentifier.uuid(), totalPayloadLength);
} catch (final RateLimitExceededException e) {
messageByteLimitEstimator.add(destinationServiceIdentifier.uuid().toString());
throw e;
}
final SpamCheckResult<GrpcResponse<SendMessageResponse>> spamCheckResult =
spamChecker.checkForIndividualRecipientSpamGrpc(messageType,
Optional.of(sender),
Optional.of(destination),
destinationServiceIdentifier);
if (spamCheckResult.response().isPresent()) {
return spamCheckResult.response().get().getResponseOrThrowStatus();
}
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = messages.getMessagesMap().entrySet()
.stream()
.collect(Collectors.toMap(
entry -> DeviceIdUtil.validate(entry.getKey()),
entry -> {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setType(getEnvelopeType(envelopeType))
.setClientTimestamp(messages.getTimestamp())
.setServerTimestamp(clock.millis())
.setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString())
.setSourceServiceId(new AciServiceIdentifier(sender.accountIdentifier()).toServiceIdentifierString())
.setSourceDevice(sender.deviceId())
.setEphemeral(ephemeral)
.setUrgent(urgent)
.setContent(entry.getValue().getPayload());
spamCheckResult.token().ifPresent(reportSpamToken ->
envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken)));
return envelopeBuilder.build();
}
));
final Map<Byte, Integer> registrationIdsByDeviceId = messages.getMessagesMap().entrySet().stream()
.collect(Collectors.toMap(
entry -> entry.getKey().byteValue(),
entry -> entry.getValue().getRegistrationId()));
return MessagesGrpcHelper.sendMessage(messageSender,
destination,
destinationServiceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
messageType == MessageType.SYNC ? Optional.of(sender.deviceId()) : Optional.empty());
}
private static MessageProtos.Envelope.Type getEnvelopeType(final AuthenticatedSenderMessageType type) {
return switch (type) {
case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT;
case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE;
case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT;
case UNSPECIFIED, UNRECOGNIZED ->
throw Status.INVALID_ARGUMENT.withDescription("Unrecognized envelope type").asRuntimeException();
};
}
}

View File

@ -6,8 +6,10 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Clock; import java.time.Clock;
import java.util.List;
import org.signal.chat.profile.CredentialType; import org.signal.chat.profile.CredentialType;
import org.signal.chat.profile.GetExpiringProfileKeyCredentialAnonymousRequest; import org.signal.chat.profile.GetExpiringProfileKeyCredentialAnonymousRequest;
import org.signal.chat.profile.GetExpiringProfileKeyCredentialResponse; import org.signal.chat.profile.GetExpiringProfileKeyCredentialResponse;
@ -57,17 +59,11 @@ public class ProfileAnonymousGrpcService extends ReactorProfileAnonymousGrpc.Pro
} }
final Mono<Account> account = switch (request.getAuthenticationCase()) { final Mono<Account> account = switch (request.getAuthenticationCase()) {
case GROUP_SEND_TOKEN -> { case GROUP_SEND_TOKEN ->
try { groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), List.of(targetIdentifier))
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), targetIdentifier); .then(Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier)))
yield Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier))
.flatMap(Mono::justOrEmpty) .flatMap(Mono::justOrEmpty)
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())); .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()));
} catch (final StatusException e) {
yield Mono.error(e);
}
}
case UNIDENTIFIED_ACCESS_KEY -> case UNIDENTIFIED_ACCESS_KEY ->
getTargetAccountAndValidateUnidentifiedAccess(targetIdentifier, request.getUnidentifiedAccessKey().toByteArray()); getTargetAccountAndValidateUnidentifiedAccess(targetIdentifier, request.getUnidentifiedAccessKey().toByteArray());
default -> Mono.error(Status.INVALID_ARGUMENT.asException()); default -> Mono.error(Status.INVALID_ARGUMENT.asException());

View File

@ -145,13 +145,11 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase {
request.getCommitment().toByteArray()))); request.getCommitment().toByteArray())));
final List<Mono<?>> updates = new ArrayList<>(2); final List<Mono<?>> updates = new ArrayList<>(2);
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, account.getBadges()))
.orElseGet(account::getBadges);
updates.add(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> { updates.add(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> {
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
.orElseGet(a::getBadges);
a.setBadges(clock, updatedBadges); a.setBadges(clock, updatedBadges);
a.setCurrentProfileVersion(request.getVersion()); a.setCurrentProfileVersion(request.getVersion());
}))); })));

View File

@ -1,16 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
public record RequestAttributes(InetAddress remoteAddress,
@Nullable String userAgent,
List<Locale.LanguageRange> acceptLanguage) {
}

View File

@ -2,25 +2,28 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Contexts; import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import io.grpc.Status; import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
/**
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
*
* @see RequestAttributesUtil
*/
public class RequestAttributesInterceptor implements ServerInterceptor { public class RequestAttributesInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager; private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager; this.grpcClientConnectionManager = grpcClientConnectionManager;
} }
@ -30,12 +33,52 @@ public class RequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
try { if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
return Contexts.interceptCall(Context.current() Context context = Context.current();
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next); {
} catch (final ChannelNotFoundException e) { final Optional<InetAddress> maybeRemoteAddress = grpcClientConnectionManager.getRemoteAddress(localAddress);
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
if (maybeRemoteAddress.isEmpty()) {
// We should never have a call from a party whose remote address we can't identify
log.warn("No remote address available");
call.close(Status.INTERNAL, new Metadata());
return new ServerCall.Listener<>() {};
}
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
}
{
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
grpcClientConnectionManager.getAcceptableLanguages(localAddress);
if (maybeAcceptLanguage.isPresent()) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
}
}
{
final Optional<String> maybeRawUserAgent =
grpcClientConnectionManager.getRawUserAgent(localAddress);
if (maybeRawUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.RAW_USER_AGENT_CONTEXT_KEY, maybeRawUserAgent.get());
}
}
{
final Optional<UserAgent> maybeUserAgent = grpcClientConnectionManager.getUserAgent(localAddress);
if (maybeUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
}
}
return Contexts.interceptCall(context, call, headers, next);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
} }
} }
} }

View File

@ -3,13 +3,18 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context; import io.grpc.Context;
import java.net.InetAddress; import java.net.InetAddress;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class RequestAttributesUtil { public class RequestAttributesUtil {
static final Context.Key<RequestAttributes> REQUEST_ATTRIBUTES_CONTEXT_KEY = Context.key("request-attributes"); static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language");
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
static final Context.Key<String> RAW_USER_AGENT_CONTEXT_KEY = Context.key("unparsed-user-agent");
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("parsed-user-agent");
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales()); private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
@ -18,8 +23,8 @@ public class RequestAttributesUtil {
* *
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified * @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
*/ */
public static List<Locale.LanguageRange> getAcceptableLanguages() { public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() {
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().acceptLanguage(); return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get());
} }
/** /**
@ -30,7 +35,9 @@ public class RequestAttributesUtil {
* @return a list of distinct locales acceptable to the remote client and available in this JVM * @return a list of distinct locales acceptable to the remote client and available in this JVM
*/ */
public static List<Locale> getAvailableAcceptedLocales() { public static List<Locale> getAvailableAcceptedLocales() {
return Locale.filter(getAcceptableLanguages(), AVAILABLE_LOCALES); return getAcceptableLanguages()
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
.orElseGet(Collections::emptyList);
} }
/** /**
@ -39,7 +46,16 @@ public class RequestAttributesUtil {
* @return the remote address of the remote client * @return the remote address of the remote client
*/ */
public static InetAddress getRemoteAddress() { public static InetAddress getRemoteAddress() {
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().remoteAddress(); return REMOTE_ADDRESS_CONTEXT_KEY.get();
}
/**
* Returns the parsed user-agent of the remote client in the current gRPC request context.
*
* @return the parsed user-agent of the remote client; may be empty if unparseable or not specified
*/
public static Optional<UserAgent> getUserAgent() {
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
} }
/** /**
@ -47,7 +63,7 @@ public class RequestAttributesUtil {
* *
* @return the unparsed user-agent of the remote client; may be empty if not specified * @return the unparsed user-agent of the remote client; may be empty if not specified
*/ */
public static Optional<String> getUserAgent() { public static Optional<String> getRawUserAgent() {
return Optional.ofNullable(REQUEST_ATTRIBUTES_CONTEXT_KEY.get().userAgent()); return Optional.ofNullable(RAW_USER_AGENT_CONTEXT_KEY.get());
} }
} }

View File

@ -1,39 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
public class ServerInterceptorUtil {
@SuppressWarnings("rawtypes")
private static final ServerCall.Listener NO_OP_LISTENER = new ServerCall.Listener<>() {};
private static final Metadata EMPTY_TRAILERS = new Metadata();
private ServerInterceptorUtil() {
}
/**
* Closes the given server call with the given status, returning a no-op listener.
*
* @param call the server call to close
* @param status the status with which to close the call
*
* @return a no-op server call listener
*
* @param <ReqT> the type of request object handled by the server call
* @param <RespT> the type of response object returned by the server call
*/
public static <ReqT, RespT> ServerCall.Listener<ReqT> closeWithStatus(final ServerCall<ReqT, RespT> call, final Status status) {
call.close(status, EMPTY_TRAILERS);
//noinspection unchecked
return NO_OP_LISTENER;
}
}

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.internalError; import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.internalError;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.Message; import com.google.protobuf.GeneratedMessageV3;
import io.grpc.ForwardingServerCallListener; import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
@ -75,7 +75,7 @@ public class ValidatingInterceptor implements ServerInterceptor {
} }
private void validateMessage(final Object message) throws StatusException { private void validateMessage(final Object message) throws StatusException {
if (message instanceof Message msg) { if (message instanceof GeneratedMessageV3 msg) {
try { try {
for (final Descriptors.FieldDescriptor fd: msg.getDescriptorForType().getFields()) { for (final Descriptors.FieldDescriptor fd: msg.getDescriptorForType().getFields()) {
for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry: fd.getOptions().getAllFields().entrySet()) { for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry: fd.getOptions().getAllFields().entrySet()) {

View File

@ -12,10 +12,8 @@ import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
/** /**
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering * An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
@ -50,12 +48,12 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
@Override @Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) { public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseIdentityDeterminedEvent(final Optional<AuthenticatedDevice> authenticatedDevice)) { if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
// authenticated service. // authenticated service.
final LocalAddress grpcServerAddress = authenticatedDevice.isPresent() final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent()
? authenticatedGrpcServerAddress ? authenticatedGrpcServerAddress
: anonymousGrpcServerAddress; : anonymousGrpcServerAddress;
@ -74,7 +72,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
if (localChannelFuture.isSuccess()) { if (localChannelFuture.isSuccess()) {
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(), grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(), remoteChannelContext.channel(),
authenticatedDevice); noiseIdentityDeterminedEvent.authenticatedDevice());
// Close the local connection if the remote channel closes and vice versa // Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close()); remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());

View File

@ -1,8 +1,6 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.Grpc;
import io.grpc.ServerCall;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
@ -25,26 +23,15 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes; import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ClosableEpoch; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
/** /**
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a * A client connection manager associates a local connection to a local gRPC server with a remote connection through a
* Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated * Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the
* identity of the device that opened the connection (for non-anonymous connections). It can also close connections * authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close
* associated with a given device if that device's credentials have changed and clients must reauthenticate. * connections associated with a given device if that device's credentials have changed and clients must reauthenticate.
* <p>
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
* narrow window between a server call being created and the start of call execution, in which case accessor methods
* in this class will throw a {@link ChannelNotFoundException}.
* <p>
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
* be called from any application code.
*/ */
public class GrpcClientConnectionManager implements DisconnectionRequestListener { public class GrpcClientConnectionManager implements DisconnectionRequestListener {
@ -56,93 +43,94 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice"); AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
@VisibleForTesting @VisibleForTesting
public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY = static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY =
AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes"); AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress");
@VisibleForTesting @VisibleForTesting
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY = static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch"); AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent");
@VisibleForTesting
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
@VisibleForTesting
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class); private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
/** /**
* Returns the authenticated device associated with the given server call, if any. If the connection is anonymous * Returns the authenticated device associated with the given local address, if any. An authenticated device is
* (i.e. unauthenticated), the returned value will be empty. * available if and only if the given local address maps to an active local connection and that connection is
* authenticated (i.e. not anonymous).
* *
* @param serverCall the gRPC server call for which to find an authenticated device * @param localAddress the local address for which to find an authenticated device
* *
* @return the authenticated device associated with the given local address, if any * @return the authenticated device associated with the given local address, if any
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/ */
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall) public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) {
throws ChannelNotFoundException { return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress));
return getAuthenticatedDevice(getRemoteChannel(serverCall));
} }
@VisibleForTesting private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) {
Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) { return Optional.ofNullable(remoteChannel)
return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get()); .map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
} }
/** /**
* Returns the request attributes associated with the given server call. * Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may
* be unavailable if the local connection associated with the given local address has already closed, if the client
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
* *
* @param serverCall the gRPC server call for which to retrieve request attributes * @param localAddress the local address for which to find acceptable languages
* *
* @return the request attributes associated with the given server call * @return the acceptable languages associated with the given local address, if any
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/ */
public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException { public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) {
return getRequestAttributes(getRemoteChannel(serverCall)); return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
} .map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
@VisibleForTesting
RequestAttributes getRequestAttributes(final Channel remoteChannel) {
final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
if (requestAttributes == null) {
throw new IllegalStateException("Channel does not have request attributes");
}
return requestAttributes;
} }
/** /**
* Handles the start of a server call, incrementing the active call count for the remote channel associated with the * Returns the remote address associated with the given local address, if any. A remote address may be unavailable if
* given server call. * the local connection associated with the given local address has already closed.
* *
* @param serverCall the server call to start * @param localAddress the local address for which to find a remote address
* *
* @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the * @return the remote address associated with the given local address, if any
* underlying channel is closing
*/ */
public boolean handleServerCallStart(final ServerCall<?, ?> serverCall) { public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) {
try { return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive(); .map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
} catch (final ChannelNotFoundException e) {
// This would only happen if the channel had already closed, which is certainly possible. In this case, the call
// should certainly not proceed.
return false;
}
} }
/** /**
* Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel * Returns the unparsed user agent provided by the client that opened the connection associated with the given local
* associated with the given server call. * address. This method may return an empty value if no active local connection is associated with the given local
* address.
* *
* @param serverCall the server call to complete * @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent string associated with the given local address
*/ */
public void handleServerCallComplete(final ServerCall<?, ?> serverCall) { public Optional<String> getRawUserAgent(final LocalAddress localAddress) {
try { return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart(); .map(remoteChannel -> remoteChannel.attr(RAW_USER_AGENT_ATTRIBUTE_KEY).get());
} catch (final ChannelNotFoundException ignored) { }
// In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
} /**
* Returns the parsed user agent provided by the client that opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address or if the client's user-agent string was not recognized.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent associated with the given local address
*/
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
} }
/** /**
@ -151,19 +139,11 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
* @param authenticatedDevice the authenticated device for which to close connections * @param authenticatedDevice the authenticated device for which to close connections
*/ */
public void closeConnection(final AuthenticatedDevice authenticatedDevice) { public void closeConnection(final AuthenticatedDevice authenticatedDevice) {
// Channels will actually get removed from the list/map by their closeFuture listeners. We copy the list to avoid // Channels will actually get removed from the list/map by their closeFuture listeners
// concurrent modification; it's possible (though practically unlikely) that a channel can close and remove itself remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()).forEach(channel ->
// from the list while we're still iterating, resulting in a `ConcurrentModificationException`. channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
final List<Channel> channelsToClose = .toWebSocketCloseStatus("Reauthentication required")))
new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList())); .addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
}
private static void closeRemoteChannel(final Channel channel) {
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
.toWebSocketCloseStatus("Reauthentication required")))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} }
@VisibleForTesting @VisibleForTesting
@ -171,32 +151,11 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice); return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
} }
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRemoteChannel(getLocalAddress(serverCall));
}
@VisibleForTesting @VisibleForTesting
Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException { Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) {
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
if (remoteChannel == null) {
throw new ChannelNotFoundException();
}
return remoteChannelsByLocalAddress.get(localAddress); return remoteChannelsByLocalAddress.get(localAddress);
} }
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
// a proxied Noise channel.
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
return localAddress;
}
/** /**
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake * Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
* request with the channel via which the handshake took place. * request with the channel via which the handshake took place.
@ -207,23 +166,30 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be * @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null} * {@code null}
*/ */
static void handleHandshakeComplete(final Channel channel, static void handleWebSocketHandshakeComplete(final Channel channel,
final InetAddress preferredRemoteAddress, final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader, @Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) { @Nullable final String acceptLanguageHeader) {
@Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList(); channel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress);
if (StringUtils.isNotBlank(userAgentHeader)) {
channel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
try {
channel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
} catch (final UnrecognizedUserAgentException ignored) {
}
}
if (StringUtils.isNotBlank(acceptLanguageHeader)) { if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try { try {
acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader); channel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader));
} catch (final IllegalArgumentException e) { } catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e); log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
} }
} }
channel.attr(REQUEST_ATTRIBUTES_KEY)
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
} }
/** /**
@ -241,9 +207,6 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
maybeAuthenticatedDevice.ifPresent(authenticatedDevice -> maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice)); remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
.set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel); remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice -> getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->

View File

@ -74,7 +74,7 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
preferredRemoteAddress = maybePreferredRemoteAddress.get(); preferredRemoteAddress = maybePreferredRemoteAddress.get();
} }
GrpcClientConnectionManager.handleHandshakeComplete(context.channel(), GrpcClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(),
preferredRemoteAddress, preferredRemoteAddress,
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT), handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE)); handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));

View File

@ -11,6 +11,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.Message; import com.google.protobuf.Message;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException; import io.grpc.StatusException;
@ -48,7 +49,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
public void validate( public void validate(
final Object extensionValue, final Object extensionValue,
final Descriptors.FieldDescriptor fd, final Descriptors.FieldDescriptor fd,
final Message msg) throws StatusException { final GeneratedMessageV3 msg) throws StatusException {
try { try {
final T extensionValueTyped = resolveExtensionValue(extensionValue); final T extensionValueTyped = resolveExtensionValue(extensionValue);
@ -115,7 +116,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
protected void validateRepeatedField( protected void validateRepeatedField(
final T extensionValue, final T extensionValue,
final Descriptors.FieldDescriptor fd, final Descriptors.FieldDescriptor fd,
final Message msg) throws StatusException { final GeneratedMessageV3 msg) throws StatusException {
throw internalError("`validateRepeatedField` method needs to be implemented"); throw internalError("`validateRepeatedField` method needs to be implemented");
} }

View File

@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.Message; import com.google.protobuf.GeneratedMessageV3;
import io.grpc.StatusException; import io.grpc.StatusException;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -53,7 +53,7 @@ public class ExactlySizeFieldValidator extends BaseFieldValidator<Set<Integer>>
protected void validateRepeatedField( protected void validateRepeatedField(
final Set<Integer> permittedSizes, final Set<Integer> permittedSizes,
final Descriptors.FieldDescriptor fd, final Descriptors.FieldDescriptor fd,
final Message msg) throws StatusException { final GeneratedMessageV3 msg) throws StatusException {
final int size = msg.getRepeatedFieldCount(fd); final int size = msg.getRepeatedFieldCount(fd);
if (permittedSizes.contains(size)) { if (permittedSizes.contains(size)) {
return; return;

View File

@ -6,11 +6,11 @@
package org.whispersystems.textsecuregcm.grpc.validators; package org.whispersystems.textsecuregcm.grpc.validators;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.Message; import com.google.protobuf.GeneratedMessageV3;
import io.grpc.StatusException; import io.grpc.StatusException;
public interface FieldValidator { public interface FieldValidator {
void validate(Object extensionValue, Descriptors.FieldDescriptor fd, Message msg) void validate(Object extensionValue, Descriptors.FieldDescriptor fd, GeneratedMessageV3 msg)
throws StatusException; throws StatusException;
} }

View File

@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.Message; import com.google.protobuf.GeneratedMessageV3;
import io.grpc.StatusException; import io.grpc.StatusException;
import java.util.Set; import java.util.Set;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@ -52,7 +52,7 @@ public class NonEmptyFieldValidator extends BaseFieldValidator<Boolean> {
protected void validateRepeatedField( protected void validateRepeatedField(
final Boolean extensionValue, final Boolean extensionValue,
final Descriptors.FieldDescriptor fd, final Descriptors.FieldDescriptor fd,
final Message msg) throws StatusException { final GeneratedMessageV3 msg) throws StatusException {
if (msg.getRepeatedFieldCount(fd) > 0) { if (msg.getRepeatedFieldCount(fd) > 0) {
return; return;
} }

View File

@ -5,7 +5,7 @@
package org.whispersystems.textsecuregcm.grpc.validators; package org.whispersystems.textsecuregcm.grpc.validators;
public record Range(long min, long max) { public record Range(int min, int max) {
public Range { public Range {
if (min > max) { if (min > max) {
throw new IllegalArgumentException("invalid range values: expected min <= max but have [%d, %d],".formatted(min, max)); throw new IllegalArgumentException("invalid range values: expected min <= max but have [%d, %d],".formatted(min, max));

View File

@ -39,8 +39,8 @@ public class RangeFieldValidator extends BaseFieldValidator<Range> {
@Override @Override
protected Range resolveExtensionValue(final Object extensionValue) throws StatusException { protected Range resolveExtensionValue(final Object extensionValue) throws StatusException {
final ValueRangeConstraint rangeConstraint = (ValueRangeConstraint) extensionValue; final ValueRangeConstraint rangeConstraint = (ValueRangeConstraint) extensionValue;
final long min = rangeConstraint.hasMin() ? rangeConstraint.getMin() : Long.MIN_VALUE; final int min = rangeConstraint.hasMin() ? rangeConstraint.getMin() : Integer.MIN_VALUE;
final long max = rangeConstraint.hasMax() ? rangeConstraint.getMax() : Long.MAX_VALUE; final int max = rangeConstraint.hasMax() ? rangeConstraint.getMax() : Integer.MAX_VALUE;
return new Range(min, max); return new Range(min, max);
} }

View File

@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.Message; import com.google.protobuf.GeneratedMessageV3;
import io.grpc.StatusException; import io.grpc.StatusException;
import java.util.Set; import java.util.Set;
import org.signal.chat.require.SizeConstraint; import org.signal.chat.require.SizeConstraint;
@ -48,7 +48,7 @@ public class SizeFieldValidator extends BaseFieldValidator<Range> {
} }
@Override @Override
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final Message msg) throws StatusException { protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final GeneratedMessageV3 msg) throws StatusException {
final int size = msg.getRepeatedFieldCount(fd); final int size = msg.getRepeatedFieldCount(fd);
if (size < range.min() || size > range.max()) { if (size < range.min() || size > range.max()) {
throw invalidArgument("field value is [%d] but expected to be within the [%d, %d] range".formatted( throw invalidArgument("field value is [%d] but expected to be within the [%d, %d] range".formatted(

View File

@ -1,44 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import java.util.Optional;
public class DevicePlatformUtil {
private DevicePlatformUtil() {
}
/**
* Returns the most likely client platform for a device.
*
* @param device the device for which to find a client platform
*
* @return the most likely client platform for the given device or empty if no likely platform could be determined
*/
public static Optional<ClientPlatform> getDevicePlatform(final Device device) {
final Optional<ClientPlatform> clientPlatform;
if (StringUtils.isNotBlank(device.getGcmId())) {
clientPlatform = Optional.of(ClientPlatform.ANDROID);
} else if (StringUtils.isNotBlank(device.getApnId())) {
clientPlatform = Optional.of(ClientPlatform.IOS);
} else {
clientPlatform = Optional.empty();
}
return clientPlatform.or(() -> Optional.ofNullable(
switch (device.getUserAgent()) {
case "OWA" -> ClientPlatform.ANDROID;
case "OWI", "OWP" -> ClientPlatform.IOS;
case "OWD" -> ClientPlatform.DESKTOP;
case null, default -> null;
}));
}
}

View File

@ -0,0 +1,32 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.sun.management.OperatingSystemMXBean;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.binder.MeterBinder;
import java.lang.management.ManagementFactory;
public class FreeMemoryGauge implements MeterBinder {
private final OperatingSystemMXBean operatingSystemMXBean;
public FreeMemoryGauge() {
this.operatingSystemMXBean = (com.sun.management.OperatingSystemMXBean)
ManagementFactory.getOperatingSystemMXBean();
}
@Override
public void bindTo(final MeterRegistry registry) {
Gauge.builder(name(FreeMemoryGauge.class, "freeMemory"), operatingSystemMXBean,
OperatingSystemMXBean::getFreeMemorySize)
.register(registry);
}
}

View File

@ -120,7 +120,10 @@ public class MetricsUtil {
public static void registerSystemResourceMetrics(final Environment environment) { public static void registerSystemResourceMetrics(final Environment environment) {
new ProcessorMetrics().bindTo(Metrics.globalRegistry); new ProcessorMetrics().bindTo(Metrics.globalRegistry);
new FreeMemoryGauge().bindTo(Metrics.globalRegistry);
new FileDescriptorMetrics().bindTo(Metrics.globalRegistry); new FileDescriptorMetrics().bindTo(Metrics.globalRegistry);
new OperatingSystemMemoryGauge("Buffers").bindTo(Metrics.globalRegistry);
new OperatingSystemMemoryGauge("Cached").bindTo(Metrics.globalRegistry);
new JvmMemoryMetrics().bindTo(Metrics.globalRegistry); new JvmMemoryMetrics().bindTo(Metrics.globalRegistry);
new JvmThreadMetrics().bindTo(Metrics.globalRegistry); new JvmThreadMetrics().bindTo(Metrics.globalRegistry);

View File

@ -67,7 +67,7 @@ public class OpenWebSocketCounter {
try { try {
final ClientPlatform clientPlatform = final ClientPlatform clientPlatform =
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).platform(); UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).getPlatform();
calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform); calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform);
calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform); calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform);

View File

@ -0,0 +1,56 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.binder.MeterBinder;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
public class OperatingSystemMemoryGauge implements MeterBinder {
private final String metricName;
private static final File MEMINFO_FILE = new File("/proc/meminfo");
private static final Pattern MEMORY_METRIC_PATTERN = Pattern.compile("^([^:]+):\\s+([0-9]+).*$");
public OperatingSystemMemoryGauge(final String metricName) {
this.metricName = metricName;
}
@Override
public void bindTo(MeterRegistry registry) {
final String metricName = this.metricName;
Gauge.builder(name(OperatingSystemMemoryGauge.class, metricName.toLowerCase(Locale.ROOT)), () -> {
try (final BufferedReader bufferedReader = new BufferedReader(new FileReader(MEMINFO_FILE))) {
return getValue(bufferedReader.lines(), metricName);
} catch (final IOException e) {
return 0L;
}
})
.register(registry);
}
@VisibleForTesting
static double getValue(final Stream<String> lines, final String metricName) {
return lines.map(MEMORY_METRIC_PATTERN::matcher)
.filter(Matcher::matches)
.filter(matcher -> metricName.equalsIgnoreCase(matcher.group(1)))
.map(matcher -> Double.parseDouble(matcher.group(2)))
.findFirst()
.orElse(0d);
}
}

View File

@ -9,7 +9,6 @@ import io.micrometer.core.instrument.Tag;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.WhisperServerVersion; import org.whispersystems.textsecuregcm.WhisperServerVersion;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
@ -49,15 +48,15 @@ public class UserAgentTagUtil {
} }
public static Tag getPlatformTag(@Nullable final UserAgent userAgent) { public static Tag getPlatformTag(@Nullable final UserAgent userAgent) {
return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized"); return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized");
} }
public static Optional<Tag> getClientVersionTag(final String userAgentString, final ClientReleaseManager clientReleaseManager) { public static Optional<Tag> getClientVersionTag(final String userAgentString, final ClientReleaseManager clientReleaseManager) {
try { try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
if (clientReleaseManager.isVersionActive(userAgent.platform(), userAgent.version())) { if (clientReleaseManager.isVersionActive(userAgent.getPlatform(), userAgent.getVersion())) {
return Optional.of(Tag.of(VERSION_TAG, userAgent.version().toString())); return Optional.of(Tag.of(VERSION_TAG, userAgent.getVersion().toString()));
} }
} catch (final UnrecognizedUserAgentException ignored) { } catch (final UnrecognizedUserAgentException ignored) {
} }
@ -71,8 +70,10 @@ public class UserAgentTagUtil {
try { try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
platform = userAgent.platform().name().toLowerCase(); platform = userAgent.getPlatform().name().toLowerCase();
libsignal = StringUtils.contains(userAgent.additionalSpecifiers(), "libsignal"); libsignal = userAgent.getAdditionalSpecifiers()
.map(additionalSpecifiers -> additionalSpecifiers.contains("libsignal"))
.orElse(false);
} catch (final UnrecognizedUserAgentException e) { } catch (final UnrecognizedUserAgentException e) {
platform = "unrecognized"; platform = "unrecognized";
libsignal = false; libsignal = false;

View File

@ -20,25 +20,22 @@ import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices; import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
/** /**
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages, * A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
@ -55,14 +52,10 @@ public class MessageSender {
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
public static final String ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT = "androidSkipLowUrgencyPush";
// Note that these names deliberately reference `MessageController` for metric continuity // Note that these names deliberately reference `MessageController` for metric continuity
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage"); private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage");
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize"); private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize");
private static final String EMPTY_MESSAGE_LIST_COUNTER_NAME = MetricsUtil.name(MessageSender.class, "emptyMessageList");
private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage"); private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage");
private static final String EPHEMERAL_TAG_NAME = "ephemeral"; private static final String EPHEMERAL_TAG_NAME = "ephemeral";
@ -71,7 +64,6 @@ public class MessageSender {
private static final String STORY_TAG_NAME = "story"; private static final String STORY_TAG_NAME = "story";
private static final String SEALED_SENDER_TAG_NAME = "sealedSender"; private static final String SEALED_SENDER_TAG_NAME = "sealedSender";
private static final String MULTI_RECIPIENT_TAG_NAME = "multiRecipient"; private static final String MULTI_RECIPIENT_TAG_NAME = "multiRecipient";
private static final String SYNC_MESSAGE_TAG_NAME = "sync";
@VisibleForTesting @VisibleForTesting
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes(); public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
@ -79,13 +71,9 @@ public class MessageSender {
@VisibleForTesting @VisibleForTesting
static final byte NO_EXCLUDED_DEVICE_ID = -1; static final byte NO_EXCLUDED_DEVICE_ID = -1;
public MessageSender( public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
final MessagesManager messagesManager,
final PushNotificationManager pushNotificationManager,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
this.experimentEnrollmentManager = experimentEnrollmentManager;
} }
/** /**
@ -97,8 +85,6 @@ public class MessageSender {
* @param destinationIdentifier the service identifier to which the messages are addressed * @param destinationIdentifier the service identifier to which the messages are addressed
* @param messagesByDeviceId a map of device IDs to message payloads * @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs * @param registrationIdsByDeviceId a map of device IDs to device registration IDs
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
* caller's own account), contains the ID of the device that sent the message
* @param userAgent the User-Agent string for the sender; may be {@code null} if not known * @param userAgent the User-Agent string for the sender; may be {@code null} if not known
* *
* @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required * @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required
@ -110,53 +96,39 @@ public class MessageSender {
final ServiceIdentifier destinationIdentifier, final ServiceIdentifier destinationIdentifier,
final Map<Byte, Envelope> messagesByDeviceId, final Map<Byte, Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId, final Map<Byte, Integer> registrationIdsByDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId,
@Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException { @Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException {
if (messagesByDeviceId.isEmpty()) {
return;
}
if (!destination.isIdentifiedBy(destinationIdentifier)) { if (!destination.isIdentifiedBy(destinationIdentifier)) {
throw new IllegalArgumentException("Destination account not identified by destination service identifier"); throw new IllegalArgumentException("Destination account not identified by destination service identifier");
} }
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent); final Envelope firstMessage = messagesByDeviceId.values().iterator().next();
if (messagesByDeviceId.isEmpty()) { final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) &&
Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME, destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId()));
Tags.of(SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment();
}
final byte excludedDeviceId; final boolean isStory = firstMessage.getStory();
if (syncMessageSenderDeviceId.isPresent()) {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) ||
!destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender"); validateIndividualMessageContentLength(messagesByDeviceId.values(), isSyncMessage, isStory, userAgent);
}
excludedDeviceId = syncMessageSenderDeviceId.get();
} else {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender");
}
excludedDeviceId = NO_EXCLUDED_DEVICE_ID;
}
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination, final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
destinationIdentifier, destinationIdentifier,
registrationIdsByDeviceId, registrationIdsByDeviceId,
excludedDeviceId); isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID);
if (maybeMismatchedDevices.isPresent()) { if (maybeMismatchedDevices.isPresent()) {
throw new MismatchedDevicesException(maybeMismatchedDevices.get()); throw new MismatchedDevicesException(maybeMismatchedDevices.get());
} }
validateIndividualMessageContentLength(messagesByDeviceId.values(), syncMessageSenderDeviceId.isPresent(), userAgent);
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId) messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
.forEach((deviceId, destinationPresent) -> { .forEach((deviceId, destinationPresent) -> {
final Envelope message = messagesByDeviceId.get(deviceId); final Envelope message = messagesByDeviceId.get(deviceId);
if (!destinationPresent && !message.getEphemeral() && !shouldSkipPush(destination, deviceId, message.getUrgent())) { if (!destinationPresent && !message.getEphemeral()) {
try { try {
pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent()); pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent());
} catch (final NotPushRegisteredException ignored) { } catch (final NotPushRegisteredException ignored) {
@ -169,21 +141,13 @@ public class MessageSender {
URGENT_TAG_NAME, String.valueOf(message.getUrgent()), URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()), STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()), SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()),
SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent()),
MULTI_RECIPIENT_TAG_NAME, "false") MULTI_RECIPIENT_TAG_NAME, "false")
.and(platformTag); .and(UserAgentTagUtil.getPlatformTag(userAgent));
Metrics.counter(SEND_COUNTER_NAME, tags).increment(); Metrics.counter(SEND_COUNTER_NAME, tags).increment();
}); });
} }
private boolean shouldSkipPush(final Account account, byte deviceId, boolean urgent) {
final boolean isAndroidFcm = account.getDevice(deviceId).map(Device::getGcmId).isPresent();
return !urgent
&& isAndroidFcm
&& experimentEnrollmentManager.isEnrolled(account.getUuid(), ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT);
}
/** /**
* Sends messages to a group of recipients. If a destination device has a valid push notification token and does not * Sends messages to a group of recipients. If a destination device has a valid push notification token and does not
* have an active connection to a Signal server, then this method will also send a push notification to that device to * have an active connection to a Signal server, then this method will also send a push notification to that device to
@ -261,7 +225,6 @@ public class MessageSender {
URGENT_TAG_NAME, String.valueOf(isUrgent), URGENT_TAG_NAME, String.valueOf(isUrgent),
STORY_TAG_NAME, String.valueOf(isStory), STORY_TAG_NAME, String.valueOf(isStory),
SEALED_SENDER_TAG_NAME, "true", SEALED_SENDER_TAG_NAME, "true",
SYNC_MESSAGE_TAG_NAME, "false",
MULTI_RECIPIENT_TAG_NAME, "true") MULTI_RECIPIENT_TAG_NAME, "true")
.and(UserAgentTagUtil.getPlatformTag(userAgent)); .and(UserAgentTagUtil.getPlatformTag(userAgent));
@ -345,13 +308,14 @@ public class MessageSender {
private static void validateIndividualMessageContentLength(final Iterable<Envelope> messages, private static void validateIndividualMessageContentLength(final Iterable<Envelope> messages,
final boolean isSyncMessage, final boolean isSyncMessage,
final boolean isStory,
@Nullable final String userAgent) throws MessageTooLargeException { @Nullable final String userAgent) throws MessageTooLargeException {
for (final Envelope message : messages) { for (final Envelope message : messages) {
MessageSender.validateContentLength(message.getContent().size(), MessageSender.validateContentLength(message.getContent().size(),
false, false,
isSyncMessage, isSyncMessage,
message.getStory(), isStory,
userAgent); userAgent);
} }
} }

View File

@ -1,128 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
public class MessageUtil {
public static final int DEFAULT_MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
private MessageUtil() {
}
/**
* Finds account records for all recipients named in the given multi-recipient manager. Note that the returned map
* of recipients to account records will omit entries for recipients that could not be resolved to active accounts;
* callers that require full resolution should check for a missing entries and take appropriate action.
*
* @param accountsManager the {@code AccountsManager} instance to use to find account records
* @param multiRecipientMessage the message for which to resolve recipients
*
* @return a map of recipients to account records
*
* @see #getUnresolvedRecipients(SealedSenderMultiRecipientMessage, Map)
*/
public static Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolveRecipients(
final AccountsManager accountsManager,
final SealedSenderMultiRecipientMessage multiRecipientMessage) {
return resolveRecipients(accountsManager, multiRecipientMessage, DEFAULT_MAX_FETCH_ACCOUNT_CONCURRENCY);
}
/**
* Finds account records for all recipients named in the given multi-recipient manager. Note that the returned map
* of recipients to account records will omit entries for recipients that could not be resolved to active accounts;
* callers that require full resolution should check for a missing entries and take appropriate action.
*
* @param accountsManager the {@code AccountsManager} instance to use to find account records
* @param multiRecipientMessage the message for which to resolve recipients
* @param maxFetchAccountConcurrency the maximum number of concurrent account-retrieval operations
*
* @return a map of recipients to account records
*
* @see #getUnresolvedRecipients(SealedSenderMultiRecipientMessage, Map)
*/
public static Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolveRecipients(
final AccountsManager accountsManager,
final SealedSenderMultiRecipientMessage multiRecipientMessage,
final int maxFetchAccountConcurrency) {
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.flatMap(serviceIdAndRecipient -> {
final ServiceIdentifier serviceIdentifier =
ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.flatMap(Mono::justOrEmpty)
.map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account));
}, maxFetchAccountConcurrency)
.collectMap(Tuple2::getT1, Tuple2::getT2)
.blockOptional()
.orElse(Collections.emptyMap());
}
/**
* Returns a list of recipients missing from the map of resolved recipients for a multi-recipient message.
*
* @param multiRecipientMessage the multi-recipient message
* @param resolvedRecipients the map of resolved recipients to check for missing entries
*
* @return a list of {@code ServiceIdentifiers} belonging to multi-recipient message recipients that are not present
* in the given map of {@code resolvedRecipients}
*/
public static List<ServiceIdentifier> getUnresolvedRecipients(
final SealedSenderMultiRecipientMessage multiRecipientMessage,
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients) {
return multiRecipientMessage.getRecipients().entrySet().stream()
.filter(entry -> !resolvedRecipients.containsKey(entry.getValue()))
.map(entry -> ServiceIdentifier.fromLibsignal(entry.getKey()))
.toList();
}
/**
* Checks if a multi-recipient message contains duplicate recipients.
*
* @param multiRecipientMessage the message to check for duplicate recipients
*
* @return {@code true} if the message contains duplicate recipients or {@code false} otherwise
*/
public static boolean hasDuplicateDevices(final SealedSenderMultiRecipientMessage multiRecipientMessage) {
final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID + 1];
for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) {
if (recipient.getDevices().length == 1) {
// A recipient can't have repeated devices if they only have one device
continue;
}
Arrays.fill(usedDeviceIds, false);
for (final byte deviceId : recipient.getDevices()) {
if (usedDeviceIds[deviceId]) {
return true;
}
usedDeviceIds[deviceId] = true;
}
}
return false;
}
}

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.push;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -71,7 +70,6 @@ public class ReceiptSender {
destinationIdentifier, destinationIdentifier,
messagesByDeviceId, messagesByDeviceId,
registrationIdsByDeviceId, registrationIdsByDeviceId,
Optional.empty(),
UserAgentTagUtil.SERVER_UA); UserAgentTagUtil.SERVER_UA);
} catch (final Exception e) { } catch (final Exception e) {
logger.warn("Could not send delivery receipt", e); logger.warn("Could not send delivery receipt", e);

View File

@ -38,8 +38,8 @@ public class SchedulingUtil {
final LocalTime preferredTime, final LocalTime preferredTime,
final Clock clock) { final Clock clock) {
final ZonedDateTime candidateNotificationTime = getZoneId(account, clock) final ZonedDateTime candidateNotificationTime = getZoneOffset(account, clock)
.map(zoneId -> ZonedDateTime.now(clock.withZone(zoneId)).with(preferredTime)) .map(zoneOffset -> ZonedDateTime.now(zoneOffset).with(preferredTime))
.orElseGet(() -> { .orElseGet(() -> {
// We couldn't find a reasonable timezone for the account for some reason, so make an educated guess at a // We couldn't find a reasonable timezone for the account for some reason, so make an educated guess at a
// reasonable time to send a notification based on the account's creation time. // reasonable time to send a notification based on the account's creation time.
@ -59,7 +59,7 @@ public class SchedulingUtil {
} }
@VisibleForTesting @VisibleForTesting
static Optional<ZoneId> getZoneId(final Account account, final Clock clock) { static Optional<ZoneOffset> getZoneOffset(final Account account, final Clock clock) {
try { try {
final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(account.getNumber(), null); final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(account.getNumber(), null);
@ -70,7 +70,7 @@ public class SchedulingUtil {
return Optional.empty(); return Optional.empty();
} }
final List<ZoneId> sortedZoneOffsets = timeZonesForNumber final List<ZoneOffset> sortedZoneOffsets = timeZonesForNumber
.stream() .stream()
.map(id -> { .map(id -> {
try { try {
@ -80,6 +80,9 @@ public class SchedulingUtil {
} }
}) })
.filter(Objects::nonNull) .filter(Objects::nonNull)
.map(ZoneId::getRules)
.distinct()
.map(zoneRules -> zoneRules.getOffset(clock.instant()))
.sorted() .sorted()
.toList(); .toList();

View File

@ -6,27 +6,19 @@
package org.whispersystems.textsecuregcm.spam; package org.whispersystems.textsecuregcm.spam;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import javax.annotation.Nullable;
import java.util.Optional; import java.util.Optional;
/** /**
* A combination of a gRPC status and response message to communicate to callers that a message has been flagged as * A combination of a gRPC status and response message to communicate to callers that a message has been flagged as
* potential spam. * potential spam.
* *
* @param status The gRPC status for this response. If the status is {@link Status#OK}, then a response object will be
* available via {@link #response}. Otherwise, callers should transmit the status as an error to clients.
* @param response a response object to send to clients; will be present if {@link #status} is not {@link Status#OK}
*
* @param <R> the type of response object * @param <R> the type of response object
*/ */
public class GrpcResponse<R> { public record GrpcResponse<R>(Status status, Optional<R> response) {
private final Status status;
@Nullable
private final R response;
private GrpcResponse(final Status status, @Nullable final R response) {
this.status = status;
this.response = response;
}
/** /**
* Constructs a new response object with the given status and no response message. * Constructs a new response object with the given status and no response message.
@ -38,7 +30,7 @@ public class GrpcResponse<R> {
* @param <R> the type of response object * @param <R> the type of response object
*/ */
public static <R> GrpcResponse<R> withStatus(final Status status) { public static <R> GrpcResponse<R> withStatus(final Status status) {
return new GrpcResponse<>(status, null); return new GrpcResponse<>(status, Optional.empty());
} }
/** /**
@ -51,22 +43,6 @@ public class GrpcResponse<R> {
* @param <R> the type of response object * @param <R> the type of response object
*/ */
public static <R> GrpcResponse<R> withResponse(final R response) { public static <R> GrpcResponse<R> withResponse(final R response) {
return new GrpcResponse<>(Status.OK, response); return new GrpcResponse<>(Status.OK, Optional.of(response));
}
/**
* Returns the message body contained within this response or throws the contained status as a {@link StatusException}
* if no message body is specified.
*
* @return the message body contained within this response
*
* @throws StatusException if no message body is specified
*/
public R getResponseOrThrowStatus() throws StatusException {
if (response != null) {
return response;
}
throw status.asException();
} }
} }

View File

@ -5,10 +5,8 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.time.Clock;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.ObjectUtils;
@ -30,16 +28,12 @@ public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class); private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class);
private final MessageSender messageSender; private final MessageSender messageSender;
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final Clock clock;
public ChangeNumberManager( public ChangeNumberManager(
final MessageSender messageSender, final MessageSender messageSender,
final AccountsManager accountsManager, final AccountsManager accountsManager) {
final Clock clock) {
this.messageSender = messageSender; this.messageSender = messageSender;
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.clock = clock;
} }
public Account changeNumber(final Account account, final String number, public Account changeNumber(final Account account, final String number,
@ -102,7 +96,7 @@ public class ChangeNumberManager {
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException { final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
try { try {
final long serverTimestamp = clock.millis(); final long serverTimestamp = System.currentTimeMillis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid()); final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream() final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
@ -119,15 +113,10 @@ public class ChangeNumberManager {
.setEphemeral(false) .setEphemeral(false)
.build())); .build()));
final Map<Byte, Integer> registrationIdsByDeviceId = deviceMessages.stream() final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId)); .collect(Collectors.toMap(Device::getId, Device::getRegistrationId));
messageSender.sendMessages(account, messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId, senderUserAgent);
serviceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
Optional.of(Device.PRIMARY_ID),
senderUserAgent);
} catch (final RuntimeException e) { } catch (final RuntimeException e) {
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e); logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
throw e; throw e;

View File

@ -12,13 +12,10 @@ import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
@ -27,10 +24,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -43,7 +37,6 @@ public class MessagePersister implements Managed {
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager; private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final Duration persistDelay; private final Duration persistDelay;
@ -51,8 +44,6 @@ public class MessagePersister implements Managed {
private volatile boolean running; private volatile boolean running;
private static final String OVERSIZED_QUEUE_COUNTER_NAME = name(MessagePersister.class, "persistQueueOversized"); private static final String OVERSIZED_QUEUE_COUNTER_NAME = name(MessagePersister.class, "persistQueueOversized");
private static final String PERSISTED_MESSAGE_COUNTER_NAME = name(MessagePersister.class, "persistMessage");
private static final String PERSISTED_BYTES_COUNTER_NAME = name(MessagePersister.class, "persistBytes");
private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(MessagePersister.class, "getQueues")); private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(MessagePersister.class, "getQueues"));
private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(MessagePersister.class, "persistQueue")); private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(MessagePersister.class, "persistQueue"));
@ -66,7 +57,10 @@ public class MessagePersister implements Managed {
.publishPercentileHistogram(true) .publishPercentileHistogram(true)
.register(Metrics.globalRegistry); .register(Metrics.globalRegistry);
private static final String QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME = name(MessagePersister.class, "queueSize"); private static final DistributionSummary QUEUE_SIZE_DISTRIBUTION_SUMMARY = DistributionSummary.builder(
name(MessagePersister.class, "queueSize"))
.publishPercentileHistogram(true)
.register(Metrics.globalRegistry);
static final int QUEUE_BATCH_LIMIT = 100; static final int QUEUE_BATCH_LIMIT = 100;
static final int MESSAGE_BATCH_LIMIT = 100; static final int MESSAGE_BATCH_LIMIT = 100;
@ -81,7 +75,6 @@ public class MessagePersister implements Managed {
final MessagesManager messagesManager, final MessagesManager messagesManager,
final AccountsManager accountsManager, final AccountsManager accountsManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ExperimentEnrollmentManager experimentEnrollmentManager,
final Duration persistDelay, final Duration persistDelay,
final int dedicatedProcessWorkerThreadCount) { final int dedicatedProcessWorkerThreadCount) {
@ -89,7 +82,6 @@ public class MessagePersister implements Managed {
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.persistDelay = persistDelay; this.persistDelay = persistDelay;
this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount]; this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount];
@ -147,7 +139,6 @@ public class MessagePersister implements Managed {
@VisibleForTesting @VisibleForTesting
int persistNextQueues(final Instant currentTime) { int persistNextQueues(final Instant currentTime) {
final int slot = messagesCache.getNextSlotToPersist(); final int slot = messagesCache.getNextSlotToPersist();
final String shard = messagesCache.shardForSlot(slot);
List<String> queuesToPersist; List<String> queuesToPersist;
int queuesPersisted = 0; int queuesPersisted = 0;
@ -171,11 +162,10 @@ public class MessagePersister implements Managed {
continue; continue;
} }
try { try {
persistQueue(maybeAccount.get(), maybeDevice.get(), shard); persistQueue(maybeAccount.get(), maybeDevice.get());
} catch (final Exception e) { } catch (final Exception e) {
PERSIST_QUEUE_EXCEPTION_METER.increment(); PERSIST_QUEUE_EXCEPTION_METER.increment();
logger.warn("Failed to persist queue {}::{} (slot {}, shard {}); will schedule for retry", logger.warn("Failed to persist queue {}::{}; will schedule for retry", accountUuid, deviceId, e);
accountUuid, deviceId, slot, shard, e);
messagesCache.addQueueToPersist(accountUuid, deviceId); messagesCache.addQueueToPersist(accountUuid, deviceId);
@ -193,14 +183,10 @@ public class MessagePersister implements Managed {
} }
@VisibleForTesting @VisibleForTesting
void persistQueue(final Account account, final Device device, final String shard) throws MessagePersistenceException { void persistQueue(final Account account, final Device device) throws MessagePersistenceException {
final UUID accountUuid = account.getUuid(); final UUID accountUuid = account.getUuid();
final byte deviceId = device.getId(); final byte deviceId = device.getId();
final Tag platformTag = Tag.of("platform", DevicePlatformUtil.getDevicePlatform(device)
.map(platform -> platform.name().toLowerCase(Locale.ROOT))
.orElse("unknown"));
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
messagesCache.lockQueueForPersistence(accountUuid, deviceId); messagesCache.lockQueueForPersistence(accountUuid, deviceId);
@ -214,16 +200,6 @@ public class MessagePersister implements Managed {
do { do {
messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT); messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT);
final int urgentMessageCount = (int) messages.stream().filter(MessageProtos.Envelope::getUrgent).count();
final int nonUrgentMessageCount = messages.size() - urgentMessageCount;
final Tags tags = Tags.of(platformTag, Tag.of("shard", shard));
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "true")).increment(urgentMessageCount);
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "false")).increment(nonUrgentMessageCount);
Metrics.counter(PERSISTED_BYTES_COUNTER_NAME, tags)
.increment(messages.stream().mapToInt(MessageProtos.Envelope::getSerializedSize).sum());
int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, device, messages); int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, device, messages);
messageCount += messages.size(); messageCount += messages.size();
@ -239,14 +215,7 @@ public class MessagePersister implements Managed {
} while (!messages.isEmpty()); } while (!messages.isEmpty());
final boolean inSkipExperiment = device.getGcmId() != null && experimentEnrollmentManager.isEnrolled( QUEUE_SIZE_DISTRIBUTION_SUMMARY.record(messageCount);
accountUuid,
MessageSender.ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT);
DistributionSummary.builder(QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME)
.tags(Tags.of(platformTag).and("lowUrgencySkip", Boolean.toString(inSkipExperiment)))
.publishPercentileHistogram(true)
.register(Metrics.globalRegistry)
.record(messageCount);
} catch (ItemCollectionSizeLimitExceededException e) { } catch (ItemCollectionSizeLimitExceededException e) {
final boolean isPrimary = deviceId == Device.PRIMARY_ID; final boolean isPrimary = deviceId == Device.PRIMARY_ID;
Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment(); Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment();
@ -265,6 +234,7 @@ public class MessagePersister implements Managed {
messagesCache.unlockQueueForPersistence(accountUuid, deviceId); messagesCache.unlockQueueForPersistence(accountUuid, deviceId);
sample.stop(PERSIST_QUEUE_TIMER); sample.stop(PERSIST_QUEUE_TIMER);
} }
} }
private void trimQueue(final Account account, byte deviceId) { private void trimQueue(final Account account, byte deviceId) {

View File

@ -15,8 +15,6 @@ import io.lettuce.core.Range;
import io.lettuce.core.ScoredValue; import io.lettuce.core.ScoredValue;
import io.lettuce.core.ZAddArgs; import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.models.partitions.ClusterPartitionParser;
import io.lettuce.core.cluster.models.partitions.Partitions;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
@ -670,15 +668,6 @@ public class MessagesCache {
.thenRun(() -> sample.stop(clearQueueTimer)); .thenRun(() -> sample.stop(clearQueueTimer));
} }
public String shardForSlot(int slot) {
try {
return redisCluster.withBinaryCluster(
connection -> connection.getPartitions().getPartitionBySlot(slot).getUri().getHost());
} catch (Throwable ignored) {
return "unknown";
}
}
int getNextSlotToPersist() { int getNextSlotToPersist() {
return (int) (redisCluster.withCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY)) return (int) (redisCluster.withCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY))
% SlotHash.SLOT_COUNT); % SlotHash.SLOT_COUNT);

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.subscriptions;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.stripe.Stripe;
import com.stripe.StripeClient; import com.stripe.StripeClient;
import com.stripe.exception.CardException; import com.stripe.exception.CardException;
import com.stripe.exception.IdempotencyException; import com.stripe.exception.IdempotencyException;
@ -72,7 +71,6 @@ import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerVersion;
import org.whispersystems.textsecuregcm.storage.PaymentTime; import org.whispersystems.textsecuregcm.storage.PaymentTime;
import org.whispersystems.textsecuregcm.storage.SubscriptionException; import org.whispersystems.textsecuregcm.storage.SubscriptionException;
import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Conversions;
@ -99,9 +97,6 @@ public class StripeManager implements CustomerAwareSubscriptionPaymentProcessor
if (Strings.isNullOrEmpty(apiKey)) { if (Strings.isNullOrEmpty(apiKey)) {
throw new IllegalArgumentException("apiKey cannot be empty"); throw new IllegalArgumentException("apiKey cannot be empty");
} }
Stripe.setAppInfo("Signal-Server", WhisperServerVersion.getServerVersion());
this.stripeClient = new StripeClient(apiKey); this.stripeClient = new StripeClient(apiKey);
this.executor = Objects.requireNonNull(executor); this.executor = Objects.requireNonNull(executor);
this.idempotencyKeyGenerator = Objects.requireNonNull(idempotencyKeyGenerator); this.idempotencyKeyGenerator = Objects.requireNonNull(idempotencyKeyGenerator);

View File

@ -1,92 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.google.common.annotations.VisibleForTesting;
import java.util.concurrent.atomic.AtomicInteger;
/**
* A closable epoch is a concurrency construct that measures the number of callers in some critical section. A closable
* epoch can be closed to prevent new callers from entering the critical section, and takes a specific action when the
* critical section is empty after closure.
*/
public class ClosableEpoch {
private final Runnable onCloseHandler;
private final AtomicInteger state = new AtomicInteger();
private static final int CLOSING_BIT_MASK = 0x00000001;
/**
* Constructs a new closable epoch that will execute the given handler when the epoch is closed and all callers have
* departed the critical section. The handler will be executed on the thread that calls {@link #close()} if the
* critical section is empty at the time of the call or on the last thread to call {@link #depart()} otherwise.
* Callers should provide handlers that delegate execution to a specific thread/executor if more precise control over
* which thread runs the handler is required.
*
* @param onCloseHandler a handler to run when the epoch is closed and all callers have departed the critical section
*/
public ClosableEpoch(final Runnable onCloseHandler) {
this.onCloseHandler = onCloseHandler;
}
/**
* Announces the arrival of a caller at the start of the critical section. If the caller is allowed to enter the
* critical section, the epoch's internal caller counter is incremented accordingly.
*
* @return {@code true} if the caller is allowed to enter the critical section or {@code false} if it is not allowed
* to enter the critical section because this epoch is closing
*/
public boolean tryArrive() {
// Increment the number of active callers if and only if we're not closing. We add 2 because the lowest bit encodes
// the "closing" state, and the bits above it encode the actual call count. More verbosely, we're doing
// `state += (1 << 1)` to avoid overwriting the closing state bit.
return !isClosing(state.updateAndGet(state -> isClosing(state) ? state : state + 2));
}
/**
* Announces the departure of a caller from the critical section. If the epoch is closing and the caller is the last
* to depart the critical section, then the epoch will fire its {@code onCloseHandler}.
*/
public void depart() {
// Decrement the active caller count unconditionally. As with `tryActive`, we work in increments of 2 to "dodge" the
// "is closing" bit. If the call count is zero and we're closing then `state` will just have the "closing" bit set.
if (state.addAndGet(-2) == CLOSING_BIT_MASK) {
onCloseHandler.run();
}
}
/**
* Closes this epoch, preventing new callers from entering the critical section. If the critical section is empty when
* this method is called, it will trigger the {@code onCloseHandler} immediately. Otherwise, the
* {@code onCloseHandler} will fire when the last caller departs the critical section.
*
* @throws IllegalStateException if this epoch is already closed; note that this exception is thrown on a
* "best-effort" basis to help callers detect bugs
*/
public void close() {
// Note that this is not airtight and is a "best-effort" check
if (isClosing(state.get())) {
throw new IllegalStateException("Epoch already closed");
}
// Set the "closing" bit. If the closing bit is the only bit set, then the call count is zero and we can call the
// "on close" handler.
if (state.updateAndGet(state -> state | CLOSING_BIT_MASK) == CLOSING_BIT_MASK) {
onCloseHandler.run();
}
}
@VisibleForTesting
int getActiveCallers() {
return state.get() >> 1;
}
private static boolean isClosing(final int state) {
return (state & CLOSING_BIT_MASK) != 0;
}
}

View File

@ -46,7 +46,7 @@ public class LoggingUnhandledExceptionMapper extends LoggingExceptionMapper<Thro
// streamline the user-agent if it is recognized // streamline the user-agent if it is recognized
final UserAgent ua = UserAgentUtil.parseUserAgentString(userAgent); final UserAgent ua = UserAgentUtil.parseUserAgentString(userAgent);
userAgent = String.format("%s %s", ua.platform(), ua.version()); userAgent = String.format("%s %s", ua.getPlatform(), ua.getVersion());
} catch (final UnrecognizedUserAgentException ignored) { } catch (final UnrecognizedUserAgentException ignored) {
} catch (final Exception e) { } catch (final Exception e) {

View File

@ -6,8 +6,58 @@
package org.whispersystems.textsecuregcm.util.ua; package org.whispersystems.textsecuregcm.util.ua;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import javax.annotation.Nullable;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
public record UserAgent(ClientPlatform platform, Semver version, @Nullable String additionalSpecifiers) { public class UserAgent {
private final ClientPlatform platform;
private final Semver version;
private final String additionalSpecifiers;
public UserAgent(final ClientPlatform platform, final Semver version) {
this(platform, version, null);
}
public UserAgent(final ClientPlatform platform, final Semver version, final String additionalSpecifiers) {
this.platform = platform;
this.version = version;
this.additionalSpecifiers = additionalSpecifiers;
}
public ClientPlatform getPlatform() {
return platform;
}
public Semver getVersion() {
return version;
}
public Optional<String> getAdditionalSpecifiers() {
return Optional.ofNullable(additionalSpecifiers);
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final UserAgent userAgent = (UserAgent)o;
return platform == userAgent.platform &&
version.equals(userAgent.version) &&
Objects.equals(additionalSpecifiers, userAgent.additionalSpecifiers);
}
@Override
public int hashCode() {
return Objects.hash(platform, version, additionalSpecifiers);
}
@Override
public String toString() {
return "UserAgent{" +
"platform=" + platform +
", version=" + version +
", additionalSpecifiers='" + additionalSpecifiers + '\'' +
'}';
}
} }

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.util.ua; package org.whispersystems.textsecuregcm.util.ua;
import com.google.common.annotations.VisibleForTesting;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@ -20,10 +21,10 @@ public class UserAgentUtil {
} }
try { try {
final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString); final UserAgent standardUserAgent = parseStandardUserAgentString(userAgentString);
if (matcher.matches()) { if (standardUserAgent != null) {
return new UserAgent(ClientPlatform.valueOf(matcher.group(1).toUpperCase()), new Semver(matcher.group(2)), StringUtils.stripToNull(matcher.group(4))); return standardUserAgent;
} }
} catch (final Exception e) { } catch (final Exception e) {
throw new UnrecognizedUserAgentException(e); throw new UnrecognizedUserAgentException(e);
@ -31,4 +32,15 @@ public class UserAgentUtil {
throw new UnrecognizedUserAgentException(); throw new UnrecognizedUserAgentException();
} }
@VisibleForTesting
static UserAgent parseStandardUserAgentString(final String userAgentString) {
final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);
if (matcher.matches()) {
return new UserAgent(ClientPlatform.valueOf(matcher.group(1).toUpperCase()), new Semver(matcher.group(2)), StringUtils.stripToNull(matcher.group(4)));
}
return null;
}
} }

View File

@ -12,7 +12,6 @@ import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
@ -46,7 +45,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final Scheduler messageDeliveryScheduler; private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager; private final ClientReleaseManager clientReleaseManager;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter; private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter; private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
@ -60,8 +58,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager, ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) {
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.receiptSender = receiptSender; this.receiptSender = receiptSender;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics; this.messageMetrics = messageMetrics;
@ -72,7 +69,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.messageDeliveryScheduler = messageDeliveryScheduler; this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
this.experimentEnrollmentManager = experimentEnrollmentManager;
openAuthenticatedWebSocketCounter = openAuthenticatedWebSocketCounter =
new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, CONNECTED_DURATION_TIMER_NAME, Tags.of(AUTHENTICATED_TAG_NAME, "true")); new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, CONNECTED_DURATION_TIMER_NAME, Tags.of(AUTHENTICATED_TAG_NAME, "true"));
@ -102,8 +98,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
scheduledExecutorService, scheduledExecutorService,
messageDeliveryScheduler, messageDeliveryScheduler,
clientReleaseManager, clientReleaseManager,
messageDeliveryLoopMonitor, messageDeliveryLoopMonitor);
experimentEnrollmentManager);
context.addWebsocketClosedListener((closingContext, statusCode, reason) -> { context.addWebsocketClosedListener((closingContext, statusCode, reason) -> {
// We begin the shutdown process by removing this client's "presence," which means it will again begin to // We begin the shutdown process by removing this client's "presence," which means it will again begin to

View File

@ -39,7 +39,6 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
@ -47,7 +46,6 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
@ -122,7 +120,6 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler; private final PushNotificationScheduler pushNotificationScheduler;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final AuthenticatedDevice auth; private final AuthenticatedDevice auth;
private final WebSocketClient client; private final WebSocketClient client;
@ -162,8 +159,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager, ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) {
ExperimentEnrollmentManager experimentEnrollmentManager) {
this(receiptSender, this(receiptSender,
messagesManager, messagesManager,
@ -176,7 +172,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
scheduledExecutorService, scheduledExecutorService,
messageDeliveryScheduler, messageDeliveryScheduler,
clientReleaseManager, clientReleaseManager,
messageDeliveryLoopMonitor, experimentEnrollmentManager); messageDeliveryLoopMonitor);
} }
@VisibleForTesting @VisibleForTesting
@ -191,8 +187,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager, ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) {
ExperimentEnrollmentManager experimentEnrollmentManager) {
this.receiptSender = receiptSender; this.receiptSender = receiptSender;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
@ -206,7 +201,6 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
this.messageDeliveryScheduler = messageDeliveryScheduler; this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
this.experimentEnrollmentManager = experimentEnrollmentManager;
} }
public void start() { public void start() {
@ -337,13 +331,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
// Cleared the queue! Send a queue empty message if we need to // Cleared the queue! Send a queue empty message if we need to
consecutiveRetries.set(0); consecutiveRetries.set(0);
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
final boolean inSkipExperiment = auth.getAuthenticatedDevice().getGcmId() != null && experimentEnrollmentManager.isEnrolled( final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
auth.getAccount().getUuid(),
MessageSender.ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT);
final Tags tags = Tags
.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()))
.and("lowUrgencySkip", Boolean.toString(inSkipExperiment));
final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get(); final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get();
Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum());

View File

@ -14,7 +14,6 @@ import java.time.Duration;
import net.sourceforge.argparse4j.inf.Namespace; import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser; import net.sourceforge.argparse4j.inf.Subparser;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.MessagePersister; import org.whispersystems.textsecuregcm.storage.MessagePersister;
import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler; import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler;
@ -65,7 +64,6 @@ public class MessagePersisterServiceCommand extends ServerCommand<WhisperServerC
deps.messagesManager(), deps.messagesManager(),
deps.accountsManager(), deps.accountsManager(),
deps.dynamicConfigurationManager(), deps.dynamicConfigurationManager(),
new ExperimentEnrollmentManager(deps.dynamicConfigurationManager()),
Duration.ofMinutes(configuration.getMessageCacheConfiguration().getPersistDelayMinutes()), Duration.ofMinutes(configuration.getMessageCacheConfiguration().getPersistDelayMinutes()),
namespace.getInt(WORKER_COUNT)); namespace.getInt(WORKER_COUNT));

View File

@ -24,12 +24,10 @@ service Messages {
* destination account. * destination account.
* *
* This RPC may fail with a `NOT_FOUND` status if the destination account was * This RPC may fail with a `NOT_FOUND` status if the destination account was
* not found. It may also fail with an `INVALID_ARGUMENT` status if the * not found. It may also fail with a `RESOURCE_EXHAUSTED` status if a rate
* destination account is the same as the authenticated caller (callers should * limit for sending messages has been exceeded, in which case a `retry-after`
* use `SendSyncMessage` to send messages to themselves). It may also fail * header containing an ISO 8601 duration string may be present in the
* with a `RESOURCE_EXHAUSTED` status if a rate limit for sending messages has * response trailers.
* been exceeded, in which case a `retry-after` header containing an ISO 8601
* duration string may be present in the response trailers.
* *
* Note that message delivery may not succeed even if this RPC returns an `OK` * Note that message delivery may not succeed even if this RPC returns an `OK`
* status; callers must check the response object to verify that the message * status; callers must check the response object to verify that the message
@ -144,12 +142,9 @@ message IndividualRecipientMessageBundle {
/** /**
* The time, in milliseconds since the epoch, at which this message was * The time, in milliseconds since the epoch, at which this message was
* originally sent from the perspective of the sender. Note that the maximum * originally sent from the perspective of the sender.
* allowable timestamp for JavaScript clients is less than Long.MAX_VALUE; see
* https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Date#the_epoch_timestamps_and_invalid_date
* for additional details and discussion.
*/ */
uint64 timestamp = 1 [(require.range).min = 1, (require.range).max = 8640000000000000]; uint64 timestamp = 1;
/** /**
* A map of device IDs to individual messages. Generally, callers must include * A map of device IDs to individual messages. Generally, callers must include
@ -332,12 +327,9 @@ message MultiRecipientMessage {
/** /**
* The time, in milliseconds since the epoch, at which this message was * The time, in milliseconds since the epoch, at which this message was
* originally sent from the perspective of the sender. Note that the maximum * originally sent from the perspective of the sender.
* allowable timestamp for JavaScript clients is less than Long.MAX_VALUE; see
* https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Date#the_epoch_timestamps_and_invalid_date
* for additional details and discussion.
*/ */
uint64 timestamp = 1 [(require.range).min = 1, (require.range).max = 8640000000000000]; uint64 timestamp = 1;
/** /**
* The serialized multi-recipient message payload. * The serialized multi-recipient message payload.

View File

@ -149,8 +149,8 @@ message SizeConstraint {
} }
message ValueRangeConstraint { message ValueRangeConstraint {
optional int64 min = 1; optional int32 min = 1;
optional int64 max = 2; optional int32 max = 2;
} }
extend google.protobuf.ServiceOptions { extend google.protobuf.ServiceOptions {

View File

@ -5,11 +5,8 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import static com.github.tomakehurst.wiremock.client.WireMock.created; import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -18,19 +15,17 @@ import static org.mockito.Mockito.when;
import com.github.tomakehurst.wiremock.junit5.WireMockExtension; import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import io.netty.resolver.dns.DnsNameResolver; import io.netty.resolver.dns.DnsNameResolver;
import io.netty.util.concurrent.GlobalEventExecutor; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.SucceededFuture;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.time.Duration; import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -40,41 +35,31 @@ import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
public class CloudflareTurnCredentialsManagerTest { public class CloudflareTurnCredentialsManagerTest {
@RegisterExtension @RegisterExtension
private static final WireMockExtension wireMock = WireMockExtension.newInstance() private final WireMockExtension wireMock = WireMockExtension.newInstance()
.options(wireMockConfig().dynamicPort().dynamicHttpsPort()) .options(wireMockConfig().dynamicPort().dynamicHttpsPort())
.build(); .build();
private static final String GET_CREDENTIALS_PATH = "/v1/turn/keys/LMNOP/credentials/generate";
private static final String TURN_HOSTNAME = "localhost";
private ExecutorService httpExecutor; private ExecutorService httpExecutor;
private ScheduledExecutorService retryExecutor; private ScheduledExecutorService retryExecutor;
private DnsNameResolver dnsResolver; private DnsNameResolver dnsResolver;
private Future<List<InetAddress>> dnsResult;
private CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager; private CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = null;
private static final String GET_CREDENTIALS_PATH = "/v1/turn/keys/LMNOP/credentials/generate";
private static final String TURN_HOSTNAME = "localhost";
private static final String API_TOKEN = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final String USERNAME = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final String CREDENTIAL = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final List<String> CLOUDFLARE_TURN_URLS = List.of("turn:cf.example.com");
private static final Duration REQUESTED_CREDENTIAL_TTL = Duration.ofSeconds(100);
private static final Duration CLIENT_CREDENTIAL_TTL = REQUESTED_CREDENTIAL_TTL.dividedBy(2);
private static final List<String> IP_URL_PATTERNS = List.of("turn:%s", "turn:%s:80?transport=tcp", "turns:%s:443?transport=tcp");
@BeforeEach @BeforeEach
void setUp() { void setUp() throws CertificateException {
httpExecutor = Executors.newSingleThreadExecutor(); httpExecutor = Executors.newSingleThreadExecutor();
retryExecutor = Executors.newSingleThreadScheduledExecutor(); retryExecutor = Executors.newSingleThreadScheduledExecutor();
dnsResolver = mock(DnsNameResolver.class); dnsResolver = mock(DnsNameResolver.class);
dnsResult = mock(Future.class);
cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager( cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
API_TOKEN, "API_TOKEN",
"http://localhost:" + wireMock.getPort() + GET_CREDENTIALS_PATH, "http://localhost:" + wireMock.getPort() + GET_CREDENTIALS_PATH,
REQUESTED_CREDENTIAL_TTL, 100,
CLIENT_CREDENTIAL_TTL, List.of("turn:cf.example.com"),
CLOUDFLARE_TURN_URLS, List.of("turn:%s", "turn:%s:80?transport=tcp", "turns:%s:443?transport=tcp"),
IP_URL_PATTERNS,
TURN_HOSTNAME, TURN_HOSTNAME,
2, 2,
new CircuitBreakerConfiguration(), new CircuitBreakerConfiguration(),
@ -88,61 +73,26 @@ public class CloudflareTurnCredentialsManagerTest {
@AfterEach @AfterEach
void tearDown() throws InterruptedException { void tearDown() throws InterruptedException {
httpExecutor.shutdown(); httpExecutor.shutdown();
retryExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
httpExecutor.awaitTermination(1, TimeUnit.SECONDS); httpExecutor.awaitTermination(1, TimeUnit.SECONDS);
retryExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
retryExecutor.awaitTermination(1, TimeUnit.SECONDS); retryExecutor.awaitTermination(1, TimeUnit.SECONDS);
} }
@Test @Test
public void testSuccess() throws IOException, CancellationException { public void testSuccess() throws IOException, CancellationException, ExecutionException, InterruptedException {
wireMock.stubFor(post(urlEqualTo(GET_CREDENTIALS_PATH)) wireMock.stubFor(post(urlEqualTo(GET_CREDENTIALS_PATH))
.willReturn(created() .willReturn(aResponse().withStatus(201).withHeader("Content-Type", new String[]{"application/json"}).withBody("{\"iceServers\":{\"urls\":[\"turn:cloudflare.example.com:3478?transport=udp\"],\"username\":\"ABC\",\"credential\":\"XYZ\"}}")));
.withHeader("Content-Type", "application/json") when(dnsResult.get())
.withBody(""" .thenReturn(List.of(InetAddress.getByName("127.0.0.1"), InetAddress.getByName("::1")));
{
"iceServers": {
"urls": [
"turn:cloudflare.example.com:3478?transport=udp"
],
"username": "%s",
"credential": "%s"
}
}
""".formatted(USERNAME, CREDENTIAL))));
when(dnsResolver.resolveAll(TURN_HOSTNAME)) when(dnsResolver.resolveAll(TURN_HOSTNAME))
.thenReturn(new SucceededFuture<>(GlobalEventExecutor.INSTANCE, .thenReturn(dnsResult);
List.of(InetAddress.getByName("127.0.0.1"), InetAddress.getByName("::1"))));
TurnToken token = cloudflareTurnCredentialsManager.retrieveFromCloudflare(); TurnToken token = cloudflareTurnCredentialsManager.retrieveFromCloudflare();
wireMock.verify(postRequestedFor(urlEqualTo(GET_CREDENTIALS_PATH)) assertThat(token.username()).isEqualTo("ABC");
.withHeader("Content-Type", equalTo("application/json")) assertThat(token.password()).isEqualTo("XYZ");
.withHeader("Authorization", equalTo("Bearer " + API_TOKEN)) assertThat(token.hostname()).isEqualTo("localhost");
.withRequestBody(equalToJson(""" assertThat(token.urlsWithIps()).containsAll(List.of("turn:127.0.0.1", "turn:127.0.0.1:80?transport=tcp", "turns:127.0.0.1:443?transport=tcp", "turn:[0:0:0:0:0:0:0:1]", "turn:[0:0:0:0:0:0:0:1]:80?transport=tcp", "turns:[0:0:0:0:0:0:0:1]:443?transport=tcp"));;
{ assertThat(token.urls()).isEqualTo(List.of("turn:cf.example.com"));
"ttl": %d
}
""".formatted(REQUESTED_CREDENTIAL_TTL.toSeconds()))));
assertThat(token.username()).isEqualTo(USERNAME);
assertThat(token.password()).isEqualTo(CREDENTIAL);
assertThat(token.hostname()).isEqualTo(TURN_HOSTNAME);
assertThat(token.urls()).isEqualTo(CLOUDFLARE_TURN_URLS);
assertThat(token.ttlSeconds()).isEqualTo(CLIENT_CREDENTIAL_TTL.toSeconds());
final List<String> expectedUrlsWithIps = new ArrayList<>();
for (final String ip : new String[] {"127.0.0.1", "[0:0:0:0:0:0:0:1]"}) {
for (final String pattern : IP_URL_PATTERNS) {
expectedUrlsWithIps.add(pattern.formatted(ip));
}
}
assertThat(token.urlsWithIps()).containsExactlyElementsOf(expectedUrlsWithIps);
} }
} }

View File

@ -3,7 +3,6 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Status; import io.grpc.Status;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -23,7 +22,7 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
} }
@Test @Test
void interceptCall() throws ChannelNotFoundException { void interceptCall() {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
@ -35,10 +34,6 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice); GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
} }
} }

View File

@ -9,7 +9,6 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -23,12 +22,12 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
} }
@Test @Test
void interceptCall() throws ChannelNotFoundException { void interceptCall() {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice); GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
@ -36,9 +35,5 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice(); final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier()); assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId()); assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
} }
} }

View File

@ -40,15 +40,14 @@ class CallRoutingControllerV2Test {
private static final TurnToken CLOUDFLARE_TURN_TOKEN = new TurnToken( private static final TurnToken CLOUDFLARE_TURN_TOKEN = new TurnToken(
"ABC", "ABC",
"XYZ", "XYZ",
43_200,
List.of("turn:cloudflare.example.com:3478?transport=udp"), List.of("turn:cloudflare.example.com:3478?transport=udp"),
null, null,
"cf.example.com"); "cf.example.com");
private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter getCallEndpointLimiter = mock(RateLimiter.class); private static final RateLimiter getCallEndpointLimiter = mock(RateLimiter.class);
private static final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = private static final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = mock(
mock(CloudflareTurnCredentialsManager.class); CloudflareTurnCredentialsManager.class);
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
@ -67,14 +66,21 @@ class CallRoutingControllerV2Test {
@AfterEach @AfterEach
void tearDown() { void tearDown() {
reset(rateLimiters, getCallEndpointLimiter); reset( rateLimiters, getCallEndpointLimiter);
}
void initializeMocksWith(TurnToken cloudflareToken) {
try {
when(cloudflareTurnCredentialsManager.retrieveFromCloudflare()).thenReturn(cloudflareToken);
} catch (IOException ignored) {
}
} }
@Test @Test
void testGetRelaysBothRouting() throws IOException { void testGetRelaysBothRouting() {
when(cloudflareTurnCredentialsManager.retrieveFromCloudflare()).thenReturn(CLOUDFLARE_TURN_TOKEN); initializeMocksWith(CLOUDFLARE_TURN_TOKEN);
try (final Response rawResponse = resources.getJerseyTest() try (Response rawResponse = resources.getJerseyTest()
.target(GET_CALL_RELAYS_PATH) .target(GET_CALL_RELAYS_PATH)
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
@ -82,8 +88,11 @@ class CallRoutingControllerV2Test {
assertThat(rawResponse.getStatus()).isEqualTo(200); assertThat(rawResponse.getStatus()).isEqualTo(200);
assertThat(rawResponse.readEntity(GetCallingRelaysResponse.class).relays()) CallRoutingControllerV2.GetCallingRelaysResponse response = rawResponse.readEntity(
.isEqualTo(List.of(CLOUDFLARE_TURN_TOKEN)); CallRoutingControllerV2.GetCallingRelaysResponse.class);
List<TurnToken> relays = response.relays();
assertThat(relays).isEqualTo(List.of(CLOUDFLARE_TURN_TOKEN));
} }
} }

View File

@ -41,7 +41,6 @@ import java.util.Optional;
import java.util.OptionalInt; import java.util.OptionalInt;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
@ -943,45 +942,6 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getStatus()).isEqualTo(400);
} }
@Test
void putKeysTooManySingleUseECKeys() {
final List<ECPreKey> preKeys = IntStream.range(31337, 31438).mapToObj(KeysHelper::ecPreKey).toList();
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, AuthHelper.VALID_IDENTITY_KEY_PAIR);
final SetKeysRequest setKeysRequest = new SetKeysRequest(preKeys, signedPreKey, null, null);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(setKeysRequest, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
verifyNoMoreInteractions(KEYS);
}
@Test
void putKeysTooManySingleUseKEMKeys() {
final List<KEMSignedPreKey> pqPreKeys = IntStream.range(31337, 31438)
.mapToObj(id -> KeysHelper.signedKEMPreKey(id, AuthHelper.VALID_IDENTITY_KEY_PAIR))
.toList();
final SetKeysRequest setKeysRequest = new SetKeysRequest(null, null, pqPreKeys, null);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(setKeysRequest, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
verifyNoMoreInteractions(KEYS);
}
@Test @Test
void putKeysByPhoneNumberIdentifierTestV2() { void putKeysByPhoneNumberIdentifierTestV2() {
final ECPreKey preKey = KeysHelper.ecPreKey(31337); final ECPreKey preKey = KeysHelper.ecPreKey(31337);

View File

@ -292,7 +292,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -319,19 +319,7 @@ class MessageControllerTest {
IncomingMessageList.class), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE))) { MediaType.APPLICATION_JSON_TYPE))) {
if (sendToPni) { assertThat(response.getStatus(), is(equalTo(sendToPni ? 403 : 200)));
assertThat(response.getStatus(), is(equalTo(403)));
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
} else {
assertThat(response.getStatus(), is(equalTo(200)));
verify(messageSender).sendMessages(any(),
eq(serviceIdentifier),
any(),
any(),
eq(Optional.of(Device.PRIMARY_ID)),
any());
}
} }
} }
@ -349,7 +337,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -374,7 +362,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -412,7 +400,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -451,7 +439,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse))); assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
if (expectedResponse == 200) { if (expectedResponse == 200) {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -548,7 +536,7 @@ class MessageControllerTest {
@Test @Test
void testMultiDeviceMissing() throws Exception { void testMultiDeviceMissing() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet()))) doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); .when(messageSender).sendMessages(any(), any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -570,7 +558,7 @@ class MessageControllerTest {
@Test @Test
void testMultiDeviceExtra() throws Exception { void testMultiDeviceExtra() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet()))) doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); .when(messageSender).sendMessages(any(), any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -621,7 +609,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), eq(Optional.empty()), any()); verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), any());
assertEquals(3, envelopeCaptor.getValue().size()); assertEquals(3, envelopeCaptor.getValue().size());
@ -645,7 +633,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), eq(Optional.empty()), any()); verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), any());
assertEquals(3, envelopeCaptor.getValue().size()); assertEquals(3, envelopeCaptor.getValue().size());
@ -670,7 +658,6 @@ class MessageControllerTest {
any(), any(),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3), argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3),
any(), any(),
eq(Optional.empty()),
any()); any());
} }
} }
@ -678,7 +665,7 @@ class MessageControllerTest {
@Test @Test
void testRegistrationIdMismatch() throws Exception { void testRegistrationIdMismatch() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2)))) doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2))))
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); .when(messageSender).sendMessages(any(), any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@ -1103,7 +1090,7 @@ class MessageControllerTest {
@Test @Test
void testValidateContentLength() throws MismatchedDevicesException, MessageTooLargeException, IOException { void testValidateContentLength() throws MismatchedDevicesException, MessageTooLargeException, IOException {
doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -1132,10 +1119,10 @@ class MessageControllerTest {
if (expectOk) { if (expectOk) {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
verify(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); verify(messageSender).sendMessages(any(), any(), any(), any(), any());
} else { } else {
assertEquals(422, response.getStatus()); assertEquals(422, response.getStatus());
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any());
} }
} }
} }

View File

@ -1140,58 +1140,6 @@ class ProfileControllerTest {
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false), new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false)); new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false));
} }
}
@Test
void testSetProfileBadgeAfterUpdateTries() throws Exception {
final ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(
new ServiceId.Aci(AuthHelper.VALID_UUID));
final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156);
final String version = versionHex("anotherversion");
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
reset(accountsManager);
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(AuthHelper.VALID_ACCOUNT_TWO.getBadges())
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true)
))
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), true)
));
try (final Response response = resources.getJerseyTest()
.target("/v1/profile/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))
.put(Entity.entity(new CreateProfileRequest(commitment, version, name, emoji, about, null, false, false,
Optional.of(List.of("TEST1")), null), MediaType.APPLICATION_JSON_TYPE))) {
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.hasEntity()).isFalse();
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).getBadges();
final List<AccountBadge> badges = badgeCaptor.getValue();
assertThat(badges).isNotNull().hasSize(4).containsOnly(
new AccountBadge("TEST1", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), false));
}
} }
@ParameterizedTest @ParameterizedTest

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.dropwizard.core.Application; import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration; import io.dropwizard.core.Configuration;
@ -25,6 +24,7 @@ import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Client; import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import java.net.InetAddress;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.Set; import java.util.Set;
@ -39,7 +39,6 @@ import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.InetAddressRange; import org.whispersystems.textsecuregcm.util.InetAddressRange;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@ -158,7 +157,7 @@ class ExternalRequestFilterTest {
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null)); mockRequestAttributesInterceptor.setRemoteAddress(InetAddress.getByName("127.0.0.1"));
testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest") testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest")
.directExecutor() .directExecutor()

View File

@ -15,7 +15,6 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
@ -41,10 +40,11 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.grpc.StatusConstants; import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RemoteDeprecationFilterTest { class RemoteDeprecationFilterTest {
@ -130,7 +130,11 @@ class RemoteDeprecationFilterTest {
@MethodSource(value="testFilter") @MethodSource(value="testFilter")
void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException { void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), userAgentString, null));
try {
mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString));
} catch (UnrecognizedUserAgentException ignored) {
}
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest") final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor() .directExecutor()

View File

@ -72,8 +72,7 @@ class AccountsAnonymousGrpcServiceTest extends
when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
getMockRequestAttributesInterceptor().setRequestAttributes( getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1"));
new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
return new AccountsAnonymousGrpcService(accountsManager, rateLimiters); return new AccountsAnonymousGrpcService(accountsManager, rateLimiters);
} }

View File

@ -1,88 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
class ChannelShutdownInterceptorTest {
private GrpcClientConnectionManager grpcClientConnectionManager;
private ChannelShutdownInterceptor channelShutdownInterceptor;
private ServerCallHandler<String, String> nextCallHandler;
private static final Metadata HEADERS = new Metadata();
@BeforeEach
void setUp() {
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager);
//noinspection unchecked
nextCallHandler = mock(ServerCallHandler.class);
//noinspection unchecked
when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class));
}
@Test
void interceptCallComplete() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onComplete();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallCancelled() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onCancel();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallChannelClosing() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false);
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall);
verify(serverCall).close(eq(Status.UNAVAILABLE), any());
}
}

View File

@ -12,38 +12,14 @@ import org.signal.chat.rpc.EchoServiceGrpc;
public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase { public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase {
@Override @Override
public void echo(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) { public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(buildResponse(echoRequest)); responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build());
responseObserver.onCompleted(); responseObserver.onCompleted();
} }
@Override @Override
public void echo2(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) { public void echo2(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(buildResponse(echoRequest)); responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build());
responseObserver.onCompleted(); responseObserver.onCompleted();
} }
@Override
public StreamObserver<EchoRequest> echoStream(final StreamObserver<EchoResponse> responseObserver) {
return new StreamObserver<>() {
@Override
public void onNext(final EchoRequest echoRequest) {
responseObserver.onNext(buildResponse(echoRequest));
}
@Override
public void onError(final Throwable throwable) {
responseObserver.onError(throwable);
}
@Override
public void onCompleted() {
responseObserver.onCompleted();
}
};
}
private static EchoResponse buildResponse(final EchoRequest echoRequest) {
return EchoResponse.newBuilder().setPayload(echoRequest.getPayload()).build();
}
} }

View File

@ -1,618 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.Mock;
import org.signal.chat.messages.AuthenticatedSenderMessageType;
import org.signal.chat.messages.ChallengeRequired;
import org.signal.chat.messages.IndividualRecipientMessageBundle;
import org.signal.chat.messages.MessagesGrpc;
import org.signal.chat.messages.MismatchedDevices;
import org.signal.chat.messages.SendAuthenticatedSenderMessageRequest;
import org.signal.chat.messages.SendMessageResponse;
import org.signal.chat.messages.SendSyncMessageRequest;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.spam.GrpcResponse;
import org.whispersystems.textsecuregcm.spam.MessageType;
import org.whispersystems.textsecuregcm.spam.SpamCheckResult;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
class MessagesGrpcServiceTest extends SimpleBaseGrpcTest<MessagesGrpcService, MessagesGrpc.MessagesBlockingStub> {
@Mock
private AccountsManager accountsManager;
@Mock
private RateLimiters rateLimiters;
@Mock
private MessageSender messageSender;
@Mock
private CardinalityEstimator messageByteLimitEstimator;
@Mock
private SpamChecker spamChecker;
@Mock
private RateLimiter rateLimiter;
@Mock
private Account authenticatedAccount;
@Mock
private Device authenticatedDevice;
@Mock
private Device linkedDevice;
@Mock
private Device secondLinkedDevice;
private static final int AUTHENTICATED_REGISTRATION_ID = 7;
private static final byte LINKED_DEVICE_ID = AUTHENTICATED_DEVICE_ID + 1;
private static final int LINKED_DEVICE_REGISTRATION_ID = 13;
private static final byte SECOND_LINKED_DEVICE_ID = LINKED_DEVICE_ID + 1;
private static final int SECOND_LINKED_DEVICE_REGISTRATION_ID = 19;
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
@Override
protected MessagesGrpcService createServiceBeforeEachTest() {
return new MessagesGrpcService(accountsManager,
rateLimiters,
messageSender,
messageByteLimitEstimator,
spamChecker,
CLOCK);
}
@BeforeEach
void setUp() {
when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.empty()));
when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
when(authenticatedDevice.getRegistrationId()).thenReturn(AUTHENTICATED_REGISTRATION_ID);
when(linkedDevice.getId()).thenReturn(LINKED_DEVICE_ID);
when(linkedDevice.getRegistrationId()).thenReturn(LINKED_DEVICE_REGISTRATION_ID);
when(secondLinkedDevice.getId()).thenReturn(SECOND_LINKED_DEVICE_ID);
when(secondLinkedDevice.getRegistrationId()).thenReturn(SECOND_LINKED_DEVICE_REGISTRATION_ID);
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI);
when(authenticatedAccount.getDevice(anyByte())).thenReturn(Optional.empty());
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice));
when(authenticatedAccount.getDevice(LINKED_DEVICE_ID)).thenReturn(Optional.of(linkedDevice));
when(authenticatedAccount.getDevice(SECOND_LINKED_DEVICE_ID)).thenReturn(Optional.of(secondLinkedDevice));
when(authenticatedAccount.getDevices()).thenReturn(List.of(authenticatedDevice, linkedDevice, secondLinkedDevice));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AUTHENTICATED_ACI)))
.thenReturn(Optional.of(authenticatedAccount));
}
@Nested
class SingleRecipient {
@CartesianTest
void sendMessage(@CartesianTest.Enum(mode = CartesianTest.Enum.Mode.EXCLUDE, names = {"UNSPECIFIED", "UNRECOGNIZED"}) final AuthenticatedSenderMessageType messageType,
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean includeReportSpamToken)
throws MessageTooLargeException, MismatchedDevicesException {
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 7;
final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId);
final Account destinationAccount = mock(Account.class);
when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice));
when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice));
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final byte[] reportSpamToken = TestRandomUtil.nextBytes(64);
if (includeReportSpamToken) {
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.of(reportSpamToken)));
}
final byte[] payload = TestRandomUtil.nextBytes(128);
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(registrationId)
.setPayload(ByteString.copyFrom(payload))
.build());
final SendMessageResponse response = authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, messageType, ephemeral, urgent, messages));
assertEquals(SendMessageResponse.newBuilder().build(), response);
final MessageProtos.Envelope.Type expectedEnvelopeType = switch (messageType) {
case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT;
case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE;
case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT;
case UNSPECIFIED, UNRECOGNIZED -> throw new IllegalArgumentException("Unexpected message type: " + messageType);
};
final MessageProtos.Envelope.Builder expectedEnvelopeBuilder = MessageProtos.Envelope.newBuilder()
.setType(expectedEnvelopeType)
.setSourceServiceId(new AciServiceIdentifier(AUTHENTICATED_ACI).toServiceIdentifierString())
.setSourceDevice(AUTHENTICATED_DEVICE_ID)
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setEphemeral(ephemeral)
.setUrgent(urgent)
.setContent(ByteString.copyFrom(payload));
if (includeReportSpamToken) {
expectedEnvelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken));
}
verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_IDENTIFIED_SENDER,
Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)),
Optional.of(destinationAccount),
serviceIdentifier);
verify(messageSender).sendMessages(destinationAccount,
serviceIdentifier,
Map.of(deviceId, expectedEnvelopeBuilder.build()),
Map.of(deviceId, registrationId),
Optional.empty(),
null);
}
@Test
void mismatchedDevices() throws MessageTooLargeException, MismatchedDevicesException {
final byte missingDeviceId = Device.PRIMARY_ID;
final byte extraDeviceId = missingDeviceId + 1;
final byte staleDeviceId = extraDeviceId + 1;
final Account destinationAccount = mock(Account.class);
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages = Map.of(
staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(Device.PRIMARY_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices(
Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId))))
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
final SendMessageResponse response = authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages));
final SendMessageResponse expectedResponse = SendMessageResponse.newBuilder()
.setMismatchedDevices(MismatchedDevices.newBuilder()
.setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier))
.addMissingDevices(missingDeviceId)
.addStaleDevices(staleDeviceId)
.addExtraDevices(extraDeviceId)
.build())
.build();
assertEquals(expectedResponse, response);
}
@Test
void destinationNotFound() throws MessageTooLargeException, MismatchedDevicesException {
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(Device.PRIMARY_ID, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(1234)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.NOT_FOUND,
() -> authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
}
@Test
void rateLimited() throws RateLimitExceededException, MessageTooLargeException, MismatchedDevicesException {
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 7;
final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId);
final Account destinationAccount = mock(Account.class);
when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice));
when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice));
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Duration retryDuration = Duration.ofHours(7);
doThrow(new RateLimitExceededException(retryDuration))
.when(rateLimiter).validate(eq(serviceIdentifier.uuid()), anyInt());
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(registrationId)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertRateLimitExceeded(retryDuration,
() -> authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
verify(messageByteLimitEstimator).add(serviceIdentifier.uuid().toString());
}
@Test
void oversizedMessage() throws MessageTooLargeException, MismatchedDevicesException {
final byte missingDeviceId = Device.PRIMARY_ID;
final byte extraDeviceId = missingDeviceId + 1;
final byte staleDeviceId = extraDeviceId + 1;
final Account destinationAccount = mock(Account.class);
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages = Map.of(
staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(Device.PRIMARY_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
doThrow(new MessageTooLargeException())
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT,
() -> authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
}
@Test
void spamWithStatus() throws MessageTooLargeException, MismatchedDevicesException {
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 7;
final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId);
final Account destinationAccount = mock(Account.class);
when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice));
when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice));
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(registrationId)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withStatus(Status.RESOURCE_EXHAUSTED)), Optional.empty()));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.RESOURCE_EXHAUSTED,
() -> authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_IDENTIFIED_SENDER,
Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)),
Optional.of(destinationAccount),
serviceIdentifier);
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
}
@Test
void spamWithResponse() throws MessageTooLargeException, MismatchedDevicesException {
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 7;
final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId);
final Account destinationAccount = mock(Account.class);
when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice));
when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice));
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(registrationId)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
final SendMessageResponse response = SendMessageResponse.newBuilder()
.setChallengeRequired(ChallengeRequired.newBuilder()
.addChallengeOptions(ChallengeRequired.ChallengeType.CAPTCHA))
.build();
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withResponse(response)), Optional.empty()));
assertEquals(response, authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_IDENTIFIED_SENDER,
Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)),
Optional.of(destinationAccount),
serviceIdentifier);
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
}
private static SendAuthenticatedSenderMessageRequest generateRequest(final ServiceIdentifier serviceIdentifier,
final AuthenticatedSenderMessageType messageType,
final boolean ephemeral,
final boolean urgent,
final Map<Byte, IndividualRecipientMessageBundle.Message> messages) {
final IndividualRecipientMessageBundle.Builder messageBundleBuilder = IndividualRecipientMessageBundle.newBuilder()
.setTimestamp(CLOCK.millis());
messages.forEach(messageBundleBuilder::putMessages);
final SendAuthenticatedSenderMessageRequest.Builder requestBuilder = SendAuthenticatedSenderMessageRequest.newBuilder()
.setDestination(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier))
.setType(messageType)
.setMessages(messageBundleBuilder)
.setEphemeral(ephemeral)
.setUrgent(urgent);
return requestBuilder.build();
}
}
@Nested
class Sync {
@CartesianTest
void sendMessage(@CartesianTest.Enum(mode = CartesianTest.Enum.Mode.EXCLUDE, names = {"UNSPECIFIED", "UNRECOGNIZED"}) final AuthenticatedSenderMessageType messageType,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean includeReportSpamToken)
throws MessageTooLargeException, MismatchedDevicesException {
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(AUTHENTICATED_ACI);
final byte[] payload = TestRandomUtil.nextBytes(128);
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(LINKED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(LINKED_DEVICE_REGISTRATION_ID)
.setPayload(ByteString.copyFrom(payload))
.build(),
SECOND_LINKED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(SECOND_LINKED_DEVICE_REGISTRATION_ID)
.setPayload(ByteString.copyFrom(payload))
.build());
final byte[] reportSpamToken = TestRandomUtil.nextBytes(64);
if (includeReportSpamToken) {
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.of(reportSpamToken)));
}
final SendMessageResponse response =
authenticatedServiceStub().sendSyncMessage(generateRequest(messageType, urgent, messages));
assertEquals(SendMessageResponse.newBuilder().build(), response);
final MessageProtos.Envelope.Type expectedEnvelopeType = switch (messageType) {
case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT;
case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE;
case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT;
case UNSPECIFIED, UNRECOGNIZED -> throw new IllegalArgumentException("Unexpected message type: " + messageType);
};
final Map<Byte, MessageProtos.Envelope> expectedEnvelopes = new HashMap<>(Map.of(
LINKED_DEVICE_ID, MessageProtos.Envelope.newBuilder()
.setType(expectedEnvelopeType)
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(AUTHENTICATED_DEVICE_ID)
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setEphemeral(false)
.setUrgent(urgent)
.setContent(ByteString.copyFrom(payload))
.build(),
SECOND_LINKED_DEVICE_ID, MessageProtos.Envelope.newBuilder()
.setType(expectedEnvelopeType)
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(AUTHENTICATED_DEVICE_ID)
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setEphemeral(false)
.setUrgent(urgent)
.setContent(ByteString.copyFrom(payload))
.build()
));
if (includeReportSpamToken) {
expectedEnvelopes.replaceAll((deviceId, envelope) ->
envelope.toBuilder().setReportSpamToken(ByteString.copyFrom(reportSpamToken)).build());
}
verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.SYNC,
Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)),
Optional.of(authenticatedAccount),
serviceIdentifier);
verify(messageSender).sendMessages(authenticatedAccount,
serviceIdentifier,
expectedEnvelopes,
Map.of(LINKED_DEVICE_ID, LINKED_DEVICE_REGISTRATION_ID,
SECOND_LINKED_DEVICE_ID, SECOND_LINKED_DEVICE_REGISTRATION_ID),
Optional.of(AUTHENTICATED_DEVICE_ID),
null);
}
@Test
void mismatchedDevices() throws MessageTooLargeException, MismatchedDevicesException {
final byte missingDeviceId = Device.PRIMARY_ID;
final byte extraDeviceId = missingDeviceId + 1;
final byte staleDeviceId = extraDeviceId + 1;
final Map<Byte, IndividualRecipientMessageBundle.Message> messages = Map.of(
staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(Device.PRIMARY_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices(
Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId))))
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
final SendMessageResponse response = authenticatedServiceStub().sendSyncMessage(
generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages));
final SendMessageResponse expectedResponse = SendMessageResponse.newBuilder()
.setMismatchedDevices(MismatchedDevices.newBuilder()
.setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(new AciServiceIdentifier(AUTHENTICATED_ACI)))
.addMissingDevices(missingDeviceId)
.addStaleDevices(staleDeviceId)
.addExtraDevices(extraDeviceId)
.build())
.build();
assertEquals(expectedResponse, response);
}
@Test
void rateLimited() throws RateLimitExceededException, MessageTooLargeException, MismatchedDevicesException {
final Duration retryDuration = Duration.ofHours(7);
doThrow(new RateLimitExceededException(retryDuration))
.when(rateLimiter).validate(eq(AUTHENTICATED_ACI), anyInt());
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(AUTHENTICATED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(AUTHENTICATED_REGISTRATION_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertRateLimitExceeded(retryDuration,
() -> authenticatedServiceStub().sendSyncMessage(
generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages)));
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
verify(messageByteLimitEstimator).add(AUTHENTICATED_ACI.toString());
}
@Test
void oversizedMessage() throws MessageTooLargeException, MismatchedDevicesException {
final byte missingDeviceId = Device.PRIMARY_ID;
final byte extraDeviceId = missingDeviceId + 1;
final byte staleDeviceId = extraDeviceId + 1;
final Account destinationAccount = mock(Account.class);
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages = Map.of(
staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(Device.PRIMARY_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
doThrow(new MessageTooLargeException())
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT,
() -> authenticatedServiceStub().sendSyncMessage(
generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages)));
}
private static SendSyncMessageRequest generateRequest(final AuthenticatedSenderMessageType messageType,
final boolean urgent,
final Map<Byte, IndividualRecipientMessageBundle.Message> messages) {
final IndividualRecipientMessageBundle.Builder messageBundleBuilder = IndividualRecipientMessageBundle.newBuilder()
.setTimestamp(CLOCK.millis());
messages.forEach(messageBundleBuilder::putMessages);
final SendSyncMessageRequest.Builder requestBuilder = SendSyncMessageRequest.newBuilder()
.setType(messageType)
.setMessages(messageBundleBuilder)
.setUrgent(urgent);
return requestBuilder.build();
}
}
}

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import com.google.common.net.InetAddresses;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Contexts; import io.grpc.Contexts;
import io.grpc.Metadata; import io.grpc.Metadata;
@ -20,10 +19,25 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class MockRequestAttributesInterceptor implements ServerInterceptor { public class MockRequestAttributesInterceptor implements ServerInterceptor {
private RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null); @Nullable
private InetAddress remoteAddress;
public void setRequestAttributes(final RequestAttributes requestAttributes) { @Nullable
this.requestAttributes = requestAttributes; private UserAgent userAgent;
@Nullable
private List<Locale.LanguageRange> acceptLanguage;
public void setRemoteAddress(@Nullable final InetAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
public void setUserAgent(@Nullable final UserAgent userAgent) {
this.userAgent = userAgent;
}
public void setAcceptLanguage(@Nullable final List<Locale.LanguageRange> acceptLanguage) {
this.acceptLanguage = acceptLanguage;
} }
@Override @Override
@ -31,7 +45,20 @@ public class MockRequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
return Contexts.interceptCall(Context.current() Context context = Context.current();
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes), serverCall, headers, next);
if (remoteAddress != null) {
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress);
}
if (userAgent != null) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent);
}
if (acceptLanguage != null) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage);
}
return Contexts.interceptCall(context, serverCall, headers, next);
} }
} }

View File

@ -15,7 +15,6 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -76,6 +75,8 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> { public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@ -95,9 +96,13 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
@Override @Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() { protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
"Signal-Android/1.2.3",
Locale.LanguageRange.parse("en-us"))); try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
return new ProfileAnonymousGrpcService( return new ProfileAnonymousGrpcService(
accountsManager, accountsManager,

View File

@ -15,16 +15,13 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.refEq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber; import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
@ -33,7 +30,6 @@ import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -97,18 +93,18 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceCapability; import org.whispersystems.textsecuregcm.storage.DeviceCapability;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@ -148,8 +144,6 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
@Mock @Mock
private ServerZkProfileOperations serverZkProfileOperations; private ServerZkProfileOperations serverZkProfileOperations;
private Clock clock;
@Override @Override
protected ProfileGrpcService createServiceBeforeEachTest() { protected ProfileGrpcService createServiceBeforeEachTest() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class); @SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@ -176,9 +170,13 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164); PhoneNumberUtil.PhoneNumberFormat.E164);
getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
"Signal-Android/1.2.3",
Locale.LanguageRange.parse("en-us"))); try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter); when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());
@ -205,10 +203,8 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null)); when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
clock = Clock.fixed(Instant.ofEpochSecond(42), ZoneId.of("Etc/UTC"));
return new ProfileGrpcService( return new ProfileGrpcService(
clock, Clock.systemUTC(),
accountsManager, accountsManager,
profilesManager, profilesManager,
dynamicConfigurationManager, dynamicConfigurationManager,
@ -396,42 +392,6 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
} }
} }
@Test
void setProfileBadges() throws InvalidInputException {
final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(AUTHENTICATED_ACI)).serialize();
final SetProfileRequest request = SetProfileRequest.newBuilder()
.setVersion(VERSION)
.setName(ByteString.copyFrom(VALID_NAME))
.setAvatarChange(AvatarChange.AVATAR_CHANGE_UNCHANGED)
.addAllBadgeIds(List.of("TEST3"))
.setCommitment(ByteString.copyFrom(commitment))
.build();
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(account.getBadges())
.thenReturn(List.of(new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)))
.thenReturn(List.of(new AccountBadge("TEST2", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)));
//noinspection ResultOfMethodCallIgnored
authenticatedServiceStub().setProfile(request);
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(account, times(2)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(account, times(accountsManagerUpdateRetryCount)).getBadges();
assertEquals(List.of(
new AccountBadge("TEST3", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(41), false)),
badgeCaptor.getValue());
}
@ParameterizedTest @ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"}) @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void getUnversionedProfile(final IdentityType identityType) { void getUnversionedProfile(final IdentityType identityType) {

View File

@ -6,6 +6,7 @@ import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest; import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse; import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc; import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.chat.rpc.UserAgent;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
@ -18,15 +19,21 @@ public class RequestAttributesServiceImpl extends RequestAttributesGrpc.RequestA
final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder(); final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder();
RequestAttributesUtil.getAcceptableLanguages() RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages ->
.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString())); acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString())));
RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale -> RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale ->
responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag())); responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag()));
responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress()); responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress());
RequestAttributesUtil.getUserAgent().ifPresent(responseBuilder::setUserAgent); RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder()
.setPlatform(userAgent.getPlatform().toString())
.setVersion(userAgent.getVersion().toString())
.setAdditionalSpecifiers(userAgent.getAdditionalSpecifiers().orElse(""))
.build()));
RequestAttributesUtil.getRawUserAgent().ifPresent(responseBuilder::setRawUserAgent);
responseObserver.onNext(responseBuilder.build()); responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted(); responseObserver.onCompleted();

View File

@ -3,84 +3,172 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import io.grpc.Context; import io.grpc.ManagedChannel;
import java.net.InetAddress; import io.grpc.Server;
import java.util.Collections; import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.Callable; import org.junit.jupiter.api.AfterAll;
import javax.annotation.Nullable; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RequestAttributesUtilTest { class RequestAttributesUtilTest {
private static final InetAddress REMOTE_ADDRESS = InetAddresses.forString("127.0.0.1"); private static DefaultEventLoopGroup eventLoopGroup;
@Test private GrpcClientConnectionManager grpcClientConnectionManager;
void getAcceptableLanguages() throws Exception {
assertEquals(Collections.emptyList(),
callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAcceptableLanguages));
assertEquals(Locale.LanguageRange.parse("en,ja"), private Server server;
callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")), private ManagedChannel managedChannel;
RequestAttributesUtil::getAcceptableLanguages));
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server");
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString("127.0.0.1")));
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
} }
@Test @Test
void getAvailableAcceptedLocales() throws Exception { void getAcceptableLanguages() {
assertEquals(Collections.emptyList(), when(grpcClientConnectionManager.getAcceptableLanguages(any()))
callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()), .thenReturn(Optional.empty());
RequestAttributesUtil::getAvailableAcceptedLocales));
final List<Locale> availableAcceptedLocales = assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty());
callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAvailableAcceptedLocales);
assertFalse(availableAcceptedLocales.isEmpty()); when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
availableAcceptedLocales.forEach(locale -> assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList());
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage())));
} }
@Test @Test
void getRemoteAddress() throws Exception { void getAvailableAcceptedLocales() {
assertEquals(REMOTE_ADDRESS, when(grpcClientConnectionManager.getAcceptableLanguages(any()))
callWithRequestAttributes(new RequestAttributes(REMOTE_ADDRESS, null, null), .thenReturn(Optional.empty());
RequestAttributesUtil::getRemoteAddress));
assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty());
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
final GetRequestAttributesResponse response = getRequestAttributes();
assertFalse(response.getAvailableAcceptedLocalesList().isEmpty());
response.getAvailableAcceptedLocalesList().forEach(languageTag -> {
final Locale locale = Locale.forLanguageTag(languageTag);
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()));
});
} }
@Test @Test
void getUserAgent() throws Exception { void getRemoteAddress() {
assertEquals(Optional.empty(), when(grpcClientConnectionManager.getRemoteAddress(any()))
callWithRequestAttributes(buildRequestAttributes((String) null), .thenReturn(Optional.empty());
RequestAttributesUtil::getUserAgent));
assertEquals(Optional.of("Signal-Desktop/1.2.3 Linux"), GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes);
callWithRequestAttributes(buildRequestAttributes("Signal-Desktop/1.2.3 Linux"),
RequestAttributesUtil::getUserAgent)); final String remoteAddressString = "6.7.8.9";
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString(remoteAddressString)));
assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress());
} }
private static <V> V callWithRequestAttributes(final RequestAttributes requestAttributes, final Callable<V> callable) throws Exception { @Test
return Context.current() void getUserAgent() throws UnrecognizedUserAgentException {
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes) when(grpcClientConnectionManager.getUserAgent(any()))
.call(callable); .thenReturn(Optional.empty());
assertFalse(getRequestAttributes().hasUserAgent());
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
when(grpcClientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.of(userAgent));
final GetRequestAttributesResponse response = getRequestAttributes();
assertTrue(response.hasUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} }
private static RequestAttributes buildRequestAttributes(final String userAgent) { @Test
return buildRequestAttributes(userAgent, Collections.emptyList()); void getRawUserAgent() {
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getRawUserAgent().isBlank());
final String userAgentString = "Signal-Desktop/1.2.3 Linux";
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.of(userAgentString));
assertEquals(userAgentString, getRequestAttributes().getRawUserAgent());
} }
private static RequestAttributes buildRequestAttributes(final List<Locale.LanguageRange> acceptLanguage) { private GetRequestAttributesResponse getRequestAttributes() {
return buildRequestAttributes(null, acceptLanguage); return RequestAttributesGrpc.newBlockingStub(managedChannel)
} .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
private static RequestAttributes buildRequestAttributes(@Nullable final String userAgent,
final List<Locale.LanguageRange> acceptLanguage) {
return new RequestAttributes(REMOTE_ADDRESS, userAgent, acceptLanguage);
} }
} }

View File

@ -1,11 +1,7 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -16,12 +12,6 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel; import io.netty.channel.local.LocalServerChannel;
import java.net.InetAddress;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
@ -31,9 +21,20 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
class GrpcClientConnectionManagerTest { class GrpcClientConnectionManagerTest {
@ -102,7 +103,7 @@ class GrpcClientConnectionManagerTest {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
assertEquals(maybeAuthenticatedDevice, assertEquals(maybeAuthenticatedDevice,
grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel)); grpcClientConnectionManager.getAuthenticatedDevice(localChannel.localAddress()));
} }
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() { private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
@ -113,115 +114,170 @@ class GrpcClientConnectionManagerTest {
} }
@Test @Test
void getRequestAttributes() { void getAcceptableLanguages() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel)); assertEquals(Optional.empty(),
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null); final List<Locale.LanguageRange> acceptLanguageRanges = Locale.LanguageRange.parse("en,ja");
remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes); remoteChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges);
assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel)); assertEquals(Optional.of(acceptLanguageRanges),
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
} }
@Test @Test
void closeConnection() throws InterruptedException, ChannelNotFoundException { void getRemoteAddress() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9");
remoteChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress);
assertEquals(Optional.of(remoteAddress),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
remoteChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent);
assertEquals(Optional.of(userAgent),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
}
@Test
void closeConnection() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertTrue(remoteChannel.isOpen()); assertTrue(remoteChannel.isOpen());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), assertEquals(List.of(remoteChannel),
grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await(); remoteChannel.close().await();
assertThrows(ChannelNotFoundException.class, assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
} }
@Test
void handleWebSocketHandshakeCompleteRemoteAddress() {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
preferredRemoteAddress,
null,
null);
assertEquals(preferredRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress, void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader,
final String userAgentHeader, @Nullable final UserAgent expectedParsedUserAgent) {
final String acceptLanguageHeader,
final RequestAttributes expectedRequestAttributes) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel, GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
preferredRemoteAddress, InetAddresses.forString("127.0.0.1"),
userAgentHeader, userAgentHeader,
acceptLanguageHeader); null);
assertEquals(expectedRequestAttributes, assertEquals(userAgentHeader,
embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get()); embeddedChannel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get());
assertEquals(expectedParsedUserAgent,
embeddedChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
} }
private static List<Arguments> handleHandshakeCompleteRequestAttributes() { private static List<Arguments> handleWebSocketHandshakeCompleteUserAgent() {
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
return List.of( return List.of(
Arguments.argumentSet("Null User-Agent and Accept-Language headers", // Recognized user-agent
preferredRemoteAddress, null, null, Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())),
Arguments.argumentSet("Recognized User-Agent and null Accept-Language header", // Unrecognized user-agent
preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null, Arguments.of("Not a valid user-agent string", null),
new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())),
Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header", // Missing user-agent
preferredRemoteAddress, "Not a valid user-agent string", null, Arguments.of(null, null)
new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())), );
}
Arguments.argumentSet("Null User-Agent and parsable Accept-Language header",
preferredRemoteAddress, null, "ja,en;q=0.4",
new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))),
Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header", @ParameterizedTest
preferredRemoteAddress, null, "This is not a valid language preference list", @MethodSource
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())) void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader,
@Nullable final List<Locale.LanguageRange> expectedLanguageRanges) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
null,
acceptLanguageHeader);
assertEquals(expectedLanguageRanges,
embeddedChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteAcceptLanguage() {
return List.of(
// Parseable list
Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")),
// Unparsable list
Arguments.of("This is not a valid language preference list", null),
// Missing list
Arguments.of(null, null)
); );
} }
@Test @Test
void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException { void handleConnectionEstablishedAuthenticated() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
assertThrows(ChannelNotFoundException.class, assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await(); remoteChannel.close().await();
assertThrows(ChannelNotFoundException.class, assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
} }
@Test @Test
void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException { void handleConnectionEstablishedAnonymous() throws InterruptedException {
assertThrows(ChannelNotFoundException.class, assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
remoteChannel.close().await(); remoteChannel.close().await();
assertThrows(ChannelNotFoundException.class, assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
} }
} }

View File

@ -1,7 +1,6 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
@ -9,12 +8,10 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder; import io.grpc.ServerBuilder;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
@ -64,9 +61,6 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest; import org.signal.chat.rpc.GetRequestAttributesRequest;
@ -77,8 +71,6 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
@ -91,7 +83,6 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private static NioEventLoopGroup nioEventLoopGroup; private static NioEventLoopGroup nioEventLoopGroup;
private static DefaultEventLoopGroup defaultEventLoopGroup; private static DefaultEventLoopGroup defaultEventLoopGroup;
private static ExecutorService delegatedTaskExecutor; private static ExecutorService delegatedTaskExecutor;
private static ExecutorService serverCallExecutor;
private static X509Certificate serverTlsCertificate; private static X509Certificate serverTlsCertificate;
@ -145,8 +136,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
static void setUpBeforeAll() throws CertificateException { static void setUpBeforeAll() throws CertificateException {
nioEventLoopGroup = new NioEventLoopGroup(); nioEventLoopGroup = new NioEventLoopGroup();
defaultEventLoopGroup = new DefaultEventLoopGroup(); defaultEventLoopGroup = new DefaultEventLoopGroup();
delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor(); delegatedTaskExecutor = Executors.newSingleThreadExecutor();
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate( serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
@ -181,11 +171,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override @Override
protected void configureServer(final ServerBuilder<?> serverBuilder) { protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder serverBuilder.addService(new RequestAttributesServiceImpl())
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.addService(new EchoServiceImpl())
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager)); .intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
} }
@ -196,9 +182,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override @Override
protected void configureServer(final ServerBuilder<?> serverBuilder) { protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder serverBuilder.addService(new RequestAttributesServiceImpl())
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager)); .intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
} }
@ -211,7 +195,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
serverTlsPrivateKey, serverTlsPrivateKey,
nioEventLoopGroup, nioEventLoopGroup,
delegatedTaskExecutor, delegatedTaskExecutor,
grpcClientConnectionManager, grpcClientConnectionManager,
clientPublicKeysManager, clientPublicKeysManager,
serverKeyPair, serverKeyPair,
authenticatedGrpcServerAddress, authenticatedGrpcServerAddress,
@ -225,7 +209,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
null, null,
nioEventLoopGroup, nioEventLoopGroup,
delegatedTaskExecutor, delegatedTaskExecutor,
grpcClientConnectionManager, grpcClientConnectionManager,
clientPublicKeysManager, clientPublicKeysManager,
serverKeyPair, serverKeyPair,
authenticatedGrpcServerAddress, authenticatedGrpcServerAddress,
@ -251,10 +235,6 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
delegatedTaskExecutor.shutdown(); delegatedTaskExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS); delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
serverCallExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
} }
@ParameterizedTest @ParameterizedTest
@ -543,7 +523,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
assertEquals(remoteAddress, response.getRemoteAddress()); assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList()); assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals(userAgent, response.getUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -599,89 +582,6 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
} }
} }
@Test
void waitForCallCompletion() throws InterruptedException {
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
final AtomicBoolean closedByServer = new AtomicBoolean(false);
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(false);
connectionCloseLatch.countDown();
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(true);
connectionCloseLatch.countDown();
}
};
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
// Start an open-ended server call and leave it in a non-complete state
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
new StreamObserver<>() {
@Override
public void onNext(final EchoResponse echoResponse) {
responseCountDownLatch.countDown();
}
@Override
public void onError(final Throwable throwable) {
}
@Override
public void onCompleted() {
}
});
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
// truly started before requesting connection closure.
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
assertFalse(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should not close until active requests have finished");
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
// Complete the open-ended server call
echoRequestStreamObserver.onCompleted();
assertTrue(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should close once active requests have finished");
assertTrue(closedByServer.get());
assertEquals(4004, serverCloseStatusCode.get());
} finally {
channel.shutdown();
}
}
}
private NoiseWebSocketTunnelClient.Builder anonymous() { private NoiseWebSocketTunnelClient.Builder anonymous() {
return new NoiseWebSocketTunnelClient return new NoiseWebSocketTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())

View File

@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -17,7 +16,6 @@ import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.Attribute;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
@ -33,7 +31,6 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
@ -137,13 +134,8 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
embeddedChannel.setRemoteAddress(remoteAddress); embeddedChannel.setRemoteAddress(remoteAddress);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertEquals(expectedRemoteAddress, assertEquals(expectedRemoteAddress,
Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY)) embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
.map(Attribute::get)
.map(RequestAttributes::remoteAddress)
.orElse(null));
} }
private static List<Arguments> getRemoteAddress() { private static List<Arguments> getRemoteAddress() {
@ -152,53 +144,53 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1"); final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
return List.of( return List.of(
argumentSet("Recognized proxy, single forwarded-for address", // Recognized proxy, single forwarded-for address
new DefaultHttpHeaders() Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
clientAddress), clientAddress),
argumentSet("Recognized proxy, multiple forwarded-for addresses", // Recognized proxy, multiple forwarded-for addresses
new DefaultHttpHeaders() Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
remoteAddress, remoteAddress,
proxyAddress), proxyAddress),
argumentSet("No recognized proxy header, single forwarded-for address", // No recognized proxy header, single forwarded-for address
new DefaultHttpHeaders() Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
argumentSet("No recognized proxy header, no forwarded-for address", // No recognized proxy header, no forwarded-for address
new DefaultHttpHeaders(), Arguments.of(new DefaultHttpHeaders(),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
argumentSet("Incorrect proxy header, single forwarded-for address", // Incorrect proxy header, single forwarded-for address
new DefaultHttpHeaders() Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect") .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
argumentSet("Recognized proxy, no forwarded-for address", // Recognized proxy, no forwarded-for address
new DefaultHttpHeaders() Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
argumentSet("Recognized proxy, bogus forwarded-for address", // Recognized proxy, bogus forwarded-for address
new DefaultHttpHeaders() Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
remoteAddress, remoteAddress,
null), null),
argumentSet("No forwarded-for address, non-InetSocketAddress remote address", // No forwarded-for address, non-InetSocketAddress remote address
new DefaultHttpHeaders() Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
new LocalAddress("local-address"), new LocalAddress("local-address"),
null) null)

View File

@ -0,0 +1,91 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
class OperatingSystemMemoryGaugeTest {
private static final String MEMINFO =
"""
MemTotal: 16052208 kB
MemFree: 4568468 kB
MemAvailable: 7702848 kB
Buffers: 636372 kB
Cached: 5019116 kB
SwapCached: 6692 kB
Active: 7746436 kB
Inactive: 2729876 kB
Active(anon): 5580980 kB
Inactive(anon): 1648108 kB
Active(file): 2165456 kB
Inactive(file): 1081768 kB
Unevictable: 443948 kB
Mlocked: 4924 kB
SwapTotal: 1003516 kB
SwapFree: 935932 kB
Dirty: 28308 kB
Writeback: 0 kB
AnonPages: 5258396 kB
Mapped: 1530740 kB
Shmem: 2419340 kB
KReclaimable: 229392 kB
Slab: 408156 kB
SReclaimable: 229392 kB
SUnreclaim: 178764 kB
KernelStack: 17360 kB
PageTables: 50436 kB
NFS_Unstable: 0 kB
Bounce: 0 kB
WritebackTmp: 0 kB
CommitLimit: 9029620 kB
Committed_AS: 16681884 kB
VmallocTotal: 34359738367 kB
VmallocUsed: 41944 kB
VmallocChunk: 0 kB
Percpu: 4240 kB
HardwareCorrupted: 0 kB
AnonHugePages: 0 kB
ShmemHugePages: 0 kB
ShmemPmdMapped: 0 kB
FileHugePages: 0 kB
FilePmdMapped: 0 kB
CmaTotal: 0 kB
CmaFree: 0 kB
HugePages_Total: 0
HugePages_Free: 7
HugePages_Rsvd: 0
HugePages_Surp: 0
Hugepagesize: 2048 kB
Hugetlb: 0 kB
DirectMap4k: 481804 kB
DirectMap2M: 14901248 kB
DirectMap1G: 2097152 kB
""";
@ParameterizedTest
@MethodSource
void testGetValue(final String metricName, final long expectedValue) {
assertEquals(expectedValue, OperatingSystemMemoryGauge.getValue(MEMINFO.lines(), metricName));
}
@SuppressWarnings("unused")
private static Stream<Arguments> testGetValue() {
return Stream.of(
Arguments.of("MemTotal", 16052208L),
Arguments.of("Active(anon)", 5580980L),
Arguments.of("Committed_AS", 16681884L),
Arguments.of("HugePages_Free", 7L),
Arguments.of("NonsenseMetric", 0L)
);
}
}

View File

@ -12,7 +12,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -40,7 +39,6 @@ import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
@ -56,67 +54,13 @@ class MessageSenderTest {
private MessagesManager messagesManager; private MessagesManager messagesManager;
private PushNotificationManager pushNotificationManager; private PushNotificationManager pushNotificationManager;
private MessageSender messageSender; private MessageSender messageSender;
private ExperimentEnrollmentManager experimentEnrollmentManager;
@BeforeEach @BeforeEach
void setUp() { void setUp() {
messagesManager = mock(MessagesManager.class); messagesManager = mock(MessagesManager.class);
pushNotificationManager = mock(PushNotificationManager.class); pushNotificationManager = mock(PushNotificationManager.class);
experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager); messageSender = new MessageSender(messagesManager, pushNotificationManager);
}
@CartesianTest
void pushSkippedExperiment(
@CartesianTest.Values(booleans = {true, false}) final boolean hasGcmToken,
@CartesianTest.Values(booleans = {true, false}) final boolean isUrgent,
@CartesianTest.Values(booleans = {true, false}) final boolean inExperiment) throws NotPushRegisteredException {
final boolean shouldSkip = hasGcmToken && !isUrgent && inExperiment;
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder()
.setEphemeral(false)
.setUrgent(isUrgent)
.build();
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
if (hasGcmToken) {
when(device.getGcmId()).thenReturn("gcm-token");
} else {
when(device.getApnId()).thenReturn("apn-token");
}
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, false));
when(experimentEnrollmentManager.isEnrolled(accountIdentifier, MessageSender.ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT))
.thenReturn(inExperiment);
assertDoesNotThrow(() -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(device.getId(), message),
Map.of(device.getId(), registrationId),
Optional.empty(),
null));
if (shouldSkip) {
verifyNoInteractions(pushNotificationManager);
} else {
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, isUrgent);
}
} }
@CartesianTest @CartesianTest
@ -160,7 +104,6 @@ class MessageSenderTest {
serviceIdentifier, serviceIdentifier,
Map.of(device.getId(), message), Map.of(device.getId(), message),
Map.of(device.getId(), registrationId), Map.of(device.getId(), registrationId),
Optional.empty(),
null)); null));
final MessageProtos.Envelope expectedMessage = ephemeral final MessageProtos.Envelope expectedMessage = ephemeral
@ -201,7 +144,6 @@ class MessageSenderTest {
serviceIdentifier, serviceIdentifier,
Map.of(device.getId(), message), Map.of(device.getId(), message),
Map.of(device.getId(), registrationId + 1), Map.of(device.getId(), registrationId + 1),
Optional.empty(),
null)); null));
assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)), assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)),
@ -402,64 +344,4 @@ class MessageSenderTest {
Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId)))) Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId))))
); );
} }
@Test
void sendMessageEmptyMessageList() {
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(device));
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account,
serviceIdentifier,
Collections.emptyMap(),
Collections.emptyMap(),
Optional.empty(),
null));
assertDoesNotThrow(() -> messageSender.sendMessages(account,
serviceIdentifier,
Collections.emptyMap(),
Collections.emptyMap(),
Optional.of(Device.PRIMARY_ID),
null));
}
@Test
void sendSyncMessageMismatchedAddressing() {
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
final Account nonSyncDestination = mock(Account.class);
when(nonSyncDestination.isIdentifiedBy(any())).thenReturn(true);
assertThrows(IllegalArgumentException.class, () -> messageSender.sendMessages(nonSyncDestination,
new AciServiceIdentifier(UUID.randomUUID()),
Map.of(deviceId, MessageProtos.Envelope.newBuilder().build()),
Map.of(deviceId, 17),
Optional.of(deviceId),
null),
"Should throw an IllegalArgumentException for inter-account messages with a sync message device ID");
assertThrows(IllegalArgumentException.class, () -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(deviceId, MessageProtos.Envelope.newBuilder()
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(deviceId)
.build()),
Map.of(deviceId, 17),
Optional.empty(),
null),
"Should throw an IllegalArgumentException for self-addressed messages without a sync message device ID");
}
} }

View File

@ -6,10 +6,8 @@ import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import java.time.Clock; import java.time.Clock;
import java.time.LocalDateTime;
import java.time.LocalTime; import java.time.LocalTime;
import java.time.ZoneId; import java.time.ZoneId;
import java.time.ZoneOffset;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -67,26 +65,4 @@ class SchedulingUtilTest {
Clock.fixed(afterNotificationTime.toInstant(), ZoneId.systemDefault()))); Clock.fixed(afterNotificationTime.toInstant(), ZoneId.systemDefault())));
} }
} }
@Test
void getNextRecommendedNotificationTimeDaylightSavings() {
final Account account = mock(Account.class);
// The account has a phone number that can be resolved to a region with known timezones
when(account.getNumber()).thenReturn(PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("DE"), PhoneNumberUtil.PhoneNumberFormat.E164));
final LocalDateTime afterNotificationTime = LocalDateTime.of(2025, 3, 29, 15, 0);
final ZoneId berlinZoneId = ZoneId.of("Europe/Berlin");
final ZoneOffset berlineZoneOffset = berlinZoneId.getRules().getOffset(afterNotificationTime);
// Daylight Savings Time started on 2025-03-30 at 2:00AM in Germany.
// Instantiating a ZonedDateTime with a zone ID factors in daylight savings when we adjust the time.
final ZonedDateTime afterNotificationTimeWithZoneId = ZonedDateTime.of(afterNotificationTime, berlinZoneId);
assertEquals(
afterNotificationTimeWithZoneId.with(LocalTime.of(14, 0)).plusDays(1).toInstant(),
SchedulingUtil.getNextRecommendedNotificationTime(account, LocalTime.of(14, 0),
Clock.fixed(afterNotificationTime.toInstant(berlineZoneOffset), berlinZoneId)));
}
} }

View File

@ -19,6 +19,7 @@ import static org.whispersystems.textsecuregcm.tests.util.DevicesHelper.createDe
import com.fasterxml.jackson.annotation.JsonFilter; import com.fasterxml.jackson.annotation.JsonFilter;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -186,7 +187,7 @@ class AccountTest {
@Test @Test
void addAndRemoveBadges() { void addAndRemoveBadges() {
final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), List.of(createDevice(Device.PRIMARY_ID)), new byte[0]); final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), List.of(createDevice(Device.PRIMARY_ID)), new byte[0]);
final TestClock clock = TestClock.pinned(Instant.ofEpochSecond(40)); final Clock clock = TestClock.pinned(Instant.ofEpochSecond(40));
account.addBadge(clock, new AccountBadge("foo", Instant.ofEpochSecond(42), false)); account.addBadge(clock, new AccountBadge("foo", Instant.ofEpochSecond(42), false));
account.addBadge(clock, new AccountBadge("bar", Instant.ofEpochSecond(44), true)); account.addBadge(clock, new AccountBadge("bar", Instant.ofEpochSecond(44), true));
@ -213,17 +214,6 @@ class AccountTest {
assertThat(badge.expiration().getEpochSecond()).isEqualTo(51); assertThat(badge.expiration().getEpochSecond()).isEqualTo(51);
assertThat(badge.visible()).isTrue(); assertThat(badge.visible()).isTrue();
}); });
clock.pin(Instant.ofEpochSecond(52));
// for a merged badge, visible = true is preferred
account.addBadge(clock, new AccountBadge("foo", Instant.ofEpochSecond(53), false));
assertThat(account.getBadges()).hasSize(1).element(0).satisfies(badge -> {
assertThat(badge.id()).isEqualTo("foo");
assertThat(badge.expiration().getEpochSecond()).isEqualTo(53);
assertThat(badge.visible()).isTrue();
});
} }
@Test @Test

View File

@ -9,16 +9,13 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -38,10 +35,8 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
public class ChangeNumberManagerTest { public class ChangeNumberManagerTest {
private AccountsManager accountsManager; private AccountsManager accountsManager;
@ -50,13 +45,11 @@ public class ChangeNumberManagerTest {
private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount; private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount;
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
accountsManager = mock(AccountsManager.class); accountsManager = mock(AccountsManager.class);
messageSender = mock(MessageSender.class); messageSender = mock(MessageSender.class);
changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, CLOCK); changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
updatedPhoneNumberIdentifiersByAccount = new HashMap<>(); updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
@ -110,7 +103,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null); changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any(), any()); verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any());
} }
@Test @Test
@ -124,7 +117,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null);
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any(), any()); verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any());
} }
@Test @Test
@ -139,59 +132,45 @@ public class ChangeNumberManagerTest {
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni); when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device primaryDevice = mock(Device.class); final Device d2 = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); final byte deviceId2 = 2;
when(primaryDevice.getRegistrationId()).thenReturn(7); when(d2.getId()).thenReturn(deviceId2);
final Device linkedDevice = mock(Device.class); when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
final byte linkedDeviceId = Device.PRIMARY_ID + 1; when(account.getDevices()).thenReturn(List.of(d2));
final int linkedDeviceRegistrationId = 17;
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID, final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair), KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
linkedDeviceId, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, linkedDeviceId, 19); final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.type()).thenReturn(1); when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.destinationDeviceId()).thenReturn(linkedDeviceId);
when(msg.destinationRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(msg.content()).thenReturn(new byte[]{1}); when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds); verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder() @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
.setType(MessageProtos.Envelope.Type.forNumber(msg.type())) ArgumentCaptor.forClass(Map.class);
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setDestinationServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setContent(ByteString.copyFrom(msg.content()))
.setSourceServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(updatedPhoneNumberIdentifiersByAccount.get(account).toString())
.setUrgent(true)
.setEphemeral(false)
.build();
verify(messageSender).sendMessages(argThat(a -> a.getUuid().equals(aci)), verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any());
eq(new AciServiceIdentifier(aci)),
eq(Map.of(linkedDeviceId, expectedEnvelope)), assertEquals(1, envelopeCaptor.getValue().size());
eq(Map.of(linkedDeviceId, linkedDeviceRegistrationId)), assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
eq(Optional.of(Device.PRIMARY_ID)),
any()); final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
} }
@Test @Test
void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final String originalE164 = "+18005551234"; final String originalE164 = "+18005551234";
@ -231,7 +210,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -282,7 +261,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -329,7 +308,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -378,7 +357,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());

View File

@ -32,7 +32,6 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
@ -98,7 +97,7 @@ class MessagePersisterIntegrationTest {
webSocketConnectionEventManager.start(); webSocketConnectionEventManager.start();
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, mock(ExperimentEnrollmentManager.class), PERSIST_DELAY, 1); dynamicConfigurationManager, PERSIST_DELAY, 1);
account = mock(Account.class); account = mock(Account.class);

View File

@ -54,7 +54,6 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagePersisterConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagePersisterConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
@ -119,7 +118,7 @@ class MessagePersisterTest {
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, mock(ExperimentEnrollmentManager.class), PERSIST_DELAY, 1); dynamicConfigurationManager, PERSIST_DELAY, 1);
when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
@ -258,7 +257,7 @@ class MessagePersisterTest {
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
assertThrows(MessagePersistenceException.class, assertThrows(MessagePersistenceException.class,
() -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test"))); () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)));
} }
@Test @Test
@ -299,7 +298,7 @@ class MessagePersisterTest {
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test")); messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE));
verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID); verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID);
} }
@ -401,7 +400,7 @@ class MessagePersisterTest {
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenReturn(CompletableFuture.failedFuture(new TimeoutException())); when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenReturn(CompletableFuture.failedFuture(new TimeoutException()));
assertThrows(CompletionException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test")); assertThrows(CompletionException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE));
} }
@SuppressWarnings("SameParameterValue") @SuppressWarnings("SameParameterValue")

Some files were not shown because too many files have changed in this diff Show More