Compare commits

...

34 Commits

Author SHA1 Message Date
Jon Chambers f191c68efc Close remote connections only after all active server calls have completed 2025-04-22 17:00:48 -04:00
Jon Chambers bb8ce6d981 Introduce `ClosableEpoch` 2025-04-22 17:00:48 -04:00
Katherine e0ee75e0d0
Fix Daylight Savings bug in recommended notification time calculation 2025-04-22 16:56:10 -04:00
Jon Chambers 1ef3a230a1 Tag queue size distribution with client platform 2025-04-22 16:55:16 -04:00
Jon Chambers b1805d4bf1 Add a "persisted bytes" counter 2025-04-22 16:55:16 -04:00
Jon Chambers cac979c7fd Count individual persisted messages 2025-04-22 16:55:16 -04:00
Jon Chambers 4072dcdda5 Introduce `DevicePlatformUtil` 2025-04-22 16:55:16 -04:00
Jonathan Klabunde Tomer ed382fff6d log slot number and shard host of message persister failures 2025-04-22 16:55:16 -04:00
Jon Chambers 23bb8277d5 Update to the latest version of the spam filter 2025-04-18 15:56:17 -04:00
Jon Chambers 8099d6465c
Clarify guarantees around remote channnel/request attribute presence 2025-04-18 15:44:21 -04:00
Jon Chambers 28a0b9e84e
Include a TURN credential TTL for clients in `GetCallingRelaysResponse` 2025-04-17 10:30:58 -04:00
Chris Eager 9287aaf7ce Add app info to Stripe API calls 2025-04-17 09:30:34 -05:00
Chris Eager 0585f862cb Add regression test for set profile badges calculation 2025-04-17 09:29:11 -05:00
Chris Eager 7cac6f6f72 Remove extraneous account fetch in POST /v1/donation/redeem-receipt 2025-04-17 09:28:57 -05:00
Jon Chambers 57be4d798b Add a counter for attempts to send empty message lists 2025-04-17 10:27:46 -04:00
Jon Chambers 05c74f1997 Simplify `UserAgentUtil` 2025-04-17 10:27:24 -04:00
Jon Chambers f5e49b6db7 Convert `UserAgent` to a record 2025-04-15 14:58:09 -04:00
Jon Chambers 3c40e72d27
Fix registration ID map construction when changing numbers 2025-04-15 14:57:28 -04:00
Ravi Khadiwala 2f2ae7cec5 simplify story tag calculation 2025-04-11 14:04:09 -05:00
Chris Eager b236b53dc3 set profile: move updated badge calculation into account updater lambda 2025-04-11 14:03:05 -05:00
Katherine eb71e30046
Update to protobuf 4.x 2025-04-10 13:05:23 -04:00
Jon Chambers aa5fd52302 Explicitly pass sync message sender device ID as an argument to `sendMessage` 2025-04-10 11:40:32 -04:00
Jon Chambers d6bc2765b6 Close gRPC channels from a copied list to avoid concurrent modification issues 2025-04-09 21:54:18 -04:00
Jon Chambers 01258de560
Throw a `MismatchedDevicesException` for empty message lists to support iOS clients 2025-04-09 21:53:58 -04:00
Jon Chambers 3af2cc5c70 Add tests for spam-reporting token presence 2025-04-09 14:24:34 -04:00
Jon Chambers 2278842531 Add gRPC endpoints for sending messages from identified/authenticated senders 2025-04-09 14:24:34 -04:00
Jon Chambers 60ab00ecc6 Specify bounds for message timestamps 2025-04-09 14:24:20 -04:00
Jon Chambers 1fb6d23500 Allow range validators to accept 64-bit min/max values 2025-04-09 14:24:20 -04:00
Jon Chambers 8d8a2a5583
Extract common message-sending methods into a shared utility class 2025-04-08 17:39:45 -04:00
Jon Chambers caa81b4885 Implement story sending via gRPC 2025-04-08 17:30:33 -04:00
Jon Chambers 37c4a0451a Simplify returning spam responses from gRPC 2025-04-08 17:30:33 -04:00
Jon Chambers 11df8fcc6c Add gRPC endpoints for sending unauthenticated (i.e. sealed-sender) messages 2025-04-08 17:30:33 -04:00
Jon Chambers 5a7f4d8381 Make the utility method for checking group send credentials blocking 2025-04-08 17:30:33 -04:00
Jon Chambers 1f1e4c72ec Add `simple-grpc` as a dependency/generator 2025-04-08 17:30:33 -04:00
95 changed files with 4466 additions and 1023 deletions

31
pom.xml
View File

@ -46,7 +46,7 @@
<!-- can be updated to latest version with Dropwizard 5 (Jetty 12); will then need to disable telemetry --> <!-- can be updated to latest version with Dropwizard 5 (Jetty 12); will then need to disable telemetry -->
<dynamodblocal.version>2.2.1</dynamodblocal.version> <dynamodblocal.version>2.2.1</dynamodblocal.version>
<google-cloud-libraries.version>26.57.0</google-cloud-libraries.version> <google-cloud-libraries.version>26.57.0</google-cloud-libraries.version>
<grpc.version>1.69.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom --> <grpc.version>1.70.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
<gson.version>2.12.1</gson.version> <gson.version>2.12.1</gson.version>
<!-- several libraries (AWS, Google Cloud) use Apache http components transitively, and we need to align them --> <!-- several libraries (AWS, Google Cloud) use Apache http components transitively, and we need to align them -->
<httpcore.version>4.4.16</httpcore.version> <httpcore.version>4.4.16</httpcore.version>
@ -65,14 +65,15 @@
<luajava.version>3.5.0</luajava.version> <luajava.version>3.5.0</luajava.version>
<micrometer.version>1.14.5</micrometer.version> <micrometer.version>1.14.5</micrometer.version>
<netty.version>4.1.119.Final</netty.version> <netty.version>4.1.119.Final</netty.version>
<!-- Must be greater than or equal to the value from Google libraries-bom <!-- Must be less than or equal to the value from Google libraries-bom which controls the protobuf runtime version.
since some of its libraries generate code. See https://protobuf.dev/support/cross-version-runtime-guarantee/. --> See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
<protobuf.version>3.25.5</protobuf.version> <protoc.version>4.29.4</protoc.version>
<pushy.version>0.15.4</pushy.version> <pushy.version>0.15.4</pushy.version>
<reactive.grpc.version>1.2.4</reactive.grpc.version> <reactive.grpc.version>1.2.4</reactive.grpc.version>
<reactor-bom.version>2024.0.4</reactor-bom.version> <!-- 3.7.4, see https://github.com/reactor/reactor#bom-versioning-scheme --> <reactor-bom.version>2024.0.4</reactor-bom.version> <!-- 3.7.4, see https://github.com/reactor/reactor#bom-versioning-scheme -->
<resilience4j.version>2.3.0</resilience4j.version> <resilience4j.version>2.3.0</resilience4j.version>
<semver4j.version>3.1.0</semver4j.version> <semver4j.version>3.1.0</semver4j.version>
<simple-grpc.version>0.1.0</simple-grpc.version>
<slf4j.version>2.0.17</slf4j.version> <slf4j.version>2.0.17</slf4j.version>
<stripe.version>23.10.0</stripe.version> <stripe.version>23.10.0</stripe.version>
<swagger.version>2.2.27</swagger.version> <swagger.version>2.2.27</swagger.version>
@ -126,7 +127,7 @@
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.google.cloud</groupId> <groupId>com.google.cloud</groupId>
<artifactId>libraries-bom-protobuf3</artifactId> <artifactId>libraries-bom</artifactId>
<version>${google-cloud-libraries.version}</version> <version>${google-cloud-libraries.version}</version>
<type>pom</type> <type>pom</type>
<scope>import</scope> <scope>import</scope>
@ -174,11 +175,6 @@
<artifactId>pushy-dropwizard-metrics-listener</artifactId> <artifactId>pushy-dropwizard-metrics-listener</artifactId>
<version>${pushy.version}</version> <version>${pushy.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>
<dependency> <dependency>
<groupId>com.googlecode.libphonenumber</groupId> <groupId>com.googlecode.libphonenumber</groupId>
<artifactId>libphonenumber</artifactId> <artifactId>libphonenumber</artifactId>
@ -267,6 +263,11 @@
<artifactId>libsignal-server</artifactId> <artifactId>libsignal-server</artifactId>
<version>0.67.6</version> <version>0.67.6</version>
</dependency> </dependency>
<dependency>
<groupId>org.signal</groupId>
<artifactId>simple-grpc-runtime</artifactId>
<version>${simple-grpc.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.signal.forks</groupId> <groupId>org.signal.forks</groupId>
<artifactId>noise-java</artifactId> <artifactId>noise-java</artifactId>
@ -437,7 +438,7 @@
<version>0.6.1</version> <version>0.6.1</version>
<configuration> <configuration>
<checkStaleness>false</checkStaleness> <checkStaleness>false</checkStaleness>
<protocArtifact>com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}</protocArtifact> <protocArtifact>com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}</protocArtifact>
<pluginId>grpc-java</pluginId> <pluginId>grpc-java</pluginId>
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact> <pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
@ -449,6 +450,14 @@
<version>${reactive.grpc.version}</version> <version>${reactive.grpc.version}</version>
<mainClass>com.salesforce.reactorgrpc.ReactorGrpcGenerator</mainClass> <mainClass>com.salesforce.reactorgrpc.ReactorGrpcGenerator</mainClass>
</protocPlugin> </protocPlugin>
<protocPlugin>
<id>simple</id>
<groupId>org.signal</groupId>
<artifactId>simple-grpc-generator</artifactId>
<version>${simple-grpc.version}</version>
<mainClass>org.signal.grpc.simple.SimpleGrpcGenerator</mainClass>
</protocPlugin>
</protocPlugins> </protocPlugins>
</configuration> </configuration>
<executions> <executions>

View File

@ -482,7 +482,8 @@ turn:
- turn:%s - turn:%s
- turn:%s:80?transport=tcp - turn:%s:80?transport=tcp
- turns:%s:443?transport=tcp - turns:%s:443?transport=tcp
ttl: 86400 requestedCredentialTtl: PT24H
clientCredentialTtl: PT12H
hostname: turn.cloudflare.example.com hostname: turn.cloudflare.example.com
numHttpClients: 1 numHttpClients: 1

View File

@ -80,6 +80,11 @@
<artifactId>noise-java</artifactId> <artifactId>noise-java</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.signal</groupId>
<artifactId>simple-grpc-runtime</artifactId>
</dependency>
<dependency> <dependency>
<groupId>io.dropwizard</groupId> <groupId>io.dropwizard</groupId>
<artifactId>dropwizard-core</artifactId> <artifactId>dropwizard-core</artifactId>

View File

@ -673,7 +673,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager( final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
config.getTurnConfiguration().cloudflare().apiToken().value(), config.getTurnConfiguration().cloudflare().apiToken().value(),
config.getTurnConfiguration().cloudflare().endpoint(), config.getTurnConfiguration().cloudflare().endpoint(),
config.getTurnConfiguration().cloudflare().ttl(), config.getTurnConfiguration().cloudflare().requestedCredentialTtl(),
config.getTurnConfiguration().cloudflare().clientCredentialTtl(),
config.getTurnConfiguration().cloudflare().urls(), config.getTurnConfiguration().cloudflare().urls(),
config.getTurnConfiguration().cloudflare().urlsWithIps(), config.getTurnConfiguration().cloudflare().urlsWithIps(),
config.getTurnConfiguration().cloudflare().hostname(), config.getTurnConfiguration().cloudflare().hostname(),
@ -693,7 +694,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager, PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager,
pushChallengeDynamoDb); pushChallengeDynamoDb);
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, Clock.systemUTC());
HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build(); HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build();
FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients() FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients()

View File

@ -15,6 +15,7 @@ import java.net.Inet6Address;
import java.net.URI; import java.net.URI;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -39,16 +40,18 @@ public class CloudflareTurnCredentialsManager {
private final List<String> cloudflareTurnUrls; private final List<String> cloudflareTurnUrls;
private final List<String> cloudflareTurnUrlsWithIps; private final List<String> cloudflareTurnUrlsWithIps;
private final String cloudflareTurnHostname; private final String cloudflareTurnHostname;
private final HttpRequest request; private final HttpRequest getCredentialsRequest;
private final FaultTolerantHttpClient cloudflareTurnClient; private final FaultTolerantHttpClient cloudflareTurnClient;
private final DnsNameResolver dnsNameResolver; private final DnsNameResolver dnsNameResolver;
record CredentialRequest(long ttl) {} private final Duration clientCredentialTtl;
record CloudflareTurnResponse(IceServer iceServers) { private record CredentialRequest(long ttl) {}
record IceServer( private record CloudflareTurnResponse(IceServer iceServers) {
private record IceServer(
String username, String username,
String credential, String credential,
List<String> urls) { List<String> urls) {
@ -56,10 +59,17 @@ public class CloudflareTurnCredentialsManager {
} }
public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken, public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken,
final String cloudflareTurnEndpoint, final long cloudflareTurnTtl, final List<String> cloudflareTurnUrls, final String cloudflareTurnEndpoint,
final List<String> cloudflareTurnUrlsWithIps, final String cloudflareTurnHostname, final Duration requestedCredentialTtl,
final int cloudflareTurnNumHttpClients, final CircuitBreakerConfiguration circuitBreaker, final Duration clientCredentialTtl,
final ExecutorService executor, final RetryConfiguration retry, final ScheduledExecutorService retryExecutor, final List<String> cloudflareTurnUrls,
final List<String> cloudflareTurnUrlsWithIps,
final String cloudflareTurnHostname,
final int cloudflareTurnNumHttpClients,
final CircuitBreakerConfiguration circuitBreaker,
final ExecutorService executor,
final RetryConfiguration retry,
final ScheduledExecutorService retryExecutor,
final DnsNameResolver dnsNameResolver) { final DnsNameResolver dnsNameResolver) {
this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder() this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder()
@ -75,17 +85,24 @@ public class CloudflareTurnCredentialsManager {
this.cloudflareTurnHostname = cloudflareTurnHostname; this.cloudflareTurnHostname = cloudflareTurnHostname;
this.dnsNameResolver = dnsNameResolver; this.dnsNameResolver = dnsNameResolver;
final String credentialsRequestBody;
try { try {
final String body = SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(cloudflareTurnTtl)); credentialsRequestBody =
this.request = HttpRequest.newBuilder() SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(requestedCredentialTtl.toSeconds()));
.uri(URI.create(cloudflareTurnEndpoint)) } catch (final JsonProcessingException e) {
.header("Content-Type", "application/json")
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
.POST(HttpRequest.BodyPublishers.ofString(body))
.build();
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }
// We repeat the same request to Cloudflare every time, so we can construct it once and re-use it
this.getCredentialsRequest = HttpRequest.newBuilder()
.uri(URI.create(cloudflareTurnEndpoint))
.header("Content-Type", "application/json")
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
.POST(HttpRequest.BodyPublishers.ofString(credentialsRequestBody))
.build();
this.clientCredentialTtl = clientCredentialTtl;
} }
public TurnToken retrieveFromCloudflare() throws IOException { public TurnToken retrieveFromCloudflare() throws IOException {
@ -105,7 +122,7 @@ public class CloudflareTurnCredentialsManager {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
final HttpResponse<String> response; final HttpResponse<String> response;
try { try {
response = cloudflareTurnClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join(); response = cloudflareTurnClient.sendAsync(getCredentialsRequest, HttpResponse.BodyHandlers.ofString()).join();
sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME) sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME)
.publishPercentileHistogram(true) .publishPercentileHistogram(true)
.tags("outcome", "success") .tags("outcome", "success")
@ -130,6 +147,7 @@ public class CloudflareTurnCredentialsManager {
return new TurnToken( return new TurnToken(
cloudflareTurnResponse.iceServers().username(), cloudflareTurnResponse.iceServers().username(),
cloudflareTurnResponse.iceServers().credential(), cloudflareTurnResponse.iceServers().credential(),
clientCredentialTtl.toSeconds(),
cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls, cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls,
cloudflareTurnComposedUrls, cloudflareTurnComposedUrls,
cloudflareTurnHostname cloudflareTurnHostname

View File

@ -5,13 +5,15 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List;
public record TurnToken( public record TurnToken(
String username, String username,
String password, String password,
@JsonProperty("ttl") long ttlSeconds,
@Nonnull List<String> urls, @Nonnull List<String> urls,
@Nonnull List<String> urlsWithIps, @Nonnull List<String> urlsWithIps,
@Nullable String hostname) { @Nullable String hostname) {

View File

@ -1,34 +1,22 @@
package org.whispersystems.textsecuregcm.auth.grpc; package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import java.util.Optional; import java.util.Optional;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor { abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager; private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Metadata EMPTY_TRAILERS = new Metadata();
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager; this.grpcClientConnectionManager = grpcClientConnectionManager;
} }
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) { protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call)
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { throws ChannelNotFoundException {
return grpcClientConnectionManager.getAuthenticatedDevice(localAddress);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) { return grpcClientConnectionManager.getAuthenticatedDevice(call);
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
return new ServerCall.Listener<>() {};
} }
} }

View File

@ -3,12 +3,17 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/** /**
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not * A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated * originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
* device are closed with an {@code UNAUTHENTICATED} status. * device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
* reject the call with a status of {@code UNAVAILABLE}.
*/ */
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor { public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -21,8 +26,15 @@ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInt
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
return getAuthenticatedDevice(call) try {
.map(ignored -> closeAsUnauthenticated(call)) return getAuthenticatedDevice(call)
.orElseGet(() -> next.startCall(call, headers)); // Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited
// service via an authenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
.orElseGet(() -> next.startCall(call, headers));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
} }
} }

View File

@ -5,12 +5,16 @@ import io.grpc.Contexts;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/** /**
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an * A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED} * authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
* status. * status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
*/ */
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor { public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -23,10 +27,17 @@ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInte
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
return getAuthenticatedDevice(call) try {
.map(authenticatedDevice -> Contexts.interceptCall(Context.current() return getAuthenticatedDevice(call)
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), .map(authenticatedDevice -> Contexts.interceptCall(Context.current()
call, headers, next)) .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
.orElseGet(() -> closeAsUnauthenticated(call)); call, headers, next))
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
} }
} }

View File

@ -6,16 +6,36 @@
package org.whispersystems.textsecuregcm.configuration; package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import java.time.Duration;
import java.util.List; import java.util.List;
import jakarta.validation.constraints.Positive; import jakarta.validation.constraints.Positive;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString; import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
/**
* Configuration properties for Cloudflare TURN integration.
*
* @param apiToken the API token to use when requesting TURN tokens from Cloudflare
* @param endpoint the URI of the Cloudflare API endpoint that vends TURN tokens
* @param requestedCredentialTtl the lifetime of TURN tokens to request from Cloudflare
* @param clientCredentialTtl the time clients may cache a TURN token; must be less than or equal to {@link #requestedCredentialTtl}
* @param urls a collection of TURN URLs to include verbatim in responses to clients
* @param urlsWithIps a collection of {@link String#format(String, Object...)} patterns to be populated with resolved IP
* addresses for {@link #hostname} in responses to clients; each pattern must include a single
* {@code %s} placeholder for the IP address
* @param circuitBreaker a circuit breaker for requests to Cloudflare
* @param retry a retry policy for requests to Cloudflare
* @param hostname the hostname to resolve to IP addresses for use with {@link #urlsWithIps}; also transmitted to
* clients for use as an SNI when connecting to pre-resolved hosts
* @param numHttpClients the number of parallel HTTP clients to use to communicate with Cloudflare
*/
public record CloudflareTurnConfiguration(@NotNull SecretString apiToken, public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
@NotBlank String endpoint, @NotBlank String endpoint,
@NotBlank long ttl, @NotNull Duration requestedCredentialTtl,
@NotNull Duration clientCredentialTtl,
@NotNull @NotEmpty @Valid List<@NotBlank String> urls, @NotNull @NotEmpty @Valid List<@NotBlank String> urls,
@NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps, @NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps,
@NotNull @Valid CircuitBreakerConfiguration circuitBreaker, @NotNull @Valid CircuitBreakerConfiguration circuitBreaker,
@ -35,4 +55,9 @@ public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
retry = new RetryConfiguration(); retry = new RetryConfiguration();
} }
} }
@AssertTrue
public boolean isClientTtlShorterThanRequestedTtl() {
return clientCredentialTtl.compareTo(requestedCredentialTtl) <= 0;
}
} }

View File

@ -15,16 +15,12 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import jakarta.ws.rs.GET; import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces; import jakarta.ws.rs.Produces;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.MediaType;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager; import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly; import org.whispersystems.websocket.auth.ReadOnly;
@ -32,14 +28,16 @@ import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v2/calling") @Path("/v2/calling")
public class CallRoutingControllerV2 { public class CallRoutingControllerV2 {
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager; private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER =
Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
public CallRoutingControllerV2( public CallRoutingControllerV2(
final RateLimiters rateLimiters, final RateLimiters rateLimiters,
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager) {
) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager; this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager;
} }
@ -58,25 +56,17 @@ public class CallRoutingControllerV2 {
@ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponse(responseCode = "429", description = "Rate limited.")
public GetCallingRelaysResponse getCallingRelays( public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth)
final @ReadOnly @Auth AuthenticatedDevice auth throws RateLimitExceededException, IOException {
) throws RateLimitExceededException, IOException {
UUID aci = auth.getAccount().getUuid(); final UUID aci = auth.getAccount().getUuid();
rateLimiters.getCallEndpointLimiter().validate(aci); rateLimiters.getCallEndpointLimiter().validate(aci);
List<TurnToken> tokens = new ArrayList<>();
try { try {
tokens.add(cloudflareTurnCredentialsManager.retrieveFromCloudflare()); return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare()));
} catch (Exception e) { } catch (final Exception e) {
CallRoutingControllerV2.CLOUDFLARE_TURN_ERROR_COUNTER.increment(); CLOUDFLARE_TURN_ERROR_COUNTER.increment();
throw e; throw e;
} }
return new GetCallingRelaysResponse(tokens);
}
public record GetCallingRelaysResponse(
List<TurnToken> relays
) {
} }
} }

View File

@ -44,7 +44,6 @@ import java.util.EnumMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -52,7 +51,6 @@ import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
@ -74,6 +72,7 @@ import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -402,7 +401,7 @@ public class DeviceController {
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) { private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
try { try {
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform()); return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).platform());
} catch (final UnrecognizedUserAgentException ignored) { } catch (final UnrecognizedUserAgentException ignored) {
return linkedDeviceListenersForUnrecognizedPlatforms; return linkedDeviceListenersForUnrecognizedPlatforms;
} }
@ -600,25 +599,9 @@ public class DeviceController {
} }
private static io.micrometer.core.instrument.Tag primaryPlatformTag(final Account account) { private static io.micrometer.core.instrument.Tag primaryPlatformTag(final Account account) {
final Device primaryDevice = account.getPrimaryDevice();
Optional<ClientPlatform> clientPlatform = Optional.empty();
if (StringUtils.isNotBlank(primaryDevice.getGcmId())) {
clientPlatform = Optional.of(ClientPlatform.ANDROID);
} else if (StringUtils.isNotBlank(primaryDevice.getApnId())) {
clientPlatform = Optional.of(ClientPlatform.IOS);
}
clientPlatform = clientPlatform.or(() -> Optional.ofNullable(
switch (primaryDevice.getUserAgent()) {
case "OWA" -> ClientPlatform.ANDROID;
case "OWI", "OWP" -> ClientPlatform.IOS;
case "OWD" -> ClientPlatform.DESKTOP;
case null, default -> null;
}));
return io.micrometer.core.instrument.Tag.of( return io.micrometer.core.instrument.Tag.of(
"primaryPlatform", "primaryPlatform",
clientPlatform DevicePlatformUtil.getDevicePlatform(account.getPrimaryDevice())
.map(p -> p.name().toLowerCase(Locale.ROOT)) .map(p -> p.name().toLowerCase(Locale.ROOT))
.orElse("unknown")); .orElse("unknown"));
} }

View File

@ -97,21 +97,20 @@ public class DonationController {
return redeemedReceiptsManager.put( return redeemedReceiptsManager.put(
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid()) receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid())
.thenCompose(receiptMatched -> { .thenCompose(receiptMatched -> {
if (!receiptMatched) { if (!receiptMatched) {
return CompletableFuture.completedFuture(
Response.status(Status.BAD_REQUEST).entity("receipt serial is already redeemed")
.type(MediaType.TEXT_PLAIN_TYPE).build());
}
return accountsManager.getByAccountIdentifierAsync(auth.getAccount().getUuid()) return CompletableFuture.completedFuture(
.thenCompose(optionalAccount -> Response.status(Status.BAD_REQUEST).entity("receipt serial is already redeemed")
optionalAccount.map(account -> accountsManager.updateAsync(account, a -> { .type(MediaType.TEXT_PLAIN_TYPE).build());
}
return accountsManager.updateAsync(auth.getAccount(), a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible())); a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) { if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId); a.makeBadgePrimaryIfExists(clock, badgeId);
} }
})).orElse(CompletableFuture.completedFuture(null))) })
.thenApply(ignored -> Response.ok().build()); .thenApply(ignored -> Response.ok().build());
}); });
}).thenCompose(Function.identity()); }).thenCompose(Function.identity());
} }

View File

@ -0,0 +1,13 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import java.util.List;
public record GetCallingRelaysResponse(List<TurnToken> relays) {
}

View File

@ -43,7 +43,6 @@ import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status; import jakarta.ws.rs.core.Response.Status;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -96,6 +95,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException; import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.push.MessageUtil;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
@ -105,7 +105,6 @@ import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
@ -114,11 +113,7 @@ import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.WebsocketHeaders; import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.auth.ReadOnly; import org.whispersystems.websocket.auth.ReadOnly;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/messages") @Path("/v1/messages")
@ -145,8 +140,6 @@ public class MessageController {
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final Clock clock; private final Clock clock;
private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture<?>[0]; private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture<?>[0];
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes"); private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
@ -443,11 +436,16 @@ public class MessageController {
final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream() final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId)); .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
final Optional<Byte> syncMessageSenderDeviceId = messageType == MessageType.SYNC
? Optional.ofNullable(sender).map(authenticatedDevice -> authenticatedDevice.getAuthenticatedDevice().getId())
: Optional.empty();
try { try {
messageSender.sendMessages(destination, messageSender.sendMessages(destination,
destinationIdentifier, destinationIdentifier,
messagesByDeviceId, messagesByDeviceId,
registrationIdsByDeviceId, registrationIdsByDeviceId,
syncMessageSenderDeviceId,
userAgent); userAgent);
} catch (final MismatchedDevicesException e) { } catch (final MismatchedDevicesException e) {
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) { if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
@ -563,7 +561,9 @@ public class MessageController {
final ContainerRequestContext context) { final ContainerRequestContext context) {
// Perform fast, inexpensive checks before attempting to resolve recipients // Perform fast, inexpensive checks before attempting to resolve recipients
validateNoDuplicateDevices(multiRecipientMessage); if (MessageUtil.hasDuplicateDevices(multiRecipientMessage)) {
throw new BadRequestException("Multi-recipient message contains duplicate recipient");
}
if (groupSendTokenHeader == null && combinedUnidentifiedSenderAccessKeys == null) { if (groupSendTokenHeader == null && combinedUnidentifiedSenderAccessKeys == null) {
throw new NotAuthorizedException("A group send endorsement token or unidentified access key is required for non-story messages"); throw new NotAuthorizedException("A group send endorsement token or unidentified access key is required for non-story messages");
@ -582,7 +582,14 @@ public class MessageController {
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient // At this point, the caller has at least superficially provided the information needed to send a multi-recipient
// message. Attempt to resolve the destination service identifiers to Signal accounts. // message. Attempt to resolve the destination service identifiers to Signal accounts.
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients = final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
resolveRecipients(multiRecipientMessage, groupSendTokenHeader == null); MessageUtil.resolveRecipients(accountsManager, multiRecipientMessage);
final List<ServiceIdentifier> unresolvedRecipientServiceIdentifiers =
MessageUtil.getUnresolvedRecipients(multiRecipientMessage, resolvedRecipients);
if (groupSendTokenHeader == null && !unresolvedRecipientServiceIdentifiers.isEmpty()) {
throw new NotFoundException();
}
// Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above. // Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above.
// Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them // Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them
@ -598,17 +605,6 @@ public class MessageController {
urgent, urgent,
context); context);
final List<ServiceIdentifier> unresolvedRecipientServiceIdentifiers;
if (groupSendTokenHeader != null) {
unresolvedRecipientServiceIdentifiers = multiRecipientMessage.getRecipients().entrySet().stream()
.filter(entry -> !resolvedRecipients.containsKey(entry.getValue()))
.map(entry -> ServiceIdentifier.fromLibsignal(entry.getKey()))
.toList();
} else {
unresolvedRecipientServiceIdentifiers = List.of();
}
return new SendMultiRecipientMessageResponse(unresolvedRecipientServiceIdentifiers); return new SendMultiRecipientMessageResponse(unresolvedRecipientServiceIdentifiers);
} }
@ -620,12 +616,14 @@ public class MessageController {
final ContainerRequestContext context) { final ContainerRequestContext context) {
// Perform fast, inexpensive checks before attempting to resolve recipients // Perform fast, inexpensive checks before attempting to resolve recipients
validateNoDuplicateDevices(multiRecipientMessage); if (MessageUtil.hasDuplicateDevices(multiRecipientMessage)) {
throw new BadRequestException("Multi-recipient message contains duplicate recipient");
}
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient // At this point, the caller has at least superficially provided the information needed to send a multi-recipient
// message. Attempt to resolve the destination service identifiers to Signal accounts. // message. Attempt to resolve the destination service identifiers to Signal accounts.
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients = final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
resolveRecipients(multiRecipientMessage, false); MessageUtil.resolveRecipients(accountsManager, multiRecipientMessage);
// We might filter out all the recipients of a story (if none exist). // We might filter out all the recipients of a story (if none exist).
// In this case there is no error so we should just return 200 now. // In this case there is no error so we should just return 200 now.
@ -909,43 +907,4 @@ public class MessageController {
return Response.status(Status.ACCEPTED) return Response.status(Status.ACCEPTED)
.build(); .build();
} }
private static void validateNoDuplicateDevices(final SealedSenderMultiRecipientMessage multiRecipientMessage) {
final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID + 1];
for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) {
if (recipient.getDevices().length == 1) {
// A recipient can't have repeated devices if they only have one device
continue;
}
Arrays.fill(usedDeviceIds, false);
for (final byte deviceId : recipient.getDevices()) {
if (usedDeviceIds[deviceId]) {
throw new BadRequestException();
}
usedDeviceIds[deviceId] = true;
}
}
}
private Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolveRecipients(final SealedSenderMultiRecipientMessage multiRecipientMessage,
final boolean throwOnNotFound) {
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.flatMap(serviceIdAndRecipient -> {
final ServiceIdentifier serviceIdentifier =
ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.flatMap(Mono::justOrEmpty)
.switchIfEmpty(throwOnNotFound ? Mono.error(NotFoundException::new) : Mono.empty())
.map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account));
}, MAX_FETCH_ACCOUNT_CONCURRENCY)
.collectMap(Tuple2::getT1, Tuple2::getT2)
.blockOptional()
.orElse(Collections.emptyMap());
}
} }

View File

@ -428,7 +428,7 @@ public class OneTimeDonationController {
@Nullable @Nullable
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) { private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
try { try {
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform(); return UserAgentUtil.parseUserAgentString(userAgentString).platform();
} catch (final UnrecognizedUserAgentException e) { } catch (final UnrecognizedUserAgentException e) {
return null; return null;
} }

View File

@ -204,11 +204,12 @@ public class ProfileController {
.build())); .build()));
} }
final List<AccountBadge> updatedBadges = request.badges()
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, auth.getAccount().getBadges()))
.orElseGet(() -> auth.getAccount().getBadges());
accountsManager.update(auth.getAccount(), a -> { accountsManager.update(auth.getAccount(), a -> {
final List<AccountBadge> updatedBadges = request.badges()
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
.orElseGet(a::getBadges);
a.setBadges(clock, updatedBadges); a.setBadges(clock, updatedBadges);
a.setCurrentProfileVersion(request.version()); a.setCurrentProfileVersion(request.version());
}); });

View File

@ -755,7 +755,7 @@ public class SubscriptionController {
@Nullable @Nullable
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) { private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
try { try {
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform(); return UserAgentUtil.parseUserAgentString(userAgentString).platform();
} catch (final UnrecognizedUserAgentException e) { } catch (final UnrecognizedUserAgentException e) {
return null; return null;
} }

View File

@ -81,7 +81,16 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) { @Nullable final UserAgent userAgent = RequestAttributesUtil.getUserAgent()
.map(userAgentString -> {
try {
return UserAgentUtil.parseUserAgentString(userAgentString);
} catch (final UnrecognizedUserAgentException e) {
return null;
}
}).orElse(null);
if (shouldBlock(userAgent)) {
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata()); call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
return new ServerCall.Listener<>() {}; return new ServerCall.Listener<>() {};
} else { } else {
@ -108,28 +117,28 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
return true; return true;
} }
if (blockedVersionsByPlatform.containsKey(userAgent.getPlatform())) { if (blockedVersionsByPlatform.containsKey(userAgent.platform())) {
if (blockedVersionsByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) { if (blockedVersionsByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
recordDeprecation(userAgent, BLOCKED_CLIENT_REASON); recordDeprecation(userAgent, BLOCKED_CLIENT_REASON);
shouldBlock = true; shouldBlock = true;
} }
} }
if (minimumVersionsByPlatform.containsKey(userAgent.getPlatform())) { if (minimumVersionsByPlatform.containsKey(userAgent.platform())) {
if (userAgent.getVersion().isLowerThan(minimumVersionsByPlatform.get(userAgent.getPlatform()))) { if (userAgent.version().isLowerThan(minimumVersionsByPlatform.get(userAgent.platform()))) {
recordDeprecation(userAgent, EXPIRED_CLIENT_REASON); recordDeprecation(userAgent, EXPIRED_CLIENT_REASON);
shouldBlock = true; shouldBlock = true;
} }
} }
if (versionsPendingBlockByPlatform.containsKey(userAgent.getPlatform())) { if (versionsPendingBlockByPlatform.containsKey(userAgent.platform())) {
if (versionsPendingBlockByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) { if (versionsPendingBlockByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
recordPendingDeprecation(userAgent, BLOCKED_CLIENT_REASON); recordPendingDeprecation(userAgent, BLOCKED_CLIENT_REASON);
} }
} }
if (versionsPendingDeprecationByPlatform.containsKey(userAgent.getPlatform())) { if (versionsPendingDeprecationByPlatform.containsKey(userAgent.platform())) {
if (userAgent.getVersion().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.getPlatform()))) { if (userAgent.version().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.platform()))) {
recordPendingDeprecation(userAgent, EXPIRED_CLIENT_REASON); recordPendingDeprecation(userAgent, EXPIRED_CLIENT_REASON);
} }
} }
@ -139,13 +148,13 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
private void recordDeprecation(final UserAgent userAgent, final String reason) { private void recordDeprecation(final UserAgent userAgent, final String reason) {
Metrics.counter(DEPRECATED_CLIENT_COUNTER_NAME, Metrics.counter(DEPRECATED_CLIENT_COUNTER_NAME,
PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized", PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized",
REASON_TAG_NAME, reason).increment(); REASON_TAG_NAME, reason).increment();
} }
private void recordPendingDeprecation(final UserAgent userAgent, final String reason) { private void recordPendingDeprecation(final UserAgent userAgent, final String reason) {
Metrics.counter(PENDING_DEPRECATION_COUNTER_NAME, Metrics.counter(PENDING_DEPRECATION_COUNTER_NAME,
PLATFORM_TAG, userAgent.getPlatform().name().toLowerCase(), PLATFORM_TAG, userAgent.platform().name().toLowerCase(),
REASON_TAG_NAME, reason).increment(); REASON_TAG_NAME, reason).increment();
} }
} }

View File

@ -15,8 +15,6 @@ import jakarta.ws.rs.container.ContainerRequestFilter;
import jakarta.ws.rs.core.SecurityContext; import jakarta.ws.rs.core.SecurityContext;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@ -70,8 +68,8 @@ public class RestDeprecationFilter implements ContainerRequestFilter {
try { try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
final ClientPlatform platform = userAgent.getPlatform(); final ClientPlatform platform = userAgent.platform();
final Semver version = userAgent.getVersion(); final Semver version = userAgent.version();
if (!minimumRestFreeVersion.containsKey(platform)) { if (!minimumRestFreeVersion.containsKey(platform)) {
return; return;
} }

View File

@ -0,0 +1,12 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
/**
* Indicates that a remote channel was not found for a given server call or remote address.
*/
public class ChannelNotFoundException extends Exception {
}

View File

@ -0,0 +1,55 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with
* {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise.
*/
public class ChannelShutdownInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (!grpcClientConnectionManager.handleServerCallStart(call)) {
// Don't allow new calls if the connection is getting ready to close
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {
@Override
public void onComplete() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onComplete();
}
@Override
public void onCancel() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onCancel();
}
};
}
}

View File

@ -7,8 +7,9 @@ package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Clock; import java.time.Clock;
import java.util.Collection;
import java.util.List; import java.util.List;
import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
@ -18,8 +19,6 @@ import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair;
import org.signal.libsignal.zkgroup.groupsend.GroupSendFullToken; import org.signal.libsignal.zkgroup.groupsend.GroupSendFullToken;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import reactor.core.publisher.Mono;
public class GroupSendTokenUtil { public class GroupSendTokenUtil {
private final ServerSecretParams serverSecretParams; private final ServerSecretParams serverSecretParams;
@ -30,16 +29,22 @@ public class GroupSendTokenUtil {
this.clock = clock; this.clock = clock;
} }
public Mono<Void> checkGroupSendToken(final ByteString serializedGroupSendToken, List<ServiceIdentifier> serviceIdentifiers) { public void checkGroupSendToken(final ByteString serializedGroupSendToken,
final ServiceIdentifier serviceIdentifier) throws StatusException {
checkGroupSendToken(serializedGroupSendToken, List.of(serviceIdentifier.toLibsignal()));
}
public void checkGroupSendToken(final ByteString serializedGroupSendToken,
final Collection<ServiceId> serviceIds) throws StatusException {
try { try {
final GroupSendFullToken token = new GroupSendFullToken(serializedGroupSendToken.toByteArray()); final GroupSendFullToken token = new GroupSendFullToken(serializedGroupSendToken.toByteArray());
final List<ServiceId> serviceIds = serviceIdentifiers.stream().map(ServiceIdentifier::toLibsignal).toList();
token.verify(serviceIds, clock.instant(), GroupSendDerivedKeyPair.forExpiration(token.getExpiration(), serverSecretParams)); token.verify(serviceIds, clock.instant(), GroupSendDerivedKeyPair.forExpiration(token.getExpiration(), serverSecretParams));
return Mono.empty(); } catch (final InvalidInputException e) {
} catch (InvalidInputException e) { throw Status.INVALID_ARGUMENT.asException();
return Mono.error(Status.INVALID_ARGUMENT.asException());
} catch (VerificationFailedException e) { } catch (VerificationFailedException e) {
return Mono.error(Status.UNAUTHENTICATED.asException()); throw Status.UNAUTHENTICATED.asException();
} }
} }
} }

View File

@ -7,12 +7,11 @@ package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.time.Clock; import java.time.Clock;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import org.signal.chat.keys.CheckIdentityKeyRequest; import org.signal.chat.keys.CheckIdentityKeyRequest;
import org.signal.chat.keys.CheckIdentityKeyResponse; import org.signal.chat.keys.CheckIdentityKeyResponse;
import org.signal.chat.keys.GetPreKeysAnonymousRequest; import org.signal.chat.keys.GetPreKeysAnonymousRequest;
@ -52,16 +51,24 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony
: KeysGrpcHelper.ALL_DEVICES; : KeysGrpcHelper.ALL_DEVICES;
return switch (request.getAuthorizationCase()) { return switch (request.getAuthorizationCase()) {
case GROUP_SEND_TOKEN -> case GROUP_SEND_TOKEN -> {
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), List.of(serviceIdentifier)) try {
.then(lookUpAccount(serviceIdentifier, Status.NOT_FOUND)) groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), serviceIdentifier);
yield lookUpAccount(serviceIdentifier, Status.NOT_FOUND)
.flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager)); .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager));
} catch (final StatusException e) {
yield Mono.error(e);
}
}
case UNIDENTIFIED_ACCESS_KEY -> case UNIDENTIFIED_ACCESS_KEY ->
lookUpAccount(serviceIdentifier, Status.UNAUTHENTICATED) lookUpAccount(serviceIdentifier, Status.UNAUTHENTICATED)
.flatMap(targetAccount -> .flatMap(targetAccount ->
UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray()) UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray())
? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager) ? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager)
: Mono.error(Status.UNAUTHENTICATED.asException())); : Mono.error(Status.UNAUTHENTICATED.asException()));
default -> Mono.error(Status.INVALID_ARGUMENT.asException()); default -> Mono.error(Status.INVALID_ARGUMENT.asException());
}; };
} }

View File

@ -0,0 +1,302 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Clock;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.signal.chat.messages.IndividualRecipientMessageBundle;
import org.signal.chat.messages.MultiRecipientMismatchedDevices;
import org.signal.chat.messages.SendMessageResponse;
import org.signal.chat.messages.SendMultiRecipientMessageRequest;
import org.signal.chat.messages.SendMultiRecipientMessageResponse;
import org.signal.chat.messages.SendMultiRecipientStoryRequest;
import org.signal.chat.messages.SendSealedSenderMessageRequest;
import org.signal.chat.messages.SendStoryMessageRequest;
import org.signal.chat.messages.SimpleMessagesAnonymousGrpc;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.InvalidVersionException;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.push.MessageUtil;
import org.whispersystems.textsecuregcm.spam.GrpcResponse;
import org.whispersystems.textsecuregcm.spam.MessageType;
import org.whispersystems.textsecuregcm.spam.SpamCheckResult;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.MessagesAnonymousImplBase {
private final AccountsManager accountsManager;
private final RateLimiters rateLimiters;
private final MessageSender messageSender;
private final GroupSendTokenUtil groupSendTokenUtil;
private final CardinalityEstimator messageByteLimitEstimator;
private final SpamChecker spamChecker;
private final Clock clock;
private static final SendMessageResponse SEND_MESSAGE_SUCCESS_RESPONSE = SendMessageResponse.newBuilder().build();
public MessagesAnonymousGrpcService(final AccountsManager accountsManager,
final RateLimiters rateLimiters,
final MessageSender messageSender,
final GroupSendTokenUtil groupSendTokenUtil,
final CardinalityEstimator messageByteLimitEstimator,
final SpamChecker spamChecker,
final Clock clock) {
this.accountsManager = accountsManager;
this.rateLimiters = rateLimiters;
this.messageSender = messageSender;
this.messageByteLimitEstimator = messageByteLimitEstimator;
this.spamChecker = spamChecker;
this.clock = clock;
this.groupSendTokenUtil = groupSendTokenUtil;
}
@Override
public SendMessageResponse sendSingleRecipientMessage(final SendSealedSenderMessageRequest request)
throws StatusException, RateLimitExceededException {
final ServiceIdentifier destinationServiceIdentifier =
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination());
final Account destination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier)
.orElseThrow(Status.UNAUTHENTICATED::asException);
switch (request.getAuthorizationCase()) {
case UNIDENTIFIED_ACCESS_KEY -> {
if (!UnidentifiedAccessUtil.checkUnidentifiedAccess(destination, request.getUnidentifiedAccessKey().toByteArray())) {
throw Status.UNAUTHENTICATED.asException();
}
}
case GROUP_SEND_TOKEN ->
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), destinationServiceIdentifier);
case AUTHORIZATION_NOT_SET -> throw Status.UNAUTHENTICATED.asException();
}
return sendIndividualMessage(destination,
destinationServiceIdentifier,
request.getMessages(),
request.getEphemeral(),
request.getUrgent(),
false);
}
@Override
public SendMessageResponse sendStory(final SendStoryMessageRequest request)
throws StatusException, RateLimitExceededException {
final ServiceIdentifier destinationServiceIdentifier =
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination());
final Optional<Account> maybeDestination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier);
if (maybeDestination.isEmpty()) {
// Don't reveal to unauthenticated callers whether a destination account actually exists
return SEND_MESSAGE_SUCCESS_RESPONSE;
}
final Account destination = maybeDestination.get();
rateLimiters.getStoriesLimiter().validate(destination.getIdentifier(IdentityType.ACI));
return sendIndividualMessage(destination,
destinationServiceIdentifier,
request.getMessages(),
false,
request.getUrgent(),
true);
}
private SendMessageResponse sendIndividualMessage(final Account destination,
final ServiceIdentifier destinationServiceIdentifier,
final IndividualRecipientMessageBundle messages,
final boolean ephemeral,
final boolean urgent,
final boolean story) throws StatusException, RateLimitExceededException {
final SpamCheckResult<GrpcResponse<SendMessageResponse>> spamCheckResult =
spamChecker.checkForIndividualRecipientSpamGrpc(
story ? MessageType.INDIVIDUAL_STORY : MessageType.INDIVIDUAL_SEALED_SENDER,
Optional.empty(),
Optional.of(destination),
destinationServiceIdentifier);
if (spamCheckResult.response().isPresent()) {
return spamCheckResult.response().get().getResponseOrThrowStatus();
}
try {
final int totalPayloadLength = messages.getMessagesMap().values().stream()
.mapToInt(message -> message.getPayload().size())
.sum();
rateLimiters.getInboundMessageBytes().validate(destinationServiceIdentifier.uuid(), totalPayloadLength);
} catch (final RateLimitExceededException e) {
messageByteLimitEstimator.add(destinationServiceIdentifier.uuid().toString());
throw e;
}
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = messages.getMessagesMap().entrySet()
.stream()
.collect(Collectors.toMap(
entry -> DeviceIdUtil.validate(entry.getKey()),
entry -> {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER)
.setClientTimestamp(messages.getTimestamp())
.setServerTimestamp(clock.millis())
.setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString())
.setEphemeral(ephemeral)
.setUrgent(urgent)
.setStory(story)
.setContent(entry.getValue().getPayload());
spamCheckResult.token().ifPresent(reportSpamToken ->
envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken)));
return envelopeBuilder.build();
}
));
final Map<Byte, Integer> registrationIdsByDeviceId = messages.getMessagesMap().entrySet().stream()
.collect(Collectors.toMap(
entry -> entry.getKey().byteValue(),
entry -> entry.getValue().getRegistrationId()));
return MessagesGrpcHelper.sendMessage(messageSender,
destination,
destinationServiceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
Optional.empty());
}
@Override
public SendMultiRecipientMessageResponse sendMultiRecipientMessage(final SendMultiRecipientMessageRequest request)
throws StatusException {
final SealedSenderMultiRecipientMessage multiRecipientMessage =
parseAndValidateMultiRecipientMessage(request.getMessage().getPayload().toByteArray());
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), multiRecipientMessage.getRecipients().keySet());
return sendMultiRecipientMessage(multiRecipientMessage,
request.getMessage().getTimestamp(),
request.getEphemeral(),
request.getUrgent(),
false);
}
@Override
public SendMultiRecipientMessageResponse sendMultiRecipientStory(final SendMultiRecipientStoryRequest request)
throws StatusException {
final SealedSenderMultiRecipientMessage multiRecipientMessage =
parseAndValidateMultiRecipientMessage(request.getMessage().getPayload().toByteArray());
return sendMultiRecipientMessage(multiRecipientMessage,
request.getMessage().getTimestamp(),
false,
request.getUrgent(),
true)
.toBuilder()
// Don't identify unresolved recipients for stories
.clearUnresolvedRecipients()
.build();
}
private SendMultiRecipientMessageResponse sendMultiRecipientMessage(
final SealedSenderMultiRecipientMessage multiRecipientMessage,
final long timestamp,
final boolean ephemeral,
final boolean urgent,
final boolean story) throws StatusException {
final SpamCheckResult<GrpcResponse<SendMultiRecipientMessageResponse>> spamCheckResult =
spamChecker.checkForMultiRecipientSpamGrpc(story
? MessageType.MULTI_RECIPIENT_STORY
: MessageType.MULTI_RECIPIENT_SEALED_SENDER);
if (spamCheckResult.response().isPresent()) {
return spamCheckResult.response().get().getResponseOrThrowStatus();
}
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient
// message. Attempt to resolve the destination service identifiers to Signal accounts.
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
MessageUtil.resolveRecipients(accountsManager, multiRecipientMessage);
try {
messageSender.sendMultiRecipientMessage(multiRecipientMessage,
resolvedRecipients,
timestamp,
story,
ephemeral,
urgent,
RequestAttributesUtil.getUserAgent().orElse(null));
final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder();
MessageUtil.getUnresolvedRecipients(multiRecipientMessage, resolvedRecipients).stream()
.map(ServiceIdentifierUtil::toGrpcServiceIdentifier)
.forEach(responseBuilder::addUnresolvedRecipients);
return responseBuilder.build();
} catch (final MessageTooLargeException e) {
throw Status.INVALID_ARGUMENT
.withDescription("Message for an individual recipient was too large")
.withCause(e)
.asRuntimeException();
} catch (final MultiRecipientMismatchedDevicesException e) {
final MultiRecipientMismatchedDevices.Builder mismatchedDevicesBuilder =
MultiRecipientMismatchedDevices.newBuilder();
e.getMismatchedDevicesByServiceIdentifier().forEach((serviceIdentifier, mismatchedDevices) ->
mismatchedDevicesBuilder.addMismatchedDevices(MessagesGrpcHelper.buildMismatchedDevices(serviceIdentifier, mismatchedDevices)));
return SendMultiRecipientMessageResponse.newBuilder()
.setMismatchedDevices(mismatchedDevicesBuilder)
.build();
}
}
private SealedSenderMultiRecipientMessage parseAndValidateMultiRecipientMessage(
final byte[] serializedMultiRecipientMessage) throws StatusException {
final SealedSenderMultiRecipientMessage multiRecipientMessage;
try {
multiRecipientMessage = SealedSenderMultiRecipientMessage.parse(serializedMultiRecipientMessage);
} catch (final InvalidMessageException | InvalidVersionException e) {
throw Status.INVALID_ARGUMENT.withCause(e).asException();
}
// Check that the request is well-formed and doesn't contain repeated entries for the same device for the same
// recipient
if (MessageUtil.hasDuplicateDevices(multiRecipientMessage)) {
throw Status.INVALID_ARGUMENT.withDescription("Multi-recipient message contains duplicate recipient").asException();
}
return multiRecipientMessage;
}
}

View File

@ -0,0 +1,91 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Status;
import io.grpc.StatusException;
import java.util.Map;
import java.util.Optional;
import org.signal.chat.messages.MismatchedDevices;
import org.signal.chat.messages.SendMessageResponse;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.storage.Account;
public class MessagesGrpcHelper {
private static final SendMessageResponse SEND_MESSAGE_SUCCESS_RESPONSE = SendMessageResponse.newBuilder().build();
/**
* Sends a "bundle" of messages to an individual destination account, mapping common exceptions to appropriate gRPC
* statuses.
*
* @param messageSender the {@code MessageSender} instance to use to send the messages
* @param destination the destination account for the messages
* @param destinationServiceIdentifier the service identifier for the destination account
* @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
* caller's own account), contains the ID of the device that sent the message
*
* @return a response object to send to callers
*
* @throws StatusException if the message bundle could not be sent due to an out-of-date device set or an invalid
* message payload
* @throws RateLimitExceededException if the message bundle could not be sent due to a violated rated limit
*/
public static SendMessageResponse sendMessage(final MessageSender messageSender,
final Account destination,
final ServiceIdentifier destinationServiceIdentifier,
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId)
throws StatusException, RateLimitExceededException {
try {
messageSender.sendMessages(destination,
destinationServiceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
syncMessageSenderDeviceId,
RequestAttributesUtil.getUserAgent().orElse(null));
return SEND_MESSAGE_SUCCESS_RESPONSE;
} catch (final MismatchedDevicesException e) {
return SendMessageResponse.newBuilder()
.setMismatchedDevices(buildMismatchedDevices(destinationServiceIdentifier, e.getMismatchedDevices()))
.build();
} catch (final MessageTooLargeException e) {
throw Status.INVALID_ARGUMENT.withDescription("Message too large").withCause(e).asException();
}
}
/**
* Translates an internal {@link org.whispersystems.textsecuregcm.controllers.MismatchedDevices} entity to a gRPC
* {@link MismatchedDevices} entity.
*
* @param serviceIdentifier the service identifier to which the mismatched device response applies
* @param mismatchedDevices the mismatched device entity to translate to gRPC
*
* @return a gRPC {@code MismatchedDevices} representation of the given mismatched devices
*/
public static MismatchedDevices buildMismatchedDevices(final ServiceIdentifier serviceIdentifier,
final org.whispersystems.textsecuregcm.controllers.MismatchedDevices mismatchedDevices) {
final MismatchedDevices.Builder mismatchedDevicesBuilder = MismatchedDevices.newBuilder()
.setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier));
mismatchedDevices.missingDeviceIds().forEach(mismatchedDevicesBuilder::addMissingDevices);
mismatchedDevices.extraDeviceIds().forEach(mismatchedDevicesBuilder::addExtraDevices);
mismatchedDevices.staleDeviceIds().forEach(mismatchedDevicesBuilder::addStaleDevices);
return mismatchedDevicesBuilder.build();
}
}

View File

@ -0,0 +1,188 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Clock;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.signal.chat.messages.AuthenticatedSenderMessageType;
import org.signal.chat.messages.IndividualRecipientMessageBundle;
import org.signal.chat.messages.SendAuthenticatedSenderMessageRequest;
import org.signal.chat.messages.SendMessageResponse;
import org.signal.chat.messages.SendSyncMessageRequest;
import org.signal.chat.messages.SimpleMessagesGrpc;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.spam.GrpcResponse;
import org.whispersystems.textsecuregcm.spam.MessageType;
import org.whispersystems.textsecuregcm.spam.SpamCheckResult;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class MessagesGrpcService extends SimpleMessagesGrpc.MessagesImplBase {
private final AccountsManager accountsManager;
private final RateLimiters rateLimiters;
private final MessageSender messageSender;
private final CardinalityEstimator messageByteLimitEstimator;
private final SpamChecker spamChecker;
private final Clock clock;
public MessagesGrpcService(final AccountsManager accountsManager,
final RateLimiters rateLimiters,
final MessageSender messageSender,
final CardinalityEstimator messageByteLimitEstimator,
final SpamChecker spamChecker,
final Clock clock) {
this.accountsManager = accountsManager;
this.rateLimiters = rateLimiters;
this.messageSender = messageSender;
this.messageByteLimitEstimator = messageByteLimitEstimator;
this.spamChecker = spamChecker;
this.clock = clock;
}
@Override
public SendMessageResponse sendMessage(final SendAuthenticatedSenderMessageRequest request)
throws StatusException, RateLimitExceededException {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
final AciServiceIdentifier senderServiceIdentifier = new AciServiceIdentifier(authenticatedDevice.accountIdentifier());
final Account sender =
accountsManager.getByServiceIdentifier(senderServiceIdentifier).orElseThrow(Status.UNAUTHENTICATED::asException);
final ServiceIdentifier destinationServiceIdentifier =
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination());
if (sender.isIdentifiedBy(destinationServiceIdentifier)) {
throw Status.INVALID_ARGUMENT
.withDescription("Use `sendSyncMessage` to send messages to own account")
.asException();
}
final Account destination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier)
.orElseThrow(Status.NOT_FOUND::asException);
rateLimiters.getMessagesLimiter().validate(authenticatedDevice.accountIdentifier(), destination.getUuid());
return sendMessage(destination,
destinationServiceIdentifier,
authenticatedDevice,
request.getType(),
MessageType.INDIVIDUAL_IDENTIFIED_SENDER,
request.getMessages(),
request.getEphemeral(),
request.getUrgent());
}
@Override
public SendMessageResponse sendSyncMessage(final SendSyncMessageRequest request)
throws StatusException, RateLimitExceededException {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
final AciServiceIdentifier senderServiceIdentifier = new AciServiceIdentifier(authenticatedDevice.accountIdentifier());
final Account sender =
accountsManager.getByServiceIdentifier(senderServiceIdentifier).orElseThrow(Status.UNAUTHENTICATED::asException);
return sendMessage(sender,
senderServiceIdentifier,
authenticatedDevice,
request.getType(),
MessageType.SYNC,
request.getMessages(),
false,
request.getUrgent());
}
private SendMessageResponse sendMessage(final Account destination,
final ServiceIdentifier destinationServiceIdentifier,
final AuthenticatedDevice sender,
final AuthenticatedSenderMessageType envelopeType,
final MessageType messageType,
final IndividualRecipientMessageBundle messages,
final boolean ephemeral,
final boolean urgent) throws StatusException, RateLimitExceededException {
try {
final int totalPayloadLength = messages.getMessagesMap().values().stream()
.mapToInt(message -> message.getPayload().size())
.sum();
rateLimiters.getInboundMessageBytes().validate(destinationServiceIdentifier.uuid(), totalPayloadLength);
} catch (final RateLimitExceededException e) {
messageByteLimitEstimator.add(destinationServiceIdentifier.uuid().toString());
throw e;
}
final SpamCheckResult<GrpcResponse<SendMessageResponse>> spamCheckResult =
spamChecker.checkForIndividualRecipientSpamGrpc(messageType,
Optional.of(sender),
Optional.of(destination),
destinationServiceIdentifier);
if (spamCheckResult.response().isPresent()) {
return spamCheckResult.response().get().getResponseOrThrowStatus();
}
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = messages.getMessagesMap().entrySet()
.stream()
.collect(Collectors.toMap(
entry -> DeviceIdUtil.validate(entry.getKey()),
entry -> {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setType(getEnvelopeType(envelopeType))
.setClientTimestamp(messages.getTimestamp())
.setServerTimestamp(clock.millis())
.setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString())
.setSourceServiceId(new AciServiceIdentifier(sender.accountIdentifier()).toServiceIdentifierString())
.setSourceDevice(sender.deviceId())
.setEphemeral(ephemeral)
.setUrgent(urgent)
.setContent(entry.getValue().getPayload());
spamCheckResult.token().ifPresent(reportSpamToken ->
envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken)));
return envelopeBuilder.build();
}
));
final Map<Byte, Integer> registrationIdsByDeviceId = messages.getMessagesMap().entrySet().stream()
.collect(Collectors.toMap(
entry -> entry.getKey().byteValue(),
entry -> entry.getValue().getRegistrationId()));
return MessagesGrpcHelper.sendMessage(messageSender,
destination,
destinationServiceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
messageType == MessageType.SYNC ? Optional.of(sender.deviceId()) : Optional.empty());
}
private static MessageProtos.Envelope.Type getEnvelopeType(final AuthenticatedSenderMessageType type) {
return switch (type) {
case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT;
case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE;
case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT;
case UNSPECIFIED, UNRECOGNIZED ->
throw Status.INVALID_ARGUMENT.withDescription("Unrecognized envelope type").asRuntimeException();
};
}
}

View File

@ -6,10 +6,8 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Clock; import java.time.Clock;
import java.util.List;
import org.signal.chat.profile.CredentialType; import org.signal.chat.profile.CredentialType;
import org.signal.chat.profile.GetExpiringProfileKeyCredentialAnonymousRequest; import org.signal.chat.profile.GetExpiringProfileKeyCredentialAnonymousRequest;
import org.signal.chat.profile.GetExpiringProfileKeyCredentialResponse; import org.signal.chat.profile.GetExpiringProfileKeyCredentialResponse;
@ -59,11 +57,17 @@ public class ProfileAnonymousGrpcService extends ReactorProfileAnonymousGrpc.Pro
} }
final Mono<Account> account = switch (request.getAuthenticationCase()) { final Mono<Account> account = switch (request.getAuthenticationCase()) {
case GROUP_SEND_TOKEN -> case GROUP_SEND_TOKEN -> {
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), List.of(targetIdentifier)) try {
.then(Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier))) groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), targetIdentifier);
yield Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier))
.flatMap(Mono::justOrEmpty) .flatMap(Mono::justOrEmpty)
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())); .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()));
} catch (final StatusException e) {
yield Mono.error(e);
}
}
case UNIDENTIFIED_ACCESS_KEY -> case UNIDENTIFIED_ACCESS_KEY ->
getTargetAccountAndValidateUnidentifiedAccess(targetIdentifier, request.getUnidentifiedAccessKey().toByteArray()); getTargetAccountAndValidateUnidentifiedAccess(targetIdentifier, request.getUnidentifiedAccessKey().toByteArray());
default -> Mono.error(Status.INVALID_ARGUMENT.asException()); default -> Mono.error(Status.INVALID_ARGUMENT.asException());

View File

@ -145,11 +145,13 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase {
request.getCommitment().toByteArray()))); request.getCommitment().toByteArray())));
final List<Mono<?>> updates = new ArrayList<>(2); final List<Mono<?>> updates = new ArrayList<>(2);
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, account.getBadges()))
.orElseGet(account::getBadges);
updates.add(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> { updates.add(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> {
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
.orElseGet(a::getBadges);
a.setBadges(clock, updatedBadges); a.setBadges(clock, updatedBadges);
a.setCurrentProfileVersion(request.getVersion()); a.setCurrentProfileVersion(request.getVersion());
}))); })));

View File

@ -0,0 +1,16 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
public record RequestAttributes(InetAddress remoteAddress,
@Nullable String userAgent,
List<Locale.LanguageRange> acceptLanguage) {
}

View File

@ -2,28 +2,25 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Contexts; import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import io.grpc.Status; import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
/**
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
*
* @see RequestAttributesUtil
*/
public class RequestAttributesInterceptor implements ServerInterceptor { public class RequestAttributesInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager; private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager; this.grpcClientConnectionManager = grpcClientConnectionManager;
} }
@ -33,52 +30,12 @@ public class RequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) { try {
Context context = Context.current(); return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
{ grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next);
final Optional<InetAddress> maybeRemoteAddress = grpcClientConnectionManager.getRemoteAddress(localAddress); } catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
if (maybeRemoteAddress.isEmpty()) {
// We should never have a call from a party whose remote address we can't identify
log.warn("No remote address available");
call.close(Status.INTERNAL, new Metadata());
return new ServerCall.Listener<>() {};
}
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
}
{
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
grpcClientConnectionManager.getAcceptableLanguages(localAddress);
if (maybeAcceptLanguage.isPresent()) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
}
}
{
final Optional<String> maybeRawUserAgent =
grpcClientConnectionManager.getRawUserAgent(localAddress);
if (maybeRawUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.RAW_USER_AGENT_CONTEXT_KEY, maybeRawUserAgent.get());
}
}
{
final Optional<UserAgent> maybeUserAgent = grpcClientConnectionManager.getUserAgent(localAddress);
if (maybeUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
}
}
return Contexts.interceptCall(context, call, headers, next);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
} }
} }
} }

View File

@ -3,18 +3,13 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context; import io.grpc.Context;
import java.net.InetAddress; import java.net.InetAddress;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class RequestAttributesUtil { public class RequestAttributesUtil {
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language"); static final Context.Key<RequestAttributes> REQUEST_ATTRIBUTES_CONTEXT_KEY = Context.key("request-attributes");
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
static final Context.Key<String> RAW_USER_AGENT_CONTEXT_KEY = Context.key("unparsed-user-agent");
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("parsed-user-agent");
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales()); private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
@ -23,8 +18,8 @@ public class RequestAttributesUtil {
* *
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified * @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
*/ */
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() { public static List<Locale.LanguageRange> getAcceptableLanguages() {
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get()); return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().acceptLanguage();
} }
/** /**
@ -35,9 +30,7 @@ public class RequestAttributesUtil {
* @return a list of distinct locales acceptable to the remote client and available in this JVM * @return a list of distinct locales acceptable to the remote client and available in this JVM
*/ */
public static List<Locale> getAvailableAcceptedLocales() { public static List<Locale> getAvailableAcceptedLocales() {
return getAcceptableLanguages() return Locale.filter(getAcceptableLanguages(), AVAILABLE_LOCALES);
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
.orElseGet(Collections::emptyList);
} }
/** /**
@ -46,16 +39,7 @@ public class RequestAttributesUtil {
* @return the remote address of the remote client * @return the remote address of the remote client
*/ */
public static InetAddress getRemoteAddress() { public static InetAddress getRemoteAddress() {
return REMOTE_ADDRESS_CONTEXT_KEY.get(); return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().remoteAddress();
}
/**
* Returns the parsed user-agent of the remote client in the current gRPC request context.
*
* @return the parsed user-agent of the remote client; may be empty if unparseable or not specified
*/
public static Optional<UserAgent> getUserAgent() {
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
} }
/** /**
@ -63,7 +47,7 @@ public class RequestAttributesUtil {
* *
* @return the unparsed user-agent of the remote client; may be empty if not specified * @return the unparsed user-agent of the remote client; may be empty if not specified
*/ */
public static Optional<String> getRawUserAgent() { public static Optional<String> getUserAgent() {
return Optional.ofNullable(RAW_USER_AGENT_CONTEXT_KEY.get()); return Optional.ofNullable(REQUEST_ATTRIBUTES_CONTEXT_KEY.get().userAgent());
} }
} }

View File

@ -0,0 +1,39 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
public class ServerInterceptorUtil {
@SuppressWarnings("rawtypes")
private static final ServerCall.Listener NO_OP_LISTENER = new ServerCall.Listener<>() {};
private static final Metadata EMPTY_TRAILERS = new Metadata();
private ServerInterceptorUtil() {
}
/**
* Closes the given server call with the given status, returning a no-op listener.
*
* @param call the server call to close
* @param status the status with which to close the call
*
* @return a no-op server call listener
*
* @param <ReqT> the type of request object handled by the server call
* @param <RespT> the type of response object returned by the server call
*/
public static <ReqT, RespT> ServerCall.Listener<ReqT> closeWithStatus(final ServerCall<ReqT, RespT> call, final Status status) {
call.close(status, EMPTY_TRAILERS);
//noinspection unchecked
return NO_OP_LISTENER;
}
}

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.internalError; import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.internalError;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3; import com.google.protobuf.Message;
import io.grpc.ForwardingServerCallListener; import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
@ -75,7 +75,7 @@ public class ValidatingInterceptor implements ServerInterceptor {
} }
private void validateMessage(final Object message) throws StatusException { private void validateMessage(final Object message) throws StatusException {
if (message instanceof GeneratedMessageV3 msg) { if (message instanceof Message msg) {
try { try {
for (final Descriptors.FieldDescriptor fd: msg.getDescriptorForType().getFields()) { for (final Descriptors.FieldDescriptor fd: msg.getDescriptorForType().getFields()) {
for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry: fd.getOptions().getAllFields().entrySet()) { for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry: fd.getOptions().getAllFields().entrySet()) {

View File

@ -12,8 +12,10 @@ import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
/** /**
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering * An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
@ -48,12 +50,12 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
@Override @Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) { public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) { if (event instanceof NoiseIdentityDeterminedEvent(final Optional<AuthenticatedDevice> authenticatedDevice)) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
// authenticated service. // authenticated service.
final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent() final LocalAddress grpcServerAddress = authenticatedDevice.isPresent()
? authenticatedGrpcServerAddress ? authenticatedGrpcServerAddress
: anonymousGrpcServerAddress; : anonymousGrpcServerAddress;
@ -72,7 +74,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
if (localChannelFuture.isSuccess()) { if (localChannelFuture.isSuccess()) {
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(), grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(), remoteChannelContext.channel(),
noiseIdentityDeterminedEvent.authenticatedDevice()); authenticatedDevice);
// Close the local connection if the remote channel closes and vice versa // Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close()); remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());

View File

@ -1,6 +1,8 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.Grpc;
import io.grpc.ServerCall;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
@ -23,15 +25,26 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent; import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.util.ClosableEpoch;
/** /**
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a * A client connection manager associates a local connection to a local gRPC server with a remote connection through a
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the * Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close * identity of the device that opened the connection (for non-anonymous connections). It can also close connections
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate. * associated with a given device if that device's credentials have changed and clients must reauthenticate.
* <p>
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
* narrow window between a server call being created and the start of call execution, in which case accessor methods
* in this class will throw a {@link ChannelNotFoundException}.
* <p>
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
* be called from any application code.
*/ */
public class GrpcClientConnectionManager implements DisconnectionRequestListener { public class GrpcClientConnectionManager implements DisconnectionRequestListener {
@ -43,94 +56,93 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice"); AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
@VisibleForTesting @VisibleForTesting
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY = public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress"); AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
@VisibleForTesting @VisibleForTesting
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY = static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent"); AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
@VisibleForTesting
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
@VisibleForTesting
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class); private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
/** /**
* Returns the authenticated device associated with the given local address, if any. An authenticated device is * Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
* available if and only if the given local address maps to an active local connection and that connection is * (i.e. unauthenticated), the returned value will be empty.
* authenticated (i.e. not anonymous).
* *
* @param localAddress the local address for which to find an authenticated device * @param serverCall the gRPC server call for which to find an authenticated device
* *
* @return the authenticated device associated with the given local address, if any * @return the authenticated device associated with the given local address, if any
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/ */
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) { public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall)
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress)); throws ChannelNotFoundException {
return getAuthenticatedDevice(getRemoteChannel(serverCall));
} }
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) { @VisibleForTesting
return Optional.ofNullable(remoteChannel) Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) {
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get()); return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
} }
/** /**
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may * Returns the request attributes associated with the given server call.
* be unavailable if the local connection associated with the given local address has already closed, if the client
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
* *
* @param localAddress the local address for which to find acceptable languages * @param serverCall the gRPC server call for which to retrieve request attributes
* *
* @return the acceptable languages associated with the given local address, if any * @return the request attributes associated with the given server call
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/ */
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) { public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) return getRequestAttributes(getRemoteChannel(serverCall));
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get()); }
@VisibleForTesting
RequestAttributes getRequestAttributes(final Channel remoteChannel) {
final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
if (requestAttributes == null) {
throw new IllegalStateException("Channel does not have request attributes");
}
return requestAttributes;
} }
/** /**
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if * Handles the start of a server call, incrementing the active call count for the remote channel associated with the
* the local connection associated with the given local address has already closed. * given server call.
* *
* @param localAddress the local address for which to find a remote address * @param serverCall the server call to start
* *
* @return the remote address associated with the given local address, if any * @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the
* underlying channel is closing
*/ */
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) { public boolean handleServerCallStart(final ServerCall<?, ?> serverCall) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) try {
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive();
} catch (final ChannelNotFoundException e) {
// This would only happen if the channel had already closed, which is certainly possible. In this case, the call
// should certainly not proceed.
return false;
}
} }
/** /**
* Returns the unparsed user agent provided by the client that opened the connection associated with the given local * Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel
* address. This method may return an empty value if no active local connection is associated with the given local * associated with the given server call.
* address.
* *
* @param localAddress the local address for which to find a User-Agent string * @param serverCall the server call to complete
*
* @return the user agent string associated with the given local address
*/ */
public Optional<String> getRawUserAgent(final LocalAddress localAddress) { public void handleServerCallComplete(final ServerCall<?, ?> serverCall) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress)) try {
.map(remoteChannel -> remoteChannel.attr(RAW_USER_AGENT_ATTRIBUTE_KEY).get()); getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart();
} } catch (final ChannelNotFoundException ignored) {
// In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
/** }
* Returns the parsed user agent provided by the client that opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address or if the client's user-agent string was not recognized.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent associated with the given local address
*/
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
} }
/** /**
@ -139,11 +151,19 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
* @param authenticatedDevice the authenticated device for which to close connections * @param authenticatedDevice the authenticated device for which to close connections
*/ */
public void closeConnection(final AuthenticatedDevice authenticatedDevice) { public void closeConnection(final AuthenticatedDevice authenticatedDevice) {
// Channels will actually get removed from the list/map by their closeFuture listeners // Channels will actually get removed from the list/map by their closeFuture listeners. We copy the list to avoid
remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()).forEach(channel -> // concurrent modification; it's possible (though practically unlikely) that a channel can close and remove itself
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED // from the list while we're still iterating, resulting in a `ConcurrentModificationException`.
.toWebSocketCloseStatus("Reauthentication required"))) final List<Channel> channelsToClose =
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE)); new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
}
private static void closeRemoteChannel(final Channel channel) {
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
.toWebSocketCloseStatus("Reauthentication required")))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} }
@VisibleForTesting @VisibleForTesting
@ -151,11 +171,32 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice); return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
} }
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRemoteChannel(getLocalAddress(serverCall));
}
@VisibleForTesting @VisibleForTesting
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) { Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
if (remoteChannel == null) {
throw new ChannelNotFoundException();
}
return remoteChannelsByLocalAddress.get(localAddress); return remoteChannelsByLocalAddress.get(localAddress);
} }
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
// a proxied Noise channel.
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
return localAddress;
}
/** /**
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake * Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
* request with the channel via which the handshake took place. * request with the channel via which the handshake took place.
@ -166,30 +207,23 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be * @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null} * {@code null}
*/ */
static void handleWebSocketHandshakeComplete(final Channel channel, static void handleHandshakeComplete(final Channel channel,
final InetAddress preferredRemoteAddress, final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader, @Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) { @Nullable final String acceptLanguageHeader) {
channel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress); @Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
if (StringUtils.isNotBlank(userAgentHeader)) {
channel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
try {
channel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
} catch (final UnrecognizedUserAgentException ignored) {
}
}
if (StringUtils.isNotBlank(acceptLanguageHeader)) { if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try { try {
channel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader)); acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
} catch (final IllegalArgumentException e) { } catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e); log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
} }
} }
channel.attr(REQUEST_ATTRIBUTES_KEY)
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
} }
/** /**
@ -207,6 +241,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
maybeAuthenticatedDevice.ifPresent(authenticatedDevice -> maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice)); remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
.set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel); remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice -> getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->

View File

@ -74,7 +74,7 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
preferredRemoteAddress = maybePreferredRemoteAddress.get(); preferredRemoteAddress = maybePreferredRemoteAddress.get();
} }
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(), GrpcClientConnectionManager.handleHandshakeComplete(context.channel(),
preferredRemoteAddress, preferredRemoteAddress,
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT), handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE)); handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));

View File

@ -11,7 +11,6 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.Message; import com.google.protobuf.Message;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException; import io.grpc.StatusException;
@ -49,7 +48,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
public void validate( public void validate(
final Object extensionValue, final Object extensionValue,
final Descriptors.FieldDescriptor fd, final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException { final Message msg) throws StatusException {
try { try {
final T extensionValueTyped = resolveExtensionValue(extensionValue); final T extensionValueTyped = resolveExtensionValue(extensionValue);
@ -116,7 +115,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
protected void validateRepeatedField( protected void validateRepeatedField(
final T extensionValue, final T extensionValue,
final Descriptors.FieldDescriptor fd, final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException { final Message msg) throws StatusException {
throw internalError("`validateRepeatedField` method needs to be implemented"); throw internalError("`validateRepeatedField` method needs to be implemented");
} }

View File

@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3; import com.google.protobuf.Message;
import io.grpc.StatusException; import io.grpc.StatusException;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -53,7 +53,7 @@ public class ExactlySizeFieldValidator extends BaseFieldValidator<Set<Integer>>
protected void validateRepeatedField( protected void validateRepeatedField(
final Set<Integer> permittedSizes, final Set<Integer> permittedSizes,
final Descriptors.FieldDescriptor fd, final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException { final Message msg) throws StatusException {
final int size = msg.getRepeatedFieldCount(fd); final int size = msg.getRepeatedFieldCount(fd);
if (permittedSizes.contains(size)) { if (permittedSizes.contains(size)) {
return; return;

View File

@ -6,11 +6,11 @@
package org.whispersystems.textsecuregcm.grpc.validators; package org.whispersystems.textsecuregcm.grpc.validators;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3; import com.google.protobuf.Message;
import io.grpc.StatusException; import io.grpc.StatusException;
public interface FieldValidator { public interface FieldValidator {
void validate(Object extensionValue, Descriptors.FieldDescriptor fd, GeneratedMessageV3 msg) void validate(Object extensionValue, Descriptors.FieldDescriptor fd, Message msg)
throws StatusException; throws StatusException;
} }

View File

@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3; import com.google.protobuf.Message;
import io.grpc.StatusException; import io.grpc.StatusException;
import java.util.Set; import java.util.Set;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@ -52,7 +52,7 @@ public class NonEmptyFieldValidator extends BaseFieldValidator<Boolean> {
protected void validateRepeatedField( protected void validateRepeatedField(
final Boolean extensionValue, final Boolean extensionValue,
final Descriptors.FieldDescriptor fd, final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException { final Message msg) throws StatusException {
if (msg.getRepeatedFieldCount(fd) > 0) { if (msg.getRepeatedFieldCount(fd) > 0) {
return; return;
} }

View File

@ -5,7 +5,7 @@
package org.whispersystems.textsecuregcm.grpc.validators; package org.whispersystems.textsecuregcm.grpc.validators;
public record Range(int min, int max) { public record Range(long min, long max) {
public Range { public Range {
if (min > max) { if (min > max) {
throw new IllegalArgumentException("invalid range values: expected min <= max but have [%d, %d],".formatted(min, max)); throw new IllegalArgumentException("invalid range values: expected min <= max but have [%d, %d],".formatted(min, max));

View File

@ -39,8 +39,8 @@ public class RangeFieldValidator extends BaseFieldValidator<Range> {
@Override @Override
protected Range resolveExtensionValue(final Object extensionValue) throws StatusException { protected Range resolveExtensionValue(final Object extensionValue) throws StatusException {
final ValueRangeConstraint rangeConstraint = (ValueRangeConstraint) extensionValue; final ValueRangeConstraint rangeConstraint = (ValueRangeConstraint) extensionValue;
final int min = rangeConstraint.hasMin() ? rangeConstraint.getMin() : Integer.MIN_VALUE; final long min = rangeConstraint.hasMin() ? rangeConstraint.getMin() : Long.MIN_VALUE;
final int max = rangeConstraint.hasMax() ? rangeConstraint.getMax() : Integer.MAX_VALUE; final long max = rangeConstraint.hasMax() ? rangeConstraint.getMax() : Long.MAX_VALUE;
return new Range(min, max); return new Range(min, max);
} }

View File

@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3; import com.google.protobuf.Message;
import io.grpc.StatusException; import io.grpc.StatusException;
import java.util.Set; import java.util.Set;
import org.signal.chat.require.SizeConstraint; import org.signal.chat.require.SizeConstraint;
@ -48,7 +48,7 @@ public class SizeFieldValidator extends BaseFieldValidator<Range> {
} }
@Override @Override
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final GeneratedMessageV3 msg) throws StatusException { protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final Message msg) throws StatusException {
final int size = msg.getRepeatedFieldCount(fd); final int size = msg.getRepeatedFieldCount(fd);
if (size < range.min() || size > range.max()) { if (size < range.min() || size > range.max()) {
throw invalidArgument("field value is [%d] but expected to be within the [%d, %d] range".formatted( throw invalidArgument("field value is [%d] but expected to be within the [%d, %d] range".formatted(

View File

@ -0,0 +1,44 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import java.util.Optional;
public class DevicePlatformUtil {
private DevicePlatformUtil() {
}
/**
* Returns the most likely client platform for a device.
*
* @param device the device for which to find a client platform
*
* @return the most likely client platform for the given device or empty if no likely platform could be determined
*/
public static Optional<ClientPlatform> getDevicePlatform(final Device device) {
final Optional<ClientPlatform> clientPlatform;
if (StringUtils.isNotBlank(device.getGcmId())) {
clientPlatform = Optional.of(ClientPlatform.ANDROID);
} else if (StringUtils.isNotBlank(device.getApnId())) {
clientPlatform = Optional.of(ClientPlatform.IOS);
} else {
clientPlatform = Optional.empty();
}
return clientPlatform.or(() -> Optional.ofNullable(
switch (device.getUserAgent()) {
case "OWA" -> ClientPlatform.ANDROID;
case "OWI", "OWP" -> ClientPlatform.IOS;
case "OWD" -> ClientPlatform.DESKTOP;
case null, default -> null;
}));
}
}

View File

@ -67,7 +67,7 @@ public class OpenWebSocketCounter {
try { try {
final ClientPlatform clientPlatform = final ClientPlatform clientPlatform =
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).getPlatform(); UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).platform();
calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform); calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform);
calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform); calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform);

View File

@ -9,6 +9,7 @@ import io.micrometer.core.instrument.Tag;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.WhisperServerVersion; import org.whispersystems.textsecuregcm.WhisperServerVersion;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
@ -48,15 +49,15 @@ public class UserAgentTagUtil {
} }
public static Tag getPlatformTag(@Nullable final UserAgent userAgent) { public static Tag getPlatformTag(@Nullable final UserAgent userAgent) {
return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized"); return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized");
} }
public static Optional<Tag> getClientVersionTag(final String userAgentString, final ClientReleaseManager clientReleaseManager) { public static Optional<Tag> getClientVersionTag(final String userAgentString, final ClientReleaseManager clientReleaseManager) {
try { try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
if (clientReleaseManager.isVersionActive(userAgent.getPlatform(), userAgent.getVersion())) { if (clientReleaseManager.isVersionActive(userAgent.platform(), userAgent.version())) {
return Optional.of(Tag.of(VERSION_TAG, userAgent.getVersion().toString())); return Optional.of(Tag.of(VERSION_TAG, userAgent.version().toString()));
} }
} catch (final UnrecognizedUserAgentException ignored) { } catch (final UnrecognizedUserAgentException ignored) {
} }
@ -70,10 +71,8 @@ public class UserAgentTagUtil {
try { try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
platform = userAgent.getPlatform().name().toLowerCase(); platform = userAgent.platform().name().toLowerCase();
libsignal = userAgent.getAdditionalSpecifiers() libsignal = StringUtils.contains(userAgent.additionalSpecifiers(), "libsignal");
.map(additionalSpecifiers -> additionalSpecifiers.contains("libsignal"))
.orElse(false);
} catch (final UnrecognizedUserAgentException e) { } catch (final UnrecognizedUserAgentException e) {
platform = "unrecognized"; platform = "unrecognized";
libsignal = false; libsignal = false;

View File

@ -20,6 +20,7 @@ import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.protocol.util.Pair;
@ -35,7 +36,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
/** /**
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages, * A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
@ -56,6 +56,7 @@ public class MessageSender {
// Note that these names deliberately reference `MessageController` for metric continuity // Note that these names deliberately reference `MessageController` for metric continuity
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage"); private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage");
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize"); private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize");
private static final String EMPTY_MESSAGE_LIST_COUNTER_NAME = MetricsUtil.name(MessageSender.class, "emptyMessageList");
private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage"); private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage");
private static final String EPHEMERAL_TAG_NAME = "ephemeral"; private static final String EPHEMERAL_TAG_NAME = "ephemeral";
@ -85,6 +86,8 @@ public class MessageSender {
* @param destinationIdentifier the service identifier to which the messages are addressed * @param destinationIdentifier the service identifier to which the messages are addressed
* @param messagesByDeviceId a map of device IDs to message payloads * @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs * @param registrationIdsByDeviceId a map of device IDs to device registration IDs
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
* caller's own account), contains the ID of the device that sent the message
* @param userAgent the User-Agent string for the sender; may be {@code null} if not known * @param userAgent the User-Agent string for the sender; may be {@code null} if not known
* *
* @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required * @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required
@ -96,34 +99,48 @@ public class MessageSender {
final ServiceIdentifier destinationIdentifier, final ServiceIdentifier destinationIdentifier,
final Map<Byte, Envelope> messagesByDeviceId, final Map<Byte, Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId, final Map<Byte, Integer> registrationIdsByDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId,
@Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException { @Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException {
if (messagesByDeviceId.isEmpty()) {
return;
}
if (!destination.isIdentifiedBy(destinationIdentifier)) { if (!destination.isIdentifiedBy(destinationIdentifier)) {
throw new IllegalArgumentException("Destination account not identified by destination service identifier"); throw new IllegalArgumentException("Destination account not identified by destination service identifier");
} }
final Envelope firstMessage = messagesByDeviceId.values().iterator().next(); final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) && if (messagesByDeviceId.isEmpty()) {
destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId())); Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME,
Tags.of("sync", String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment();
}
final boolean isStory = firstMessage.getStory(); final byte excludedDeviceId;
if (syncMessageSenderDeviceId.isPresent()) {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) ||
!destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
validateIndividualMessageContentLength(messagesByDeviceId.values(), isSyncMessage, isStory, userAgent); throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender");
}
excludedDeviceId = syncMessageSenderDeviceId.get();
} else {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender");
}
excludedDeviceId = NO_EXCLUDED_DEVICE_ID;
}
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination, final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
destinationIdentifier, destinationIdentifier,
registrationIdsByDeviceId, registrationIdsByDeviceId,
isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID); excludedDeviceId);
if (maybeMismatchedDevices.isPresent()) { if (maybeMismatchedDevices.isPresent()) {
throw new MismatchedDevicesException(maybeMismatchedDevices.get()); throw new MismatchedDevicesException(maybeMismatchedDevices.get());
} }
validateIndividualMessageContentLength(messagesByDeviceId.values(), syncMessageSenderDeviceId.isPresent(), userAgent);
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId) messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
.forEach((deviceId, destinationPresent) -> { .forEach((deviceId, destinationPresent) -> {
final Envelope message = messagesByDeviceId.get(deviceId); final Envelope message = messagesByDeviceId.get(deviceId);
@ -142,7 +159,7 @@ public class MessageSender {
STORY_TAG_NAME, String.valueOf(message.getStory()), STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()), SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()),
MULTI_RECIPIENT_TAG_NAME, "false") MULTI_RECIPIENT_TAG_NAME, "false")
.and(UserAgentTagUtil.getPlatformTag(userAgent)); .and(platformTag);
Metrics.counter(SEND_COUNTER_NAME, tags).increment(); Metrics.counter(SEND_COUNTER_NAME, tags).increment();
}); });
@ -308,14 +325,13 @@ public class MessageSender {
private static void validateIndividualMessageContentLength(final Iterable<Envelope> messages, private static void validateIndividualMessageContentLength(final Iterable<Envelope> messages,
final boolean isSyncMessage, final boolean isSyncMessage,
final boolean isStory,
@Nullable final String userAgent) throws MessageTooLargeException { @Nullable final String userAgent) throws MessageTooLargeException {
for (final Envelope message : messages) { for (final Envelope message : messages) {
MessageSender.validateContentLength(message.getContent().size(), MessageSender.validateContentLength(message.getContent().size(),
false, false,
isSyncMessage, isSyncMessage,
isStory, message.getStory(),
userAgent); userAgent);
} }
} }

View File

@ -0,0 +1,128 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
public class MessageUtil {
public static final int DEFAULT_MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
private MessageUtil() {
}
/**
* Finds account records for all recipients named in the given multi-recipient manager. Note that the returned map
* of recipients to account records will omit entries for recipients that could not be resolved to active accounts;
* callers that require full resolution should check for a missing entries and take appropriate action.
*
* @param accountsManager the {@code AccountsManager} instance to use to find account records
* @param multiRecipientMessage the message for which to resolve recipients
*
* @return a map of recipients to account records
*
* @see #getUnresolvedRecipients(SealedSenderMultiRecipientMessage, Map)
*/
public static Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolveRecipients(
final AccountsManager accountsManager,
final SealedSenderMultiRecipientMessage multiRecipientMessage) {
return resolveRecipients(accountsManager, multiRecipientMessage, DEFAULT_MAX_FETCH_ACCOUNT_CONCURRENCY);
}
/**
* Finds account records for all recipients named in the given multi-recipient manager. Note that the returned map
* of recipients to account records will omit entries for recipients that could not be resolved to active accounts;
* callers that require full resolution should check for a missing entries and take appropriate action.
*
* @param accountsManager the {@code AccountsManager} instance to use to find account records
* @param multiRecipientMessage the message for which to resolve recipients
* @param maxFetchAccountConcurrency the maximum number of concurrent account-retrieval operations
*
* @return a map of recipients to account records
*
* @see #getUnresolvedRecipients(SealedSenderMultiRecipientMessage, Map)
*/
public static Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolveRecipients(
final AccountsManager accountsManager,
final SealedSenderMultiRecipientMessage multiRecipientMessage,
final int maxFetchAccountConcurrency) {
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.flatMap(serviceIdAndRecipient -> {
final ServiceIdentifier serviceIdentifier =
ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.flatMap(Mono::justOrEmpty)
.map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account));
}, maxFetchAccountConcurrency)
.collectMap(Tuple2::getT1, Tuple2::getT2)
.blockOptional()
.orElse(Collections.emptyMap());
}
/**
* Returns a list of recipients missing from the map of resolved recipients for a multi-recipient message.
*
* @param multiRecipientMessage the multi-recipient message
* @param resolvedRecipients the map of resolved recipients to check for missing entries
*
* @return a list of {@code ServiceIdentifiers} belonging to multi-recipient message recipients that are not present
* in the given map of {@code resolvedRecipients}
*/
public static List<ServiceIdentifier> getUnresolvedRecipients(
final SealedSenderMultiRecipientMessage multiRecipientMessage,
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients) {
return multiRecipientMessage.getRecipients().entrySet().stream()
.filter(entry -> !resolvedRecipients.containsKey(entry.getValue()))
.map(entry -> ServiceIdentifier.fromLibsignal(entry.getKey()))
.toList();
}
/**
* Checks if a multi-recipient message contains duplicate recipients.
*
* @param multiRecipientMessage the message to check for duplicate recipients
*
* @return {@code true} if the message contains duplicate recipients or {@code false} otherwise
*/
public static boolean hasDuplicateDevices(final SealedSenderMultiRecipientMessage multiRecipientMessage) {
final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID + 1];
for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) {
if (recipient.getDevices().length == 1) {
// A recipient can't have repeated devices if they only have one device
continue;
}
Arrays.fill(usedDeviceIds, false);
for (final byte deviceId : recipient.getDevices()) {
if (usedDeviceIds[deviceId]) {
return true;
}
usedDeviceIds[deviceId] = true;
}
}
return false;
}
}

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.push;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -70,6 +71,7 @@ public class ReceiptSender {
destinationIdentifier, destinationIdentifier,
messagesByDeviceId, messagesByDeviceId,
registrationIdsByDeviceId, registrationIdsByDeviceId,
Optional.empty(),
UserAgentTagUtil.SERVER_UA); UserAgentTagUtil.SERVER_UA);
} catch (final Exception e) { } catch (final Exception e) {
logger.warn("Could not send delivery receipt", e); logger.warn("Could not send delivery receipt", e);

View File

@ -38,8 +38,8 @@ public class SchedulingUtil {
final LocalTime preferredTime, final LocalTime preferredTime,
final Clock clock) { final Clock clock) {
final ZonedDateTime candidateNotificationTime = getZoneOffset(account, clock) final ZonedDateTime candidateNotificationTime = getZoneId(account, clock)
.map(zoneOffset -> ZonedDateTime.now(zoneOffset).with(preferredTime)) .map(zoneId -> ZonedDateTime.now(clock.withZone(zoneId)).with(preferredTime))
.orElseGet(() -> { .orElseGet(() -> {
// We couldn't find a reasonable timezone for the account for some reason, so make an educated guess at a // We couldn't find a reasonable timezone for the account for some reason, so make an educated guess at a
// reasonable time to send a notification based on the account's creation time. // reasonable time to send a notification based on the account's creation time.
@ -59,7 +59,7 @@ public class SchedulingUtil {
} }
@VisibleForTesting @VisibleForTesting
static Optional<ZoneOffset> getZoneOffset(final Account account, final Clock clock) { static Optional<ZoneId> getZoneId(final Account account, final Clock clock) {
try { try {
final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(account.getNumber(), null); final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(account.getNumber(), null);
@ -70,7 +70,7 @@ public class SchedulingUtil {
return Optional.empty(); return Optional.empty();
} }
final List<ZoneOffset> sortedZoneOffsets = timeZonesForNumber final List<ZoneId> sortedZoneOffsets = timeZonesForNumber
.stream() .stream()
.map(id -> { .map(id -> {
try { try {
@ -80,9 +80,6 @@ public class SchedulingUtil {
} }
}) })
.filter(Objects::nonNull) .filter(Objects::nonNull)
.map(ZoneId::getRules)
.distinct()
.map(zoneRules -> zoneRules.getOffset(clock.instant()))
.sorted() .sorted()
.toList(); .toList();

View File

@ -6,19 +6,27 @@
package org.whispersystems.textsecuregcm.spam; package org.whispersystems.textsecuregcm.spam;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import javax.annotation.Nullable;
import java.util.Optional; import java.util.Optional;
/** /**
* A combination of a gRPC status and response message to communicate to callers that a message has been flagged as * A combination of a gRPC status and response message to communicate to callers that a message has been flagged as
* potential spam. * potential spam.
* *
* @param status The gRPC status for this response. If the status is {@link Status#OK}, then a response object will be
* available via {@link #response}. Otherwise, callers should transmit the status as an error to clients.
* @param response a response object to send to clients; will be present if {@link #status} is not {@link Status#OK}
*
* @param <R> the type of response object * @param <R> the type of response object
*/ */
public record GrpcResponse<R>(Status status, Optional<R> response) { public class GrpcResponse<R> {
private final Status status;
@Nullable
private final R response;
private GrpcResponse(final Status status, @Nullable final R response) {
this.status = status;
this.response = response;
}
/** /**
* Constructs a new response object with the given status and no response message. * Constructs a new response object with the given status and no response message.
@ -30,7 +38,7 @@ public record GrpcResponse<R>(Status status, Optional<R> response) {
* @param <R> the type of response object * @param <R> the type of response object
*/ */
public static <R> GrpcResponse<R> withStatus(final Status status) { public static <R> GrpcResponse<R> withStatus(final Status status) {
return new GrpcResponse<>(status, Optional.empty()); return new GrpcResponse<>(status, null);
} }
/** /**
@ -43,6 +51,22 @@ public record GrpcResponse<R>(Status status, Optional<R> response) {
* @param <R> the type of response object * @param <R> the type of response object
*/ */
public static <R> GrpcResponse<R> withResponse(final R response) { public static <R> GrpcResponse<R> withResponse(final R response) {
return new GrpcResponse<>(Status.OK, Optional.of(response)); return new GrpcResponse<>(Status.OK, response);
}
/**
* Returns the message body contained within this response or throws the contained status as a {@link StatusException}
* if no message body is specified.
*
* @return the message body contained within this response
*
* @throws StatusException if no message body is specified
*/
public R getResponseOrThrowStatus() throws StatusException {
if (response != null) {
return response;
}
throw status.asException();
} }
} }

View File

@ -5,8 +5,10 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.time.Clock;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.ObjectUtils;
@ -28,12 +30,16 @@ public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class); private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class);
private final MessageSender messageSender; private final MessageSender messageSender;
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final Clock clock;
public ChangeNumberManager( public ChangeNumberManager(
final MessageSender messageSender, final MessageSender messageSender,
final AccountsManager accountsManager) { final AccountsManager accountsManager,
final Clock clock) {
this.messageSender = messageSender; this.messageSender = messageSender;
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.clock = clock;
} }
public Account changeNumber(final Account account, final String number, public Account changeNumber(final Account account, final String number,
@ -96,7 +102,7 @@ public class ChangeNumberManager {
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException { final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
try { try {
final long serverTimestamp = System.currentTimeMillis(); final long serverTimestamp = clock.millis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid()); final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream() final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
@ -113,10 +119,15 @@ public class ChangeNumberManager {
.setEphemeral(false) .setEphemeral(false)
.build())); .build()));
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream() final Map<Byte, Integer> registrationIdsByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId)); .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId, senderUserAgent); messageSender.sendMessages(account,
serviceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
Optional.of(Device.PRIMARY_ID),
senderUserAgent);
} catch (final RuntimeException e) { } catch (final RuntimeException e) {
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e); logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
throw e; throw e;

View File

@ -12,10 +12,13 @@ import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
@ -25,6 +28,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -44,6 +48,8 @@ public class MessagePersister implements Managed {
private volatile boolean running; private volatile boolean running;
private static final String OVERSIZED_QUEUE_COUNTER_NAME = name(MessagePersister.class, "persistQueueOversized"); private static final String OVERSIZED_QUEUE_COUNTER_NAME = name(MessagePersister.class, "persistQueueOversized");
private static final String PERSISTED_MESSAGE_COUNTER_NAME = name(MessagePersister.class, "persistMessage");
private static final String PERSISTED_BYTES_COUNTER_NAME = name(MessagePersister.class, "persistBytes");
private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(MessagePersister.class, "getQueues")); private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(MessagePersister.class, "getQueues"));
private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(MessagePersister.class, "persistQueue")); private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(MessagePersister.class, "persistQueue"));
@ -57,10 +63,7 @@ public class MessagePersister implements Managed {
.publishPercentileHistogram(true) .publishPercentileHistogram(true)
.register(Metrics.globalRegistry); .register(Metrics.globalRegistry);
private static final DistributionSummary QUEUE_SIZE_DISTRIBUTION_SUMMARY = DistributionSummary.builder( private static final String QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME = name(MessagePersister.class, "queueSize");
name(MessagePersister.class, "queueSize"))
.publishPercentileHistogram(true)
.register(Metrics.globalRegistry);
static final int QUEUE_BATCH_LIMIT = 100; static final int QUEUE_BATCH_LIMIT = 100;
static final int MESSAGE_BATCH_LIMIT = 100; static final int MESSAGE_BATCH_LIMIT = 100;
@ -139,6 +142,7 @@ public class MessagePersister implements Managed {
@VisibleForTesting @VisibleForTesting
int persistNextQueues(final Instant currentTime) { int persistNextQueues(final Instant currentTime) {
final int slot = messagesCache.getNextSlotToPersist(); final int slot = messagesCache.getNextSlotToPersist();
final String shard = messagesCache.shardForSlot(slot);
List<String> queuesToPersist; List<String> queuesToPersist;
int queuesPersisted = 0; int queuesPersisted = 0;
@ -162,10 +166,11 @@ public class MessagePersister implements Managed {
continue; continue;
} }
try { try {
persistQueue(maybeAccount.get(), maybeDevice.get()); persistQueue(maybeAccount.get(), maybeDevice.get(), shard);
} catch (final Exception e) { } catch (final Exception e) {
PERSIST_QUEUE_EXCEPTION_METER.increment(); PERSIST_QUEUE_EXCEPTION_METER.increment();
logger.warn("Failed to persist queue {}::{}; will schedule for retry", accountUuid, deviceId, e); logger.warn("Failed to persist queue {}::{} (slot {}, shard {}); will schedule for retry",
accountUuid, deviceId, slot, shard, e);
messagesCache.addQueueToPersist(accountUuid, deviceId); messagesCache.addQueueToPersist(accountUuid, deviceId);
@ -183,10 +188,14 @@ public class MessagePersister implements Managed {
} }
@VisibleForTesting @VisibleForTesting
void persistQueue(final Account account, final Device device) throws MessagePersistenceException { void persistQueue(final Account account, final Device device, final String shard) throws MessagePersistenceException {
final UUID accountUuid = account.getUuid(); final UUID accountUuid = account.getUuid();
final byte deviceId = device.getId(); final byte deviceId = device.getId();
final Tag platformTag = Tag.of("platform", DevicePlatformUtil.getDevicePlatform(device)
.map(platform -> platform.name().toLowerCase(Locale.ROOT))
.orElse("unknown"));
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
messagesCache.lockQueueForPersistence(accountUuid, deviceId); messagesCache.lockQueueForPersistence(accountUuid, deviceId);
@ -200,6 +209,16 @@ public class MessagePersister implements Managed {
do { do {
messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT); messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT);
final int urgentMessageCount = (int) messages.stream().filter(MessageProtos.Envelope::getUrgent).count();
final int nonUrgentMessageCount = messages.size() - urgentMessageCount;
final Tags tags = Tags.of(platformTag, Tag.of("shard", shard));
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "true")).increment(urgentMessageCount);
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "false")).increment(nonUrgentMessageCount);
Metrics.counter(PERSISTED_BYTES_COUNTER_NAME, tags)
.increment(messages.stream().mapToInt(MessageProtos.Envelope::getSerializedSize).sum());
int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, device, messages); int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, device, messages);
messageCount += messages.size(); messageCount += messages.size();
@ -215,7 +234,11 @@ public class MessagePersister implements Managed {
} while (!messages.isEmpty()); } while (!messages.isEmpty());
QUEUE_SIZE_DISTRIBUTION_SUMMARY.record(messageCount); DistributionSummary.builder(QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME)
.tags(Tags.of(platformTag))
.publishPercentileHistogram(true)
.register(Metrics.globalRegistry)
.record(messageCount);
} catch (ItemCollectionSizeLimitExceededException e) { } catch (ItemCollectionSizeLimitExceededException e) {
final boolean isPrimary = deviceId == Device.PRIMARY_ID; final boolean isPrimary = deviceId == Device.PRIMARY_ID;
Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment(); Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment();
@ -234,7 +257,6 @@ public class MessagePersister implements Managed {
messagesCache.unlockQueueForPersistence(accountUuid, deviceId); messagesCache.unlockQueueForPersistence(accountUuid, deviceId);
sample.stop(PERSIST_QUEUE_TIMER); sample.stop(PERSIST_QUEUE_TIMER);
} }
} }
private void trimQueue(final Account account, byte deviceId) { private void trimQueue(final Account account, byte deviceId) {

View File

@ -15,6 +15,8 @@ import io.lettuce.core.Range;
import io.lettuce.core.ScoredValue; import io.lettuce.core.ScoredValue;
import io.lettuce.core.ZAddArgs; import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.models.partitions.ClusterPartitionParser;
import io.lettuce.core.cluster.models.partitions.Partitions;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
@ -668,6 +670,16 @@ public class MessagesCache {
.thenRun(() -> sample.stop(clearQueueTimer)); .thenRun(() -> sample.stop(clearQueueTimer));
} }
// expensiveuse for rare error logging only
public String shardForSlot(int slot) {
try {
return redisCluster.withBinaryCluster(
connection -> ClusterPartitionParser.parse(connection.sync().clusterNodes()).getPartitionBySlot(slot).getUri().getHost());
} catch (Throwable ignored) {
return "unknown";
}
}
int getNextSlotToPersist() { int getNextSlotToPersist() {
return (int) (redisCluster.withCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY)) return (int) (redisCluster.withCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY))
% SlotHash.SLOT_COUNT); % SlotHash.SLOT_COUNT);

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.subscriptions;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.stripe.Stripe;
import com.stripe.StripeClient; import com.stripe.StripeClient;
import com.stripe.exception.CardException; import com.stripe.exception.CardException;
import com.stripe.exception.IdempotencyException; import com.stripe.exception.IdempotencyException;
@ -71,6 +72,7 @@ import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerVersion;
import org.whispersystems.textsecuregcm.storage.PaymentTime; import org.whispersystems.textsecuregcm.storage.PaymentTime;
import org.whispersystems.textsecuregcm.storage.SubscriptionException; import org.whispersystems.textsecuregcm.storage.SubscriptionException;
import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Conversions;
@ -97,6 +99,9 @@ public class StripeManager implements CustomerAwareSubscriptionPaymentProcessor
if (Strings.isNullOrEmpty(apiKey)) { if (Strings.isNullOrEmpty(apiKey)) {
throw new IllegalArgumentException("apiKey cannot be empty"); throw new IllegalArgumentException("apiKey cannot be empty");
} }
Stripe.setAppInfo("Signal-Server", WhisperServerVersion.getServerVersion());
this.stripeClient = new StripeClient(apiKey); this.stripeClient = new StripeClient(apiKey);
this.executor = Objects.requireNonNull(executor); this.executor = Objects.requireNonNull(executor);
this.idempotencyKeyGenerator = Objects.requireNonNull(idempotencyKeyGenerator); this.idempotencyKeyGenerator = Objects.requireNonNull(idempotencyKeyGenerator);

View File

@ -0,0 +1,92 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.google.common.annotations.VisibleForTesting;
import java.util.concurrent.atomic.AtomicInteger;
/**
* A closable epoch is a concurrency construct that measures the number of callers in some critical section. A closable
* epoch can be closed to prevent new callers from entering the critical section, and takes a specific action when the
* critical section is empty after closure.
*/
public class ClosableEpoch {
private final Runnable onCloseHandler;
private final AtomicInteger state = new AtomicInteger();
private static final int CLOSING_BIT_MASK = 0x00000001;
/**
* Constructs a new closable epoch that will execute the given handler when the epoch is closed and all callers have
* departed the critical section. The handler will be executed on the thread that calls {@link #close()} if the
* critical section is empty at the time of the call or on the last thread to call {@link #depart()} otherwise.
* Callers should provide handlers that delegate execution to a specific thread/executor if more precise control over
* which thread runs the handler is required.
*
* @param onCloseHandler a handler to run when the epoch is closed and all callers have departed the critical section
*/
public ClosableEpoch(final Runnable onCloseHandler) {
this.onCloseHandler = onCloseHandler;
}
/**
* Announces the arrival of a caller at the start of the critical section. If the caller is allowed to enter the
* critical section, the epoch's internal caller counter is incremented accordingly.
*
* @return {@code true} if the caller is allowed to enter the critical section or {@code false} if it is not allowed
* to enter the critical section because this epoch is closing
*/
public boolean tryArrive() {
// Increment the number of active callers if and only if we're not closing. We add 2 because the lowest bit encodes
// the "closing" state, and the bits above it encode the actual call count. More verbosely, we're doing
// `state += (1 << 1)` to avoid overwriting the closing state bit.
return !isClosing(state.updateAndGet(state -> isClosing(state) ? state : state + 2));
}
/**
* Announces the departure of a caller from the critical section. If the epoch is closing and the caller is the last
* to depart the critical section, then the epoch will fire its {@code onCloseHandler}.
*/
public void depart() {
// Decrement the active caller count unconditionally. As with `tryActive`, we work in increments of 2 to "dodge" the
// "is closing" bit. If the call count is zero and we're closing then `state` will just have the "closing" bit set.
if (state.addAndGet(-2) == CLOSING_BIT_MASK) {
onCloseHandler.run();
}
}
/**
* Closes this epoch, preventing new callers from entering the critical section. If the critical section is empty when
* this method is called, it will trigger the {@code onCloseHandler} immediately. Otherwise, the
* {@code onCloseHandler} will fire when the last caller departs the critical section.
*
* @throws IllegalStateException if this epoch is already closed; note that this exception is thrown on a
* "best-effort" basis to help callers detect bugs
*/
public void close() {
// Note that this is not airtight and is a "best-effort" check
if (isClosing(state.get())) {
throw new IllegalStateException("Epoch already closed");
}
// Set the "closing" bit. If the closing bit is the only bit set, then the call count is zero and we can call the
// "on close" handler.
if (state.updateAndGet(state -> state | CLOSING_BIT_MASK) == CLOSING_BIT_MASK) {
onCloseHandler.run();
}
}
@VisibleForTesting
int getActiveCallers() {
return state.get() >> 1;
}
private static boolean isClosing(final int state) {
return (state & CLOSING_BIT_MASK) != 0;
}
}

View File

@ -46,7 +46,7 @@ public class LoggingUnhandledExceptionMapper extends LoggingExceptionMapper<Thro
// streamline the user-agent if it is recognized // streamline the user-agent if it is recognized
final UserAgent ua = UserAgentUtil.parseUserAgentString(userAgent); final UserAgent ua = UserAgentUtil.parseUserAgentString(userAgent);
userAgent = String.format("%s %s", ua.getPlatform(), ua.getVersion()); userAgent = String.format("%s %s", ua.platform(), ua.version());
} catch (final UnrecognizedUserAgentException ignored) { } catch (final UnrecognizedUserAgentException ignored) {
} catch (final Exception e) { } catch (final Exception e) {

View File

@ -6,58 +6,8 @@
package org.whispersystems.textsecuregcm.util.ua; package org.whispersystems.textsecuregcm.util.ua;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import javax.annotation.Nullable;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
public class UserAgent { public record UserAgent(ClientPlatform platform, Semver version, @Nullable String additionalSpecifiers) {
private final ClientPlatform platform;
private final Semver version;
private final String additionalSpecifiers;
public UserAgent(final ClientPlatform platform, final Semver version) {
this(platform, version, null);
}
public UserAgent(final ClientPlatform platform, final Semver version, final String additionalSpecifiers) {
this.platform = platform;
this.version = version;
this.additionalSpecifiers = additionalSpecifiers;
}
public ClientPlatform getPlatform() {
return platform;
}
public Semver getVersion() {
return version;
}
public Optional<String> getAdditionalSpecifiers() {
return Optional.ofNullable(additionalSpecifiers);
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final UserAgent userAgent = (UserAgent)o;
return platform == userAgent.platform &&
version.equals(userAgent.version) &&
Objects.equals(additionalSpecifiers, userAgent.additionalSpecifiers);
}
@Override
public int hashCode() {
return Objects.hash(platform, version, additionalSpecifiers);
}
@Override
public String toString() {
return "UserAgent{" +
"platform=" + platform +
", version=" + version +
", additionalSpecifiers='" + additionalSpecifiers + '\'' +
'}';
}
} }

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.util.ua; package org.whispersystems.textsecuregcm.util.ua;
import com.google.common.annotations.VisibleForTesting;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@ -21,10 +20,10 @@ public class UserAgentUtil {
} }
try { try {
final UserAgent standardUserAgent = parseStandardUserAgentString(userAgentString); final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);
if (standardUserAgent != null) { if (matcher.matches()) {
return standardUserAgent; return new UserAgent(ClientPlatform.valueOf(matcher.group(1).toUpperCase()), new Semver(matcher.group(2)), StringUtils.stripToNull(matcher.group(4)));
} }
} catch (final Exception e) { } catch (final Exception e) {
throw new UnrecognizedUserAgentException(e); throw new UnrecognizedUserAgentException(e);
@ -32,15 +31,4 @@ public class UserAgentUtil {
throw new UnrecognizedUserAgentException(); throw new UnrecognizedUserAgentException();
} }
@VisibleForTesting
static UserAgent parseStandardUserAgentString(final String userAgentString) {
final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);
if (matcher.matches()) {
return new UserAgent(ClientPlatform.valueOf(matcher.group(1).toUpperCase()), new Semver(matcher.group(2)), StringUtils.stripToNull(matcher.group(4)));
}
return null;
}
} }

View File

@ -24,10 +24,12 @@ service Messages {
* destination account. * destination account.
* *
* This RPC may fail with a `NOT_FOUND` status if the destination account was * This RPC may fail with a `NOT_FOUND` status if the destination account was
* not found. It may also fail with a `RESOURCE_EXHAUSTED` status if a rate * not found. It may also fail with an `INVALID_ARGUMENT` status if the
* limit for sending messages has been exceeded, in which case a `retry-after` * destination account is the same as the authenticated caller (callers should
* header containing an ISO 8601 duration string may be present in the * use `SendSyncMessage` to send messages to themselves). It may also fail
* response trailers. * with a `RESOURCE_EXHAUSTED` status if a rate limit for sending messages has
* been exceeded, in which case a `retry-after` header containing an ISO 8601
* duration string may be present in the response trailers.
* *
* Note that message delivery may not succeed even if this RPC returns an `OK` * Note that message delivery may not succeed even if this RPC returns an `OK`
* status; callers must check the response object to verify that the message * status; callers must check the response object to verify that the message
@ -142,9 +144,12 @@ message IndividualRecipientMessageBundle {
/** /**
* The time, in milliseconds since the epoch, at which this message was * The time, in milliseconds since the epoch, at which this message was
* originally sent from the perspective of the sender. * originally sent from the perspective of the sender. Note that the maximum
* allowable timestamp for JavaScript clients is less than Long.MAX_VALUE; see
* https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Date#the_epoch_timestamps_and_invalid_date
* for additional details and discussion.
*/ */
uint64 timestamp = 1; uint64 timestamp = 1 [(require.range).min = 1, (require.range).max = 8640000000000000];
/** /**
* A map of device IDs to individual messages. Generally, callers must include * A map of device IDs to individual messages. Generally, callers must include
@ -327,9 +332,12 @@ message MultiRecipientMessage {
/** /**
* The time, in milliseconds since the epoch, at which this message was * The time, in milliseconds since the epoch, at which this message was
* originally sent from the perspective of the sender. * originally sent from the perspective of the sender. Note that the maximum
* allowable timestamp for JavaScript clients is less than Long.MAX_VALUE; see
* https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Date#the_epoch_timestamps_and_invalid_date
* for additional details and discussion.
*/ */
uint64 timestamp = 1; uint64 timestamp = 1 [(require.range).min = 1, (require.range).max = 8640000000000000];
/** /**
* The serialized multi-recipient message payload. * The serialized multi-recipient message payload.

View File

@ -149,8 +149,8 @@ message SizeConstraint {
} }
message ValueRangeConstraint { message ValueRangeConstraint {
optional int32 min = 1; optional int64 min = 1;
optional int32 max = 2; optional int64 max = 2;
} }
extend google.protobuf.ServiceOptions { extend google.protobuf.ServiceOptions {

View File

@ -5,8 +5,11 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; import static com.github.tomakehurst.wiremock.client.WireMock.created;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -15,17 +18,19 @@ import static org.mockito.Mockito.when;
import com.github.tomakehurst.wiremock.junit5.WireMockExtension; import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import io.netty.resolver.dns.DnsNameResolver; import io.netty.resolver.dns.DnsNameResolver;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GlobalEventExecutor;
import io.netty.util.concurrent.SucceededFuture;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.security.cert.CertificateException; import java.time.Duration;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -35,31 +40,41 @@ import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
public class CloudflareTurnCredentialsManagerTest { public class CloudflareTurnCredentialsManagerTest {
@RegisterExtension @RegisterExtension
private final WireMockExtension wireMock = WireMockExtension.newInstance() private static final WireMockExtension wireMock = WireMockExtension.newInstance()
.options(wireMockConfig().dynamicPort().dynamicHttpsPort()) .options(wireMockConfig().dynamicPort().dynamicHttpsPort())
.build(); .build();
private static final String GET_CREDENTIALS_PATH = "/v1/turn/keys/LMNOP/credentials/generate";
private static final String TURN_HOSTNAME = "localhost";
private ExecutorService httpExecutor; private ExecutorService httpExecutor;
private ScheduledExecutorService retryExecutor; private ScheduledExecutorService retryExecutor;
private DnsNameResolver dnsResolver; private DnsNameResolver dnsResolver;
private Future<List<InetAddress>> dnsResult;
private CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = null; private CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
private static final String GET_CREDENTIALS_PATH = "/v1/turn/keys/LMNOP/credentials/generate";
private static final String TURN_HOSTNAME = "localhost";
private static final String API_TOKEN = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final String USERNAME = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final String CREDENTIAL = RandomStringUtils.insecure().nextAlphanumeric(16);
private static final List<String> CLOUDFLARE_TURN_URLS = List.of("turn:cf.example.com");
private static final Duration REQUESTED_CREDENTIAL_TTL = Duration.ofSeconds(100);
private static final Duration CLIENT_CREDENTIAL_TTL = REQUESTED_CREDENTIAL_TTL.dividedBy(2);
private static final List<String> IP_URL_PATTERNS = List.of("turn:%s", "turn:%s:80?transport=tcp", "turns:%s:443?transport=tcp");
@BeforeEach @BeforeEach
void setUp() throws CertificateException { void setUp() {
httpExecutor = Executors.newSingleThreadExecutor(); httpExecutor = Executors.newSingleThreadExecutor();
retryExecutor = Executors.newSingleThreadScheduledExecutor(); retryExecutor = Executors.newSingleThreadScheduledExecutor();
dnsResolver = mock(DnsNameResolver.class); dnsResolver = mock(DnsNameResolver.class);
dnsResult = mock(Future.class);
cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager( cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
"API_TOKEN", API_TOKEN,
"http://localhost:" + wireMock.getPort() + GET_CREDENTIALS_PATH, "http://localhost:" + wireMock.getPort() + GET_CREDENTIALS_PATH,
100, REQUESTED_CREDENTIAL_TTL,
List.of("turn:cf.example.com"), CLIENT_CREDENTIAL_TTL,
List.of("turn:%s", "turn:%s:80?transport=tcp", "turns:%s:443?transport=tcp"), CLOUDFLARE_TURN_URLS,
IP_URL_PATTERNS,
TURN_HOSTNAME, TURN_HOSTNAME,
2, 2,
new CircuitBreakerConfiguration(), new CircuitBreakerConfiguration(),
@ -73,26 +88,61 @@ public class CloudflareTurnCredentialsManagerTest {
@AfterEach @AfterEach
void tearDown() throws InterruptedException { void tearDown() throws InterruptedException {
httpExecutor.shutdown(); httpExecutor.shutdown();
httpExecutor.awaitTermination(1, TimeUnit.SECONDS);
retryExecutor.shutdown(); retryExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
httpExecutor.awaitTermination(1, TimeUnit.SECONDS);
//noinspection ResultOfMethodCallIgnored
retryExecutor.awaitTermination(1, TimeUnit.SECONDS); retryExecutor.awaitTermination(1, TimeUnit.SECONDS);
} }
@Test @Test
public void testSuccess() throws IOException, CancellationException, ExecutionException, InterruptedException { public void testSuccess() throws IOException, CancellationException {
wireMock.stubFor(post(urlEqualTo(GET_CREDENTIALS_PATH)) wireMock.stubFor(post(urlEqualTo(GET_CREDENTIALS_PATH))
.willReturn(aResponse().withStatus(201).withHeader("Content-Type", new String[]{"application/json"}).withBody("{\"iceServers\":{\"urls\":[\"turn:cloudflare.example.com:3478?transport=udp\"],\"username\":\"ABC\",\"credential\":\"XYZ\"}}"))); .willReturn(created()
when(dnsResult.get()) .withHeader("Content-Type", "application/json")
.thenReturn(List.of(InetAddress.getByName("127.0.0.1"), InetAddress.getByName("::1"))); .withBody("""
{
"iceServers": {
"urls": [
"turn:cloudflare.example.com:3478?transport=udp"
],
"username": "%s",
"credential": "%s"
}
}
""".formatted(USERNAME, CREDENTIAL))));
when(dnsResolver.resolveAll(TURN_HOSTNAME)) when(dnsResolver.resolveAll(TURN_HOSTNAME))
.thenReturn(dnsResult); .thenReturn(new SucceededFuture<>(GlobalEventExecutor.INSTANCE,
List.of(InetAddress.getByName("127.0.0.1"), InetAddress.getByName("::1"))));
TurnToken token = cloudflareTurnCredentialsManager.retrieveFromCloudflare(); TurnToken token = cloudflareTurnCredentialsManager.retrieveFromCloudflare();
assertThat(token.username()).isEqualTo("ABC"); wireMock.verify(postRequestedFor(urlEqualTo(GET_CREDENTIALS_PATH))
assertThat(token.password()).isEqualTo("XYZ"); .withHeader("Content-Type", equalTo("application/json"))
assertThat(token.hostname()).isEqualTo("localhost"); .withHeader("Authorization", equalTo("Bearer " + API_TOKEN))
assertThat(token.urlsWithIps()).containsAll(List.of("turn:127.0.0.1", "turn:127.0.0.1:80?transport=tcp", "turns:127.0.0.1:443?transport=tcp", "turn:[0:0:0:0:0:0:0:1]", "turn:[0:0:0:0:0:0:0:1]:80?transport=tcp", "turns:[0:0:0:0:0:0:0:1]:443?transport=tcp"));; .withRequestBody(equalToJson("""
assertThat(token.urls()).isEqualTo(List.of("turn:cf.example.com")); {
"ttl": %d
}
""".formatted(REQUESTED_CREDENTIAL_TTL.toSeconds()))));
assertThat(token.username()).isEqualTo(USERNAME);
assertThat(token.password()).isEqualTo(CREDENTIAL);
assertThat(token.hostname()).isEqualTo(TURN_HOSTNAME);
assertThat(token.urls()).isEqualTo(CLOUDFLARE_TURN_URLS);
assertThat(token.ttlSeconds()).isEqualTo(CLIENT_CREDENTIAL_TTL.toSeconds());
final List<String> expectedUrlsWithIps = new ArrayList<>();
for (final String ip : new String[] {"127.0.0.1", "[0:0:0:0:0:0:0:1]"}) {
for (final String pattern : IP_URL_PATTERNS) {
expectedUrlsWithIps.add(pattern.formatted(ip));
}
}
assertThat(token.urlsWithIps()).containsExactlyElementsOf(expectedUrlsWithIps);
} }
} }

View File

@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Status; import io.grpc.Status;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -22,7 +23,7 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
} }
@Test @Test
void interceptCall() { void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
@ -34,6 +35,10 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
} }
} }

View File

@ -9,6 +9,7 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -22,12 +23,12 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
} }
@Test @Test
void interceptCall() { void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice); GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
@ -35,5 +36,9 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice(); final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier()); assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId()); assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
} }
} }

View File

@ -40,14 +40,15 @@ class CallRoutingControllerV2Test {
private static final TurnToken CLOUDFLARE_TURN_TOKEN = new TurnToken( private static final TurnToken CLOUDFLARE_TURN_TOKEN = new TurnToken(
"ABC", "ABC",
"XYZ", "XYZ",
43_200,
List.of("turn:cloudflare.example.com:3478?transport=udp"), List.of("turn:cloudflare.example.com:3478?transport=udp"),
null, null,
"cf.example.com"); "cf.example.com");
private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter getCallEndpointLimiter = mock(RateLimiter.class); private static final RateLimiter getCallEndpointLimiter = mock(RateLimiter.class);
private static final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = mock( private static final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager =
CloudflareTurnCredentialsManager.class); mock(CloudflareTurnCredentialsManager.class);
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
@ -66,21 +67,14 @@ class CallRoutingControllerV2Test {
@AfterEach @AfterEach
void tearDown() { void tearDown() {
reset( rateLimiters, getCallEndpointLimiter); reset(rateLimiters, getCallEndpointLimiter);
}
void initializeMocksWith(TurnToken cloudflareToken) {
try {
when(cloudflareTurnCredentialsManager.retrieveFromCloudflare()).thenReturn(cloudflareToken);
} catch (IOException ignored) {
}
} }
@Test @Test
void testGetRelaysBothRouting() { void testGetRelaysBothRouting() throws IOException {
initializeMocksWith(CLOUDFLARE_TURN_TOKEN); when(cloudflareTurnCredentialsManager.retrieveFromCloudflare()).thenReturn(CLOUDFLARE_TURN_TOKEN);
try (Response rawResponse = resources.getJerseyTest() try (final Response rawResponse = resources.getJerseyTest()
.target(GET_CALL_RELAYS_PATH) .target(GET_CALL_RELAYS_PATH)
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
@ -88,11 +82,8 @@ class CallRoutingControllerV2Test {
assertThat(rawResponse.getStatus()).isEqualTo(200); assertThat(rawResponse.getStatus()).isEqualTo(200);
CallRoutingControllerV2.GetCallingRelaysResponse response = rawResponse.readEntity( assertThat(rawResponse.readEntity(GetCallingRelaysResponse.class).relays())
CallRoutingControllerV2.GetCallingRelaysResponse.class); .isEqualTo(List.of(CLOUDFLARE_TURN_TOKEN));
List<TurnToken> relays = response.relays();
assertThat(relays).isEqualTo(List.of(CLOUDFLARE_TURN_TOKEN));
} }
} }

View File

@ -292,7 +292,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -319,7 +319,19 @@ class MessageControllerTest {
IncomingMessageList.class), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE))) { MediaType.APPLICATION_JSON_TYPE))) {
assertThat(response.getStatus(), is(equalTo(sendToPni ? 403 : 200))); if (sendToPni) {
assertThat(response.getStatus(), is(equalTo(403)));
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
} else {
assertThat(response.getStatus(), is(equalTo(200)));
verify(messageSender).sendMessages(any(),
eq(serviceIdentifier),
any(),
any(),
eq(Optional.of(Device.PRIMARY_ID)),
any());
}
} }
} }
@ -337,7 +349,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -362,7 +374,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -400,7 +412,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -439,7 +451,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse))); assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
if (expectedResponse == 200) { if (expectedResponse == 200) {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -536,7 +548,7 @@ class MessageControllerTest {
@Test @Test
void testMultiDeviceMissing() throws Exception { void testMultiDeviceMissing() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet()))) doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any(), any()); .when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -558,7 +570,7 @@ class MessageControllerTest {
@Test @Test
void testMultiDeviceExtra() throws Exception { void testMultiDeviceExtra() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet()))) doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any(), any()); .when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -609,7 +621,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), any()); verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), eq(Optional.empty()), any());
assertEquals(3, envelopeCaptor.getValue().size()); assertEquals(3, envelopeCaptor.getValue().size());
@ -633,7 +645,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), any()); verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), eq(Optional.empty()), any());
assertEquals(3, envelopeCaptor.getValue().size()); assertEquals(3, envelopeCaptor.getValue().size());
@ -658,6 +670,7 @@ class MessageControllerTest {
any(), any(),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3), argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3),
any(), any(),
eq(Optional.empty()),
any()); any());
} }
} }
@ -665,7 +678,7 @@ class MessageControllerTest {
@Test @Test
void testRegistrationIdMismatch() throws Exception { void testRegistrationIdMismatch() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2)))) doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2))))
.when(messageSender).sendMessages(any(), any(), any(), any(), any()); .when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@ -1090,7 +1103,7 @@ class MessageControllerTest {
@Test @Test
void testValidateContentLength() throws MismatchedDevicesException, MessageTooLargeException, IOException { void testValidateContentLength() throws MismatchedDevicesException, MessageTooLargeException, IOException {
doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any()); doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -1119,10 +1132,10 @@ class MessageControllerTest {
if (expectOk) { if (expectOk) {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
verify(messageSender).sendMessages(any(), any(), any(), any(), any()); verify(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
} else { } else {
assertEquals(422, response.getStatus()); assertEquals(422, response.getStatus());
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
} }
} }
} }

View File

@ -1140,6 +1140,58 @@ class ProfileControllerTest {
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false), new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false)); new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false));
} }
}
@Test
void testSetProfileBadgeAfterUpdateTries() throws Exception {
final ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(
new ServiceId.Aci(AuthHelper.VALID_UUID));
final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156);
final String version = versionHex("anotherversion");
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
reset(accountsManager);
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(AuthHelper.VALID_ACCOUNT_TWO.getBadges())
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true)
))
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), true)
));
try (final Response response = resources.getJerseyTest()
.target("/v1/profile/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))
.put(Entity.entity(new CreateProfileRequest(commitment, version, name, emoji, about, null, false, false,
Optional.of(List.of("TEST1")), null), MediaType.APPLICATION_JSON_TYPE))) {
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.hasEntity()).isFalse();
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).getBadges();
final List<AccountBadge> badges = badgeCaptor.getValue();
assertThat(badges).isNotNull().hasSize(4).containsOnly(
new AccountBadge("TEST1", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), false));
}
} }
@ParameterizedTest @ParameterizedTest

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.dropwizard.core.Application; import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration; import io.dropwizard.core.Configuration;
@ -24,7 +25,6 @@ import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Client; import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import java.net.InetAddress;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.Set; import java.util.Set;
@ -39,6 +39,7 @@ import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.InetAddressRange; import org.whispersystems.textsecuregcm.util.InetAddressRange;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@ -157,7 +158,7 @@ class ExternalRequestFilterTest {
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRemoteAddress(InetAddress.getByName("127.0.0.1")); mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest") testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest")
.directExecutor() .directExecutor()

View File

@ -15,6 +15,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
@ -40,11 +41,10 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.grpc.StatusConstants; import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RemoteDeprecationFilterTest { class RemoteDeprecationFilterTest {
@ -130,11 +130,7 @@ class RemoteDeprecationFilterTest {
@MethodSource(value="testFilter") @MethodSource(value="testFilter")
void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException { void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor(); final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), userAgentString, null));
try {
mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString));
} catch (UnrecognizedUserAgentException ignored) {
}
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest") final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor() .directExecutor()

View File

@ -72,7 +72,8 @@ class AccountsAnonymousGrpcServiceTest extends
when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1")); getMockRequestAttributesInterceptor().setRequestAttributes(
new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
return new AccountsAnonymousGrpcService(accountsManager, rateLimiters); return new AccountsAnonymousGrpcService(accountsManager, rateLimiters);
} }

View File

@ -0,0 +1,88 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
class ChannelShutdownInterceptorTest {
private GrpcClientConnectionManager grpcClientConnectionManager;
private ChannelShutdownInterceptor channelShutdownInterceptor;
private ServerCallHandler<String, String> nextCallHandler;
private static final Metadata HEADERS = new Metadata();
@BeforeEach
void setUp() {
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager);
//noinspection unchecked
nextCallHandler = mock(ServerCallHandler.class);
//noinspection unchecked
when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class));
}
@Test
void interceptCallComplete() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onComplete();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallCancelled() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onCancel();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallChannelClosing() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false);
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall);
verify(serverCall).close(eq(Status.UNAVAILABLE), any());
}
}

View File

@ -12,14 +12,38 @@ import org.signal.chat.rpc.EchoServiceGrpc;
public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase { public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase {
@Override @Override
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) { public void echo(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build()); responseObserver.onNext(buildResponse(echoRequest));
responseObserver.onCompleted(); responseObserver.onCompleted();
} }
@Override @Override
public void echo2(EchoRequest req, StreamObserver<EchoResponse> responseObserver) { public void echo2(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build()); responseObserver.onNext(buildResponse(echoRequest));
responseObserver.onCompleted(); responseObserver.onCompleted();
} }
@Override
public StreamObserver<EchoRequest> echoStream(final StreamObserver<EchoResponse> responseObserver) {
return new StreamObserver<>() {
@Override
public void onNext(final EchoRequest echoRequest) {
responseObserver.onNext(buildResponse(echoRequest));
}
@Override
public void onError(final Throwable throwable) {
responseObserver.onError(throwable);
}
@Override
public void onCompleted() {
responseObserver.onCompleted();
}
};
}
private static EchoResponse buildResponse(final EchoRequest echoRequest) {
return EchoResponse.newBuilder().setPayload(echoRequest.getPayload()).build();
}
} }

View File

@ -0,0 +1,618 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.Mock;
import org.signal.chat.messages.AuthenticatedSenderMessageType;
import org.signal.chat.messages.ChallengeRequired;
import org.signal.chat.messages.IndividualRecipientMessageBundle;
import org.signal.chat.messages.MessagesGrpc;
import org.signal.chat.messages.MismatchedDevices;
import org.signal.chat.messages.SendAuthenticatedSenderMessageRequest;
import org.signal.chat.messages.SendMessageResponse;
import org.signal.chat.messages.SendSyncMessageRequest;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.spam.GrpcResponse;
import org.whispersystems.textsecuregcm.spam.MessageType;
import org.whispersystems.textsecuregcm.spam.SpamCheckResult;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
class MessagesGrpcServiceTest extends SimpleBaseGrpcTest<MessagesGrpcService, MessagesGrpc.MessagesBlockingStub> {
@Mock
private AccountsManager accountsManager;
@Mock
private RateLimiters rateLimiters;
@Mock
private MessageSender messageSender;
@Mock
private CardinalityEstimator messageByteLimitEstimator;
@Mock
private SpamChecker spamChecker;
@Mock
private RateLimiter rateLimiter;
@Mock
private Account authenticatedAccount;
@Mock
private Device authenticatedDevice;
@Mock
private Device linkedDevice;
@Mock
private Device secondLinkedDevice;
private static final int AUTHENTICATED_REGISTRATION_ID = 7;
private static final byte LINKED_DEVICE_ID = AUTHENTICATED_DEVICE_ID + 1;
private static final int LINKED_DEVICE_REGISTRATION_ID = 13;
private static final byte SECOND_LINKED_DEVICE_ID = LINKED_DEVICE_ID + 1;
private static final int SECOND_LINKED_DEVICE_REGISTRATION_ID = 19;
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
@Override
protected MessagesGrpcService createServiceBeforeEachTest() {
return new MessagesGrpcService(accountsManager,
rateLimiters,
messageSender,
messageByteLimitEstimator,
spamChecker,
CLOCK);
}
@BeforeEach
void setUp() {
when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.empty()));
when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
when(authenticatedDevice.getRegistrationId()).thenReturn(AUTHENTICATED_REGISTRATION_ID);
when(linkedDevice.getId()).thenReturn(LINKED_DEVICE_ID);
when(linkedDevice.getRegistrationId()).thenReturn(LINKED_DEVICE_REGISTRATION_ID);
when(secondLinkedDevice.getId()).thenReturn(SECOND_LINKED_DEVICE_ID);
when(secondLinkedDevice.getRegistrationId()).thenReturn(SECOND_LINKED_DEVICE_REGISTRATION_ID);
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI);
when(authenticatedAccount.getDevice(anyByte())).thenReturn(Optional.empty());
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice));
when(authenticatedAccount.getDevice(LINKED_DEVICE_ID)).thenReturn(Optional.of(linkedDevice));
when(authenticatedAccount.getDevice(SECOND_LINKED_DEVICE_ID)).thenReturn(Optional.of(secondLinkedDevice));
when(authenticatedAccount.getDevices()).thenReturn(List.of(authenticatedDevice, linkedDevice, secondLinkedDevice));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AUTHENTICATED_ACI)))
.thenReturn(Optional.of(authenticatedAccount));
}
@Nested
class SingleRecipient {
@CartesianTest
void sendMessage(@CartesianTest.Enum(mode = CartesianTest.Enum.Mode.EXCLUDE, names = {"UNSPECIFIED", "UNRECOGNIZED"}) final AuthenticatedSenderMessageType messageType,
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean includeReportSpamToken)
throws MessageTooLargeException, MismatchedDevicesException {
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 7;
final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId);
final Account destinationAccount = mock(Account.class);
when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice));
when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice));
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final byte[] reportSpamToken = TestRandomUtil.nextBytes(64);
if (includeReportSpamToken) {
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.of(reportSpamToken)));
}
final byte[] payload = TestRandomUtil.nextBytes(128);
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(registrationId)
.setPayload(ByteString.copyFrom(payload))
.build());
final SendMessageResponse response = authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, messageType, ephemeral, urgent, messages));
assertEquals(SendMessageResponse.newBuilder().build(), response);
final MessageProtos.Envelope.Type expectedEnvelopeType = switch (messageType) {
case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT;
case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE;
case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT;
case UNSPECIFIED, UNRECOGNIZED -> throw new IllegalArgumentException("Unexpected message type: " + messageType);
};
final MessageProtos.Envelope.Builder expectedEnvelopeBuilder = MessageProtos.Envelope.newBuilder()
.setType(expectedEnvelopeType)
.setSourceServiceId(new AciServiceIdentifier(AUTHENTICATED_ACI).toServiceIdentifierString())
.setSourceDevice(AUTHENTICATED_DEVICE_ID)
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setEphemeral(ephemeral)
.setUrgent(urgent)
.setContent(ByteString.copyFrom(payload));
if (includeReportSpamToken) {
expectedEnvelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken));
}
verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_IDENTIFIED_SENDER,
Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)),
Optional.of(destinationAccount),
serviceIdentifier);
verify(messageSender).sendMessages(destinationAccount,
serviceIdentifier,
Map.of(deviceId, expectedEnvelopeBuilder.build()),
Map.of(deviceId, registrationId),
Optional.empty(),
null);
}
@Test
void mismatchedDevices() throws MessageTooLargeException, MismatchedDevicesException {
final byte missingDeviceId = Device.PRIMARY_ID;
final byte extraDeviceId = missingDeviceId + 1;
final byte staleDeviceId = extraDeviceId + 1;
final Account destinationAccount = mock(Account.class);
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages = Map.of(
staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(Device.PRIMARY_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices(
Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId))))
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
final SendMessageResponse response = authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages));
final SendMessageResponse expectedResponse = SendMessageResponse.newBuilder()
.setMismatchedDevices(MismatchedDevices.newBuilder()
.setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier))
.addMissingDevices(missingDeviceId)
.addStaleDevices(staleDeviceId)
.addExtraDevices(extraDeviceId)
.build())
.build();
assertEquals(expectedResponse, response);
}
@Test
void destinationNotFound() throws MessageTooLargeException, MismatchedDevicesException {
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(Device.PRIMARY_ID, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(1234)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.NOT_FOUND,
() -> authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
}
@Test
void rateLimited() throws RateLimitExceededException, MessageTooLargeException, MismatchedDevicesException {
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 7;
final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId);
final Account destinationAccount = mock(Account.class);
when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice));
when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice));
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Duration retryDuration = Duration.ofHours(7);
doThrow(new RateLimitExceededException(retryDuration))
.when(rateLimiter).validate(eq(serviceIdentifier.uuid()), anyInt());
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(registrationId)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertRateLimitExceeded(retryDuration,
() -> authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
verify(messageByteLimitEstimator).add(serviceIdentifier.uuid().toString());
}
@Test
void oversizedMessage() throws MessageTooLargeException, MismatchedDevicesException {
final byte missingDeviceId = Device.PRIMARY_ID;
final byte extraDeviceId = missingDeviceId + 1;
final byte staleDeviceId = extraDeviceId + 1;
final Account destinationAccount = mock(Account.class);
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages = Map.of(
staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(Device.PRIMARY_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
doThrow(new MessageTooLargeException())
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT,
() -> authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
}
@Test
void spamWithStatus() throws MessageTooLargeException, MismatchedDevicesException {
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 7;
final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId);
final Account destinationAccount = mock(Account.class);
when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice));
when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice));
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(registrationId)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withStatus(Status.RESOURCE_EXHAUSTED)), Optional.empty()));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.RESOURCE_EXHAUSTED,
() -> authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_IDENTIFIED_SENDER,
Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)),
Optional.of(destinationAccount),
serviceIdentifier);
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
}
@Test
void spamWithResponse() throws MessageTooLargeException, MismatchedDevicesException {
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 7;
final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId);
final Account destinationAccount = mock(Account.class);
when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice));
when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice));
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(registrationId)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
final SendMessageResponse response = SendMessageResponse.newBuilder()
.setChallengeRequired(ChallengeRequired.newBuilder()
.addChallengeOptions(ChallengeRequired.ChallengeType.CAPTCHA))
.build();
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withResponse(response)), Optional.empty()));
assertEquals(response, authenticatedServiceStub().sendMessage(
generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)));
verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_IDENTIFIED_SENDER,
Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)),
Optional.of(destinationAccount),
serviceIdentifier);
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
}
private static SendAuthenticatedSenderMessageRequest generateRequest(final ServiceIdentifier serviceIdentifier,
final AuthenticatedSenderMessageType messageType,
final boolean ephemeral,
final boolean urgent,
final Map<Byte, IndividualRecipientMessageBundle.Message> messages) {
final IndividualRecipientMessageBundle.Builder messageBundleBuilder = IndividualRecipientMessageBundle.newBuilder()
.setTimestamp(CLOCK.millis());
messages.forEach(messageBundleBuilder::putMessages);
final SendAuthenticatedSenderMessageRequest.Builder requestBuilder = SendAuthenticatedSenderMessageRequest.newBuilder()
.setDestination(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier))
.setType(messageType)
.setMessages(messageBundleBuilder)
.setEphemeral(ephemeral)
.setUrgent(urgent);
return requestBuilder.build();
}
}
@Nested
class Sync {
@CartesianTest
void sendMessage(@CartesianTest.Enum(mode = CartesianTest.Enum.Mode.EXCLUDE, names = {"UNSPECIFIED", "UNRECOGNIZED"}) final AuthenticatedSenderMessageType messageType,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean includeReportSpamToken)
throws MessageTooLargeException, MismatchedDevicesException {
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(AUTHENTICATED_ACI);
final byte[] payload = TestRandomUtil.nextBytes(128);
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(LINKED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(LINKED_DEVICE_REGISTRATION_ID)
.setPayload(ByteString.copyFrom(payload))
.build(),
SECOND_LINKED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(SECOND_LINKED_DEVICE_REGISTRATION_ID)
.setPayload(ByteString.copyFrom(payload))
.build());
final byte[] reportSpamToken = TestRandomUtil.nextBytes(64);
if (includeReportSpamToken) {
when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any()))
.thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.of(reportSpamToken)));
}
final SendMessageResponse response =
authenticatedServiceStub().sendSyncMessage(generateRequest(messageType, urgent, messages));
assertEquals(SendMessageResponse.newBuilder().build(), response);
final MessageProtos.Envelope.Type expectedEnvelopeType = switch (messageType) {
case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT;
case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE;
case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT;
case UNSPECIFIED, UNRECOGNIZED -> throw new IllegalArgumentException("Unexpected message type: " + messageType);
};
final Map<Byte, MessageProtos.Envelope> expectedEnvelopes = new HashMap<>(Map.of(
LINKED_DEVICE_ID, MessageProtos.Envelope.newBuilder()
.setType(expectedEnvelopeType)
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(AUTHENTICATED_DEVICE_ID)
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setEphemeral(false)
.setUrgent(urgent)
.setContent(ByteString.copyFrom(payload))
.build(),
SECOND_LINKED_DEVICE_ID, MessageProtos.Envelope.newBuilder()
.setType(expectedEnvelopeType)
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(AUTHENTICATED_DEVICE_ID)
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setEphemeral(false)
.setUrgent(urgent)
.setContent(ByteString.copyFrom(payload))
.build()
));
if (includeReportSpamToken) {
expectedEnvelopes.replaceAll((deviceId, envelope) ->
envelope.toBuilder().setReportSpamToken(ByteString.copyFrom(reportSpamToken)).build());
}
verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.SYNC,
Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)),
Optional.of(authenticatedAccount),
serviceIdentifier);
verify(messageSender).sendMessages(authenticatedAccount,
serviceIdentifier,
expectedEnvelopes,
Map.of(LINKED_DEVICE_ID, LINKED_DEVICE_REGISTRATION_ID,
SECOND_LINKED_DEVICE_ID, SECOND_LINKED_DEVICE_REGISTRATION_ID),
Optional.of(AUTHENTICATED_DEVICE_ID),
null);
}
@Test
void mismatchedDevices() throws MessageTooLargeException, MismatchedDevicesException {
final byte missingDeviceId = Device.PRIMARY_ID;
final byte extraDeviceId = missingDeviceId + 1;
final byte staleDeviceId = extraDeviceId + 1;
final Map<Byte, IndividualRecipientMessageBundle.Message> messages = Map.of(
staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(Device.PRIMARY_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices(
Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId))))
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
final SendMessageResponse response = authenticatedServiceStub().sendSyncMessage(
generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages));
final SendMessageResponse expectedResponse = SendMessageResponse.newBuilder()
.setMismatchedDevices(MismatchedDevices.newBuilder()
.setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(new AciServiceIdentifier(AUTHENTICATED_ACI)))
.addMissingDevices(missingDeviceId)
.addStaleDevices(staleDeviceId)
.addExtraDevices(extraDeviceId)
.build())
.build();
assertEquals(expectedResponse, response);
}
@Test
void rateLimited() throws RateLimitExceededException, MessageTooLargeException, MismatchedDevicesException {
final Duration retryDuration = Duration.ofHours(7);
doThrow(new RateLimitExceededException(retryDuration))
.when(rateLimiter).validate(eq(AUTHENTICATED_ACI), anyInt());
final Map<Byte, IndividualRecipientMessageBundle.Message> messages =
Map.of(AUTHENTICATED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(AUTHENTICATED_REGISTRATION_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertRateLimitExceeded(retryDuration,
() -> authenticatedServiceStub().sendSyncMessage(
generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages)));
verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any());
verify(messageByteLimitEstimator).add(AUTHENTICATED_ACI.toString());
}
@Test
void oversizedMessage() throws MessageTooLargeException, MismatchedDevicesException {
final byte missingDeviceId = Device.PRIMARY_ID;
final byte extraDeviceId = missingDeviceId + 1;
final byte staleDeviceId = extraDeviceId + 1;
final Account destinationAccount = mock(Account.class);
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount));
final Map<Byte, IndividualRecipientMessageBundle.Message> messages = Map.of(
staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder()
.setRegistrationId(Device.PRIMARY_ID)
.setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128)))
.build());
doThrow(new MessageTooLargeException())
.when(messageSender).sendMessages(any(), any(), any(), any(), any(), any());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT,
() -> authenticatedServiceStub().sendSyncMessage(
generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages)));
}
private static SendSyncMessageRequest generateRequest(final AuthenticatedSenderMessageType messageType,
final boolean urgent,
final Map<Byte, IndividualRecipientMessageBundle.Message> messages) {
final IndividualRecipientMessageBundle.Builder messageBundleBuilder = IndividualRecipientMessageBundle.newBuilder()
.setTimestamp(CLOCK.millis());
messages.forEach(messageBundleBuilder::putMessages);
final SendSyncMessageRequest.Builder requestBuilder = SendSyncMessageRequest.newBuilder()
.setType(messageType)
.setMessages(messageBundleBuilder)
.setUrgent(urgent);
return requestBuilder.build();
}
}
}

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import com.google.common.net.InetAddresses;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Contexts; import io.grpc.Contexts;
import io.grpc.Metadata; import io.grpc.Metadata;
@ -19,25 +20,10 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class MockRequestAttributesInterceptor implements ServerInterceptor { public class MockRequestAttributesInterceptor implements ServerInterceptor {
@Nullable private RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null);
private InetAddress remoteAddress;
@Nullable public void setRequestAttributes(final RequestAttributes requestAttributes) {
private UserAgent userAgent; this.requestAttributes = requestAttributes;
@Nullable
private List<Locale.LanguageRange> acceptLanguage;
public void setRemoteAddress(@Nullable final InetAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
public void setUserAgent(@Nullable final UserAgent userAgent) {
this.userAgent = userAgent;
}
public void setAcceptLanguage(@Nullable final List<Locale.LanguageRange> acceptLanguage) {
this.acceptLanguage = acceptLanguage;
} }
@Override @Override
@ -45,20 +31,7 @@ public class MockRequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers, final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) { final ServerCallHandler<ReqT, RespT> next) {
Context context = Context.current(); return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes), serverCall, headers, next);
if (remoteAddress != null) {
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress);
}
if (userAgent != null) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent);
}
if (acceptLanguage != null) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage);
}
return Contexts.interceptCall(context, serverCall, headers, next);
} }
} }

View File

@ -15,6 +15,7 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -75,8 +76,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> { public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@ -96,13 +95,9 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
@Override @Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() { protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us")); getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
try { Locale.LanguageRange.parse("en-us")));
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
return new ProfileAnonymousGrpcService( return new ProfileAnonymousGrpcService(
accountsManager, accountsManager,

View File

@ -15,13 +15,16 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.refEq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber; import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
@ -30,6 +33,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -93,18 +97,18 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceCapability; import org.whispersystems.textsecuregcm.storage.DeviceCapability;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@ -144,6 +148,8 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
@Mock @Mock
private ServerZkProfileOperations serverZkProfileOperations; private ServerZkProfileOperations serverZkProfileOperations;
private Clock clock;
@Override @Override
protected ProfileGrpcService createServiceBeforeEachTest() { protected ProfileGrpcService createServiceBeforeEachTest() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class); @SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@ -170,13 +176,9 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164); PhoneNumberUtil.PhoneNumberFormat.E164);
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us")); getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
try { Locale.LanguageRange.parse("en-us")));
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter); when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());
@ -203,8 +205,10 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null)); when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
clock = Clock.fixed(Instant.ofEpochSecond(42), ZoneId.of("Etc/UTC"));
return new ProfileGrpcService( return new ProfileGrpcService(
Clock.systemUTC(), clock,
accountsManager, accountsManager,
profilesManager, profilesManager,
dynamicConfigurationManager, dynamicConfigurationManager,
@ -392,6 +396,42 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
} }
} }
@Test
void setProfileBadges() throws InvalidInputException {
final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(AUTHENTICATED_ACI)).serialize();
final SetProfileRequest request = SetProfileRequest.newBuilder()
.setVersion(VERSION)
.setName(ByteString.copyFrom(VALID_NAME))
.setAvatarChange(AvatarChange.AVATAR_CHANGE_UNCHANGED)
.addAllBadgeIds(List.of("TEST3"))
.setCommitment(ByteString.copyFrom(commitment))
.build();
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(account.getBadges())
.thenReturn(List.of(new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)))
.thenReturn(List.of(new AccountBadge("TEST2", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)));
//noinspection ResultOfMethodCallIgnored
authenticatedServiceStub().setProfile(request);
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(account, times(2)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(account, times(accountsManagerUpdateRetryCount)).getBadges();
assertEquals(List.of(
new AccountBadge("TEST3", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(41), false)),
badgeCaptor.getValue());
}
@ParameterizedTest @ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"}) @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void getUnversionedProfile(final IdentityType identityType) { void getUnversionedProfile(final IdentityType identityType) {

View File

@ -6,7 +6,6 @@ import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest; import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse; import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc; import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.chat.rpc.UserAgent;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
@ -19,21 +18,15 @@ public class RequestAttributesServiceImpl extends RequestAttributesGrpc.RequestA
final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder(); final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder();
RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages -> RequestAttributesUtil.getAcceptableLanguages()
acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()))); .forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()));
RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale -> RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale ->
responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag())); responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag()));
responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress()); responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress());
RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder() RequestAttributesUtil.getUserAgent().ifPresent(responseBuilder::setUserAgent);
.setPlatform(userAgent.getPlatform().toString())
.setVersion(userAgent.getVersion().toString())
.setAdditionalSpecifiers(userAgent.getAdditionalSpecifiers().orElse(""))
.build()));
RequestAttributesUtil.getRawUserAgent().ifPresent(responseBuilder::setRawUserAgent);
responseObserver.onNext(responseBuilder.build()); responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted(); responseObserver.onCompleted();

View File

@ -3,172 +3,84 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import io.grpc.ManagedChannel; import io.grpc.Context;
import io.grpc.Server; import java.net.InetAddress;
import io.grpc.Status; import java.util.Collections;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Optional; import java.util.Optional;
import org.junit.jupiter.api.AfterAll; import java.util.concurrent.Callable;
import org.junit.jupiter.api.AfterEach; import javax.annotation.Nullable;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RequestAttributesUtilTest { class RequestAttributesUtilTest {
private static DefaultEventLoopGroup eventLoopGroup; private static final InetAddress REMOTE_ADDRESS = InetAddresses.forString("127.0.0.1");
private GrpcClientConnectionManager grpcClientConnectionManager; @Test
void getAcceptableLanguages() throws Exception {
assertEquals(Collections.emptyList(),
callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAcceptableLanguages));
private Server server; assertEquals(Locale.LanguageRange.parse("en,ja"),
private ManagedChannel managedChannel; callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAcceptableLanguages));
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server");
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString("127.0.0.1")));
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
} }
@Test @Test
void getAcceptableLanguages() { void getAvailableAcceptedLocales() throws Exception {
when(grpcClientConnectionManager.getAcceptableLanguages(any())) assertEquals(Collections.emptyList(),
.thenReturn(Optional.empty()); callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAvailableAcceptedLocales));
assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty()); final List<Locale> availableAcceptedLocales =
callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAvailableAcceptedLocales);
when(grpcClientConnectionManager.getAcceptableLanguages(any())) assertFalse(availableAcceptedLocales.isEmpty());
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList()); availableAcceptedLocales.forEach(locale ->
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage())));
} }
@Test @Test
void getAvailableAcceptedLocales() { void getRemoteAddress() throws Exception {
when(grpcClientConnectionManager.getAcceptableLanguages(any())) assertEquals(REMOTE_ADDRESS,
.thenReturn(Optional.empty()); callWithRequestAttributes(new RequestAttributes(REMOTE_ADDRESS, null, null),
RequestAttributesUtil::getRemoteAddress));
assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty());
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
final GetRequestAttributesResponse response = getRequestAttributes();
assertFalse(response.getAvailableAcceptedLocalesList().isEmpty());
response.getAvailableAcceptedLocalesList().forEach(languageTag -> {
final Locale locale = Locale.forLanguageTag(languageTag);
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()));
});
} }
@Test @Test
void getRemoteAddress() { void getUserAgent() throws Exception {
when(grpcClientConnectionManager.getRemoteAddress(any())) assertEquals(Optional.empty(),
.thenReturn(Optional.empty()); callWithRequestAttributes(buildRequestAttributes((String) null),
RequestAttributesUtil::getUserAgent));
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes); assertEquals(Optional.of("Signal-Desktop/1.2.3 Linux"),
callWithRequestAttributes(buildRequestAttributes("Signal-Desktop/1.2.3 Linux"),
final String remoteAddressString = "6.7.8.9"; RequestAttributesUtil::getUserAgent));
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString(remoteAddressString)));
assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress());
} }
@Test private static <V> V callWithRequestAttributes(final RequestAttributes requestAttributes, final Callable<V> callable) throws Exception {
void getUserAgent() throws UnrecognizedUserAgentException { return Context.current()
when(grpcClientConnectionManager.getUserAgent(any())) .withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes)
.thenReturn(Optional.empty()); .call(callable);
assertFalse(getRequestAttributes().hasUserAgent());
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
when(grpcClientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.of(userAgent));
final GetRequestAttributesResponse response = getRequestAttributes();
assertTrue(response.hasUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} }
@Test private static RequestAttributes buildRequestAttributes(final String userAgent) {
void getRawUserAgent() { return buildRequestAttributes(userAgent, Collections.emptyList());
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getRawUserAgent().isBlank());
final String userAgentString = "Signal-Desktop/1.2.3 Linux";
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.of(userAgentString));
assertEquals(userAgentString, getRequestAttributes().getRawUserAgent());
} }
private GetRequestAttributesResponse getRequestAttributes() { private static RequestAttributes buildRequestAttributes(final List<Locale.LanguageRange> acceptLanguage) {
return RequestAttributesGrpc.newBlockingStub(managedChannel) return buildRequestAttributes(null, acceptLanguage);
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()); }
private static RequestAttributes buildRequestAttributes(@Nullable final String userAgent,
final List<Locale.LanguageRange> acceptLanguage) {
return new RequestAttributes(REMOTE_ADDRESS, userAgent, acceptLanguage);
} }
} }

View File

@ -1,7 +1,11 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -12,6 +16,12 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel; import io.netty.channel.local.LocalServerChannel;
import java.net.InetAddress;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
@ -21,20 +31,9 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
class GrpcClientConnectionManagerTest { class GrpcClientConnectionManagerTest {
@ -103,7 +102,7 @@ class GrpcClientConnectionManagerTest {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
assertEquals(maybeAuthenticatedDevice, assertEquals(maybeAuthenticatedDevice,
grpcClientConnectionManager.getAuthenticatedDevice(localChannel.localAddress())); grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel));
} }
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() { private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
@ -114,170 +113,115 @@ class GrpcClientConnectionManagerTest {
} }
@Test @Test
void getAcceptableLanguages() { void getRequestAttributes() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(), assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel));
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
final List<Locale.LanguageRange> acceptLanguageRanges = Locale.LanguageRange.parse("en,ja"); final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null);
remoteChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges); remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes);
assertEquals(Optional.of(acceptLanguageRanges), assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel));
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
} }
@Test @Test
void getRemoteAddress() { void closeConnection() throws InterruptedException, ChannelNotFoundException {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9");
remoteChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress);
assertEquals(Optional.of(remoteAddress),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
remoteChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent);
assertEquals(Optional.of(userAgent),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
}
@Test
void closeConnection() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertTrue(remoteChannel.isOpen()); assertTrue(remoteChannel.isOpen());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), assertEquals(List.of(remoteChannel),
grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await(); remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
} }
@Test @ParameterizedTest
void handleWebSocketHandshakeCompleteRemoteAddress() { @MethodSource
void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress,
final String userAgentHeader,
final String acceptLanguageHeader,
final RequestAttributes expectedRequestAttributes) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1"); GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel,
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
preferredRemoteAddress, preferredRemoteAddress,
null,
null);
assertEquals(preferredRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader,
@Nullable final UserAgent expectedParsedUserAgent) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
userAgentHeader, userAgentHeader,
null);
assertEquals(userAgentHeader,
embeddedChannel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get());
assertEquals(expectedParsedUserAgent,
embeddedChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteUserAgent() {
return List.of(
// Recognized user-agent
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
// Unrecognized user-agent
Arguments.of("Not a valid user-agent string", null),
// Missing user-agent
Arguments.of(null, null)
);
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader,
@Nullable final List<Locale.LanguageRange> expectedLanguageRanges) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
null,
acceptLanguageHeader); acceptLanguageHeader);
assertEquals(expectedLanguageRanges, assertEquals(expectedRequestAttributes,
embeddedChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get()); embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
} }
private static List<Arguments> handleWebSocketHandshakeCompleteAcceptLanguage() { private static List<Arguments> handleHandshakeCompleteRequestAttributes() {
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
return List.of( return List.of(
// Parseable list Arguments.argumentSet("Null User-Agent and Accept-Language headers",
Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")), preferredRemoteAddress, null, null,
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())),
// Unparsable list Arguments.argumentSet("Recognized User-Agent and null Accept-Language header",
Arguments.of("This is not a valid language preference list", null), preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null,
new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())),
// Missing list Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header",
Arguments.of(null, null) preferredRemoteAddress, "Not a valid user-agent string", null,
new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())),
Arguments.argumentSet("Null User-Agent and parsable Accept-Language header",
preferredRemoteAddress, null, "ja,en;q=0.4",
new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))),
Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header",
preferredRemoteAddress, null, "This is not a valid language preference list",
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList()))
); );
} }
@Test @Test
void handleConnectionEstablishedAuthenticated() throws InterruptedException { void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await(); remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
} }
@Test @Test
void handleConnectionEstablishedAnonymous() throws InterruptedException { void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException {
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
remoteChannel.close().await(); remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress())); assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
} }
} }

View File

@ -1,6 +1,7 @@
package org.whispersystems.textsecuregcm.grpc.net; package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
@ -8,10 +9,12 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder; import io.grpc.ServerBuilder;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
@ -61,6 +64,9 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest; import org.signal.chat.rpc.GetRequestAttributesRequest;
@ -71,6 +77,8 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
@ -83,6 +91,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private static NioEventLoopGroup nioEventLoopGroup; private static NioEventLoopGroup nioEventLoopGroup;
private static DefaultEventLoopGroup defaultEventLoopGroup; private static DefaultEventLoopGroup defaultEventLoopGroup;
private static ExecutorService delegatedTaskExecutor; private static ExecutorService delegatedTaskExecutor;
private static ExecutorService serverCallExecutor;
private static X509Certificate serverTlsCertificate; private static X509Certificate serverTlsCertificate;
@ -136,7 +145,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
static void setUpBeforeAll() throws CertificateException { static void setUpBeforeAll() throws CertificateException {
nioEventLoopGroup = new NioEventLoopGroup(); nioEventLoopGroup = new NioEventLoopGroup();
defaultEventLoopGroup = new DefaultEventLoopGroup(); defaultEventLoopGroup = new DefaultEventLoopGroup();
delegatedTaskExecutor = Executors.newSingleThreadExecutor(); delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate( serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
@ -171,7 +181,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override @Override
protected void configureServer(final ServerBuilder<?> serverBuilder) { protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new RequestAttributesServiceImpl()) serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.addService(new EchoServiceImpl())
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager)); .intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
} }
@ -182,7 +196,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override @Override
protected void configureServer(final ServerBuilder<?> serverBuilder) { protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new RequestAttributesServiceImpl()) serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager)); .intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
} }
@ -195,7 +211,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
serverTlsPrivateKey, serverTlsPrivateKey,
nioEventLoopGroup, nioEventLoopGroup,
delegatedTaskExecutor, delegatedTaskExecutor,
grpcClientConnectionManager, grpcClientConnectionManager,
clientPublicKeysManager, clientPublicKeysManager,
serverKeyPair, serverKeyPair,
authenticatedGrpcServerAddress, authenticatedGrpcServerAddress,
@ -209,7 +225,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
null, null,
nioEventLoopGroup, nioEventLoopGroup,
delegatedTaskExecutor, delegatedTaskExecutor,
grpcClientConnectionManager, grpcClientConnectionManager,
clientPublicKeysManager, clientPublicKeysManager,
serverKeyPair, serverKeyPair,
authenticatedGrpcServerAddress, authenticatedGrpcServerAddress,
@ -235,6 +251,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
delegatedTaskExecutor.shutdown(); delegatedTaskExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS); delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
serverCallExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
} }
@ParameterizedTest @ParameterizedTest
@ -523,10 +543,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
assertEquals(remoteAddress, response.getRemoteAddress()); assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList()); assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals(userAgent, response.getUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} finally { } finally {
channel.shutdown(); channel.shutdown();
} }
@ -582,6 +599,89 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
} }
} }
@Test
void waitForCallCompletion() throws InterruptedException {
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
final AtomicBoolean closedByServer = new AtomicBoolean(false);
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(false);
connectionCloseLatch.countDown();
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(true);
connectionCloseLatch.countDown();
}
};
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
// Start an open-ended server call and leave it in a non-complete state
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
new StreamObserver<>() {
@Override
public void onNext(final EchoResponse echoResponse) {
responseCountDownLatch.countDown();
}
@Override
public void onError(final Throwable throwable) {
}
@Override
public void onCompleted() {
}
});
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
// truly started before requesting connection closure.
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
assertFalse(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should not close until active requests have finished");
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
// Complete the open-ended server call
echoRequestStreamObserver.onCompleted();
assertTrue(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should close once active requests have finished");
assertTrue(closedByServer.get());
assertEquals(4004, serverCloseStatusCode.get());
} finally {
channel.shutdown();
}
}
}
private NoiseWebSocketTunnelClient.Builder anonymous() { private NoiseWebSocketTunnelClient.Builder anonymous() {
return new NoiseWebSocketTunnelClient return new NoiseWebSocketTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())

View File

@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -16,6 +17,7 @@ import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.Attribute;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
@ -31,6 +33,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
@ -134,8 +137,13 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
embeddedChannel.setRemoteAddress(remoteAddress); embeddedChannel.setRemoteAddress(remoteAddress);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertEquals(expectedRemoteAddress, assertEquals(expectedRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get()); Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY))
.map(Attribute::get)
.map(RequestAttributes::remoteAddress)
.orElse(null));
} }
private static List<Arguments> getRemoteAddress() { private static List<Arguments> getRemoteAddress() {
@ -144,53 +152,53 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1"); final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
return List.of( return List.of(
// Recognized proxy, single forwarded-for address argumentSet("Recognized proxy, single forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
clientAddress), clientAddress),
// Recognized proxy, multiple forwarded-for addresses argumentSet("Recognized proxy, multiple forwarded-for addresses",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
remoteAddress, remoteAddress,
proxyAddress), proxyAddress),
// No recognized proxy header, single forwarded-for address argumentSet("No recognized proxy header, single forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// No recognized proxy header, no forwarded-for address argumentSet("No recognized proxy header, no forwarded-for address",
Arguments.of(new DefaultHttpHeaders(), new DefaultHttpHeaders(),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// Incorrect proxy header, single forwarded-for address argumentSet("Incorrect proxy header, single forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect") .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// Recognized proxy, no forwarded-for address argumentSet("Recognized proxy, no forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
remoteAddress, remoteAddress,
remoteAddress.getAddress()), remoteAddress.getAddress()),
// Recognized proxy, bogus forwarded-for address argumentSet("Recognized proxy, bogus forwarded-for address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"), .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
remoteAddress, remoteAddress,
null), null),
// No forwarded-for address, non-InetSocketAddress remote address argumentSet("No forwarded-for address, non-InetSocketAddress remote address",
Arguments.of(new DefaultHttpHeaders() new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
new LocalAddress("local-address"), new LocalAddress("local-address"),
null) null)

View File

@ -104,6 +104,7 @@ class MessageSenderTest {
serviceIdentifier, serviceIdentifier,
Map.of(device.getId(), message), Map.of(device.getId(), message),
Map.of(device.getId(), registrationId), Map.of(device.getId(), registrationId),
Optional.empty(),
null)); null));
final MessageProtos.Envelope expectedMessage = ephemeral final MessageProtos.Envelope expectedMessage = ephemeral
@ -144,6 +145,7 @@ class MessageSenderTest {
serviceIdentifier, serviceIdentifier,
Map.of(device.getId(), message), Map.of(device.getId(), message),
Map.of(device.getId(), registrationId + 1), Map.of(device.getId(), registrationId + 1),
Optional.empty(),
null)); null));
assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)), assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)),
@ -344,4 +346,64 @@ class MessageSenderTest {
Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId)))) Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId))))
); );
} }
@Test
void sendMessageEmptyMessageList() {
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(device));
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account,
serviceIdentifier,
Collections.emptyMap(),
Collections.emptyMap(),
Optional.empty(),
null));
assertDoesNotThrow(() -> messageSender.sendMessages(account,
serviceIdentifier,
Collections.emptyMap(),
Collections.emptyMap(),
Optional.of(Device.PRIMARY_ID),
null));
}
@Test
void sendSyncMessageMismatchedAddressing() {
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
final Account nonSyncDestination = mock(Account.class);
when(nonSyncDestination.isIdentifiedBy(any())).thenReturn(true);
assertThrows(IllegalArgumentException.class, () -> messageSender.sendMessages(nonSyncDestination,
new AciServiceIdentifier(UUID.randomUUID()),
Map.of(deviceId, MessageProtos.Envelope.newBuilder().build()),
Map.of(deviceId, 17),
Optional.of(deviceId),
null),
"Should throw an IllegalArgumentException for inter-account messages with a sync message device ID");
assertThrows(IllegalArgumentException.class, () -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(deviceId, MessageProtos.Envelope.newBuilder()
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(deviceId)
.build()),
Map.of(deviceId, 17),
Optional.empty(),
null),
"Should throw an IllegalArgumentException for self-addressed messages without a sync message device ID");
}
} }

View File

@ -6,8 +6,10 @@ import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import java.time.Clock; import java.time.Clock;
import java.time.LocalDateTime;
import java.time.LocalTime; import java.time.LocalTime;
import java.time.ZoneId; import java.time.ZoneId;
import java.time.ZoneOffset;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -65,4 +67,26 @@ class SchedulingUtilTest {
Clock.fixed(afterNotificationTime.toInstant(), ZoneId.systemDefault()))); Clock.fixed(afterNotificationTime.toInstant(), ZoneId.systemDefault())));
} }
} }
@Test
void getNextRecommendedNotificationTimeDaylightSavings() {
final Account account = mock(Account.class);
// The account has a phone number that can be resolved to a region with known timezones
when(account.getNumber()).thenReturn(PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("DE"), PhoneNumberUtil.PhoneNumberFormat.E164));
final LocalDateTime afterNotificationTime = LocalDateTime.of(2025, 3, 29, 15, 0);
final ZoneId berlinZoneId = ZoneId.of("Europe/Berlin");
final ZoneOffset berlineZoneOffset = berlinZoneId.getRules().getOffset(afterNotificationTime);
// Daylight Savings Time started on 2025-03-30 at 2:00AM in Germany.
// Instantiating a ZonedDateTime with a zone ID factors in daylight savings when we adjust the time.
final ZonedDateTime afterNotificationTimeWithZoneId = ZonedDateTime.of(afterNotificationTime, berlinZoneId);
assertEquals(
afterNotificationTimeWithZoneId.with(LocalTime.of(14, 0)).plusDays(1).toInstant(),
SchedulingUtil.getNextRecommendedNotificationTime(account, LocalTime.of(14, 0),
Clock.fixed(afterNotificationTime.toInstant(berlineZoneOffset), berlinZoneId)));
}
} }

View File

@ -19,7 +19,6 @@ import static org.whispersystems.textsecuregcm.tests.util.DevicesHelper.createDe
import com.fasterxml.jackson.annotation.JsonFilter; import com.fasterxml.jackson.annotation.JsonFilter;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -187,7 +186,7 @@ class AccountTest {
@Test @Test
void addAndRemoveBadges() { void addAndRemoveBadges() {
final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), List.of(createDevice(Device.PRIMARY_ID)), new byte[0]); final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), List.of(createDevice(Device.PRIMARY_ID)), new byte[0]);
final Clock clock = TestClock.pinned(Instant.ofEpochSecond(40)); final TestClock clock = TestClock.pinned(Instant.ofEpochSecond(40));
account.addBadge(clock, new AccountBadge("foo", Instant.ofEpochSecond(42), false)); account.addBadge(clock, new AccountBadge("foo", Instant.ofEpochSecond(42), false));
account.addBadge(clock, new AccountBadge("bar", Instant.ofEpochSecond(44), true)); account.addBadge(clock, new AccountBadge("bar", Instant.ofEpochSecond(44), true));
@ -214,6 +213,17 @@ class AccountTest {
assertThat(badge.expiration().getEpochSecond()).isEqualTo(51); assertThat(badge.expiration().getEpochSecond()).isEqualTo(51);
assertThat(badge.visible()).isTrue(); assertThat(badge.visible()).isTrue();
}); });
clock.pin(Instant.ofEpochSecond(52));
// for a merged badge, visible = true is preferred
account.addBadge(clock, new AccountBadge("foo", Instant.ofEpochSecond(53), false));
assertThat(account.getBadges()).hasSize(1).element(0).satisfies(badge -> {
assertThat(badge.id()).isEqualTo("foo");
assertThat(badge.expiration().getEpochSecond()).isEqualTo(53);
assertThat(badge.visible()).isTrue();
});
} }
@Test @Test

View File

@ -9,13 +9,16 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -35,8 +38,10 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
public class ChangeNumberManagerTest { public class ChangeNumberManagerTest {
private AccountsManager accountsManager; private AccountsManager accountsManager;
@ -45,11 +50,13 @@ public class ChangeNumberManagerTest {
private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount; private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount;
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
accountsManager = mock(AccountsManager.class); accountsManager = mock(AccountsManager.class);
messageSender = mock(MessageSender.class); messageSender = mock(MessageSender.class);
changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, CLOCK);
updatedPhoneNumberIdentifiersByAccount = new HashMap<>(); updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
@ -103,7 +110,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null); changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any()); verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any(), any());
} }
@Test @Test
@ -117,7 +124,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null);
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any()); verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any(), any());
} }
@Test @Test
@ -132,45 +139,59 @@ public class ChangeNumberManagerTest {
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni); when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class); final Device primaryDevice = mock(Device.class);
final byte deviceId2 = 2; when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(d2.getId()).thenReturn(deviceId2); when(primaryDevice.getRegistrationId()).thenReturn(7);
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); final Device linkedDevice = mock(Device.class);
when(account.getDevices()).thenReturn(List.of(d2)); final byte linkedDeviceId = Device.PRIMARY_ID + 1;
final int linkedDeviceRegistrationId = 17;
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID, final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair), KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); linkedDeviceId, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, linkedDeviceId, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2); when(msg.type()).thenReturn(1);
when(msg.destinationDeviceId()).thenReturn(linkedDeviceId);
when(msg.destinationRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(msg.content()).thenReturn(new byte[]{1}); when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds); verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder()
ArgumentCaptor.forClass(Map.class); .setType(MessageProtos.Envelope.Type.forNumber(msg.type()))
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setDestinationServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setContent(ByteString.copyFrom(msg.content()))
.setSourceServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(updatedPhoneNumberIdentifiersByAccount.get(account).toString())
.setUrgent(true)
.setEphemeral(false)
.build();
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); verify(messageSender).sendMessages(argThat(a -> a.getUuid().equals(aci)),
eq(new AciServiceIdentifier(aci)),
assertEquals(1, envelopeCaptor.getValue().size()); eq(Map.of(linkedDeviceId, expectedEnvelope)),
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); eq(Map.of(linkedDeviceId, linkedDeviceRegistrationId)),
eq(Optional.of(Device.PRIMARY_ID)),
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); any());
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
} }
@Test @Test
void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final String originalE164 = "+18005551234"; final String originalE164 = "+18005551234";
@ -210,7 +231,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -261,7 +282,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -308,7 +329,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -357,7 +378,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());

View File

@ -257,7 +257,7 @@ class MessagePersisterTest {
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
assertThrows(MessagePersistenceException.class, assertThrows(MessagePersistenceException.class,
() -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE))); () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test")));
} }
@Test @Test
@ -298,7 +298,7 @@ class MessagePersisterTest {
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test"));
verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID); verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID);
} }
@ -400,7 +400,7 @@ class MessagePersisterTest {
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenReturn(CompletableFuture.failedFuture(new TimeoutException())); when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenReturn(CompletableFuture.failedFuture(new TimeoutException()));
assertThrows(CompletionException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); assertThrows(CompletionException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test"));
} }
@SuppressWarnings("SameParameterValue") @SuppressWarnings("SameParameterValue")

View File

@ -31,8 +31,8 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
public class AccountsHelper { public class AccountsHelper {
@ -62,6 +62,71 @@ public class AccountsHelper {
setupMockUpdate(mockAccountsManager, false); setupMockUpdate(mockAccountsManager, false);
} }
/**
* Sets up stubbing for:
* <ul>
* <li>{@link AccountsManager#update(Account, Consumer)}</li>
* <li>{@link AccountsManager#updateAsync(Account, Consumer)}</li>
* <li>{@link AccountsManager#updateDevice(Account, byte, Consumer)}</li>
* <li>{@link AccountsManager#updateDeviceAsync(Account, byte, Consumer)}</li>
* </ul>
*
* with multiple calls to the {@link Consumer<Account>}. This simulates retries from {@link org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException}.
* Callers will typically set up stubbing for relevant {@link Account} methods with multiple {@link org.mockito.stubbing.OngoingStubbing#thenReturn(Object)}
* calls:
* <pre>
* // example stubbing
* when(account.getNextDeviceId())
* .thenReturn(2)
* .thenReturn(3);
* </pre>
*/
@SuppressWarnings("unchecked")
public static void setupMockUpdateWithRetries(final AccountsManager mockAccountsManager, final int retryCount) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
for (int i = 0; i < retryCount; i++) {
answer.getArgument(1, Consumer.class).accept(account);
}
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateAsync(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
for (int i = 0; i < retryCount; i++) {
answer.getArgument(1, Consumer.class).accept(account);
}
return CompletableFuture.completedFuture(copyAndMarkStale(account));
});
when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
for (int i = 0; i < retryCount; i++) {
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
}
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateDeviceAsync(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
for (int i = 0; i < retryCount; i++) {
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
}
return CompletableFuture.completedFuture(copyAndMarkStale(account));
});
}
@SuppressWarnings("unchecked")
private static void setupMockUpdate(final AccountsManager mockAccountsManager, final boolean markStale) { private static void setupMockUpdate(final AccountsManager mockAccountsManager, final boolean markStale) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> { when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class); final Account account = answer.getArgument(0, Account.class);

View File

@ -0,0 +1,101 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.junit.jupiter.api.Assertions.*;
class ClosableEpochTest {
@Test
void close() {
{
final AtomicBoolean closed = new AtomicBoolean(false);
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true));
assertTrue(closableEpoch.tryArrive(), "New callers should be allowed to arrive before closure");
assertEquals(1, closableEpoch.getActiveCallers());
closableEpoch.close();
assertFalse(closableEpoch.tryArrive(), "New callers should not be allowed to arrive after closure");
assertEquals(1, closableEpoch.getActiveCallers());
assertFalse(closed.get(), "Close handler should not fire until all callers have departed");
closableEpoch.depart();
assertTrue(closed.get(), "Close handler should fire after last caller departs");
assertEquals(0, closableEpoch.getActiveCallers());
assertThrows(IllegalStateException.class, closableEpoch::close,
"Double-closing a epoch should throw an exception");
}
{
final AtomicBoolean closed = new AtomicBoolean(false);
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true));
closableEpoch.close();
assertTrue(closed.get(), "Empty epoch should fire close handler immediately on closure");
assertEquals(0, closableEpoch.getActiveCallers());
}
}
@Test
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
void closeConcurrent() throws InterruptedException {
final AtomicBoolean closed = new AtomicBoolean(false);
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> {
synchronized (closed) {
closed.set(true);
closed.notifyAll();
}
});
final int threadCount = 128;
final CyclicBarrier cyclicBarrier = new CyclicBarrier(threadCount);
// Spawn a bunch of threads doing some simulated work. Close the epoch roughly halfway through. Some threads should
// successfully enter the critical section and others should be rejected.
for (int t = 0; t < threadCount; t++) {
final boolean shouldClose = t == threadCount / 2;
Thread.ofVirtual().start(() -> {
try {
// Wait for all threads to reach the proverbial starting line
cyclicBarrier.await();
} catch (final InterruptedException | BrokenBarrierException ignored) {
}
if (shouldClose) {
closableEpoch.close();
}
if (closableEpoch.tryArrive()) {
// Perform some simulated "work"
try {
Thread.sleep(1);
} catch (final InterruptedException ignored) {
} finally {
closableEpoch.depart();
}
}
});
}
while (!closed.get()) {
synchronized (closed) {
closed.wait();
}
}
assertEquals(0, closableEpoch.getActiveCallers());
}
}

View File

@ -13,28 +13,20 @@ import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import javax.annotation.Nullable;
class UserAgentUtilTest { class UserAgentUtilTest {
@ParameterizedTest
@MethodSource
void testParseBogusUserAgentString(final String userAgentString) {
assertThrows(UnrecognizedUserAgentException.class, () -> UserAgentUtil.parseUserAgentString(userAgentString));
}
@SuppressWarnings("unused")
private static Stream<String> testParseBogusUserAgentString() {
return Stream.of(
null,
"This is obviously not a reasonable User-Agent string.",
"Signal-Android/4.6-8.3.unreasonableversionstring-17"
);
}
@ParameterizedTest @ParameterizedTest
@MethodSource("argumentsForTestParseStandardUserAgentString") @MethodSource("argumentsForTestParseStandardUserAgentString")
void testParseStandardUserAgentString(final String userAgentString, final UserAgent expectedUserAgent) { void testParseStandardUserAgentString(final String userAgentString, @Nullable final UserAgent expectedUserAgent)
assertEquals(expectedUserAgent, UserAgentUtil.parseStandardUserAgentString(userAgentString)); throws UnrecognizedUserAgentException {
if (expectedUserAgent != null) {
assertEquals(expectedUserAgent, UserAgentUtil.parseUserAgentString(userAgentString));
} else {
assertThrows(UnrecognizedUserAgentException.class, () -> UserAgentUtil.parseUserAgentString(userAgentString));
}
} }
private static Stream<Arguments> argumentsForTestParseStandardUserAgentString() { private static Stream<Arguments> argumentsForTestParseStandardUserAgentString() {
@ -42,18 +34,18 @@ class UserAgentUtilTest {
Arguments.of("This is obviously not a reasonable User-Agent string.", null), Arguments.of("This is obviously not a reasonable User-Agent string.", null),
Arguments.of("Signal-Android/4.68.3 Android/25", Arguments.of("Signal-Android/4.68.3 Android/25",
new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25")), new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25")),
Arguments.of("Signal-Android/4.68.3", new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"))), Arguments.of("Signal-Android/4.68.3", new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), null)),
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")), Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
Arguments.of("Signal-Desktop/1.2.3 macOS", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "macOS")), Arguments.of("Signal-Desktop/1.2.3 macOS", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "macOS")),
Arguments.of("Signal-Desktop/1.2.3 Windows", Arguments.of("Signal-Desktop/1.2.3 Windows",
new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Windows")), new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Windows")),
Arguments.of("Signal-Desktop/1.2.3", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"))), Arguments.of("Signal-Desktop/1.2.3", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), null)),
Arguments.of("Signal-Desktop/1.32.0-beta.3", Arguments.of("Signal-Desktop/1.32.0-beta.3",
new UserAgent(ClientPlatform.DESKTOP, new Semver("1.32.0-beta.3"))), new UserAgent(ClientPlatform.DESKTOP, new Semver("1.32.0-beta.3"), null)),
Arguments.of("Signal-iOS/3.9.0 (iPhone; iOS 12.2; Scale/3.00)", Arguments.of("Signal-iOS/3.9.0 (iPhone; iOS 12.2; Scale/3.00)",
new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)")), new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)")),
Arguments.of("Signal-iOS/3.9.0 iOS/14.2", new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "iOS/14.2")), Arguments.of("Signal-iOS/3.9.0 iOS/14.2", new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "iOS/14.2")),
Arguments.of("Signal-iOS/3.9.0", new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"))), Arguments.of("Signal-iOS/3.9.0", new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), null)),
Arguments.of("Signal-Android/7.11.23-nightly-1982-06-28-07-07-07 tonic/0.31", Arguments.of("Signal-Android/7.11.23-nightly-1982-06-28-07-07-07 tonic/0.31",
new UserAgent(ClientPlatform.ANDROID, new Semver("7.11.23-nightly-1982-06-28-07-07-07"), "tonic/0.31")), new UserAgent(ClientPlatform.ANDROID, new Semver("7.11.23-nightly-1982-06-28-07-07-07"), "tonic/0.31")),
Arguments.of("Signal-Android/7.11.23-nightly-1982-06-28-07-07-07 Android/42 tonic/0.31", Arguments.of("Signal-Android/7.11.23-nightly-1982-06-28-07-07-07 Android/42 tonic/0.31",

View File

@ -13,6 +13,7 @@ package org.signal.chat.rpc;
service EchoService { service EchoService {
rpc echo (EchoRequest) returns (EchoResponse) {} rpc echo (EchoRequest) returns (EchoResponse) {}
rpc echo2 (EchoRequest) returns (EchoResponse) {} rpc echo2 (EchoRequest) returns (EchoResponse) {}
rpc echoStream (stream EchoRequest) returns (stream EchoResponse) {}
} }
message EchoRequest { message EchoRequest {

View File

@ -23,14 +23,7 @@ message GetRequestAttributesResponse {
repeated string acceptable_languages = 1; repeated string acceptable_languages = 1;
repeated string available_accepted_locales = 2; repeated string available_accepted_locales = 2;
string remote_address = 3; string remote_address = 3;
string raw_user_agent = 4; string user_agent = 4;
UserAgent user_agent = 5;
}
message UserAgent {
string platform = 1;
string version = 2;
string additional_specifiers = 3;
} }
message GetAuthenticatedDeviceRequest { message GetAuthenticatedDeviceRequest {

View File

@ -470,7 +470,8 @@ turn:
cloudflare: cloudflare:
apiToken: secret://turn.cloudflare.apiToken apiToken: secret://turn.cloudflare.apiToken
endpoint: https://rtc.live.cloudflare.com/v1/turn/keys/LMNOP/credentials/generate endpoint: https://rtc.live.cloudflare.com/v1/turn/keys/LMNOP/credentials/generate
ttl: 86400 requestedCredentialTtl: PT24H
clientCredentialTtl: PT12H
urls: urls:
- turn:turn.example.com:80 - turn:turn.example.com:80
urlsWithIps: urlsWithIps:

@ -1 +1 @@
Subproject commit 8f566196d763c8eb1f3c8fcefd5be3c35ff8d148 Subproject commit d9852e294a853b88c7feaa748e17fee38acbf849