Compare commits

..

75 Commits

Author SHA1 Message Date
Jonathan Klabunde Tomer 74ee1c8c4f Update to the latest version of the spam filter 2025-05-21 10:46:18 -07:00
Jonathan Klabunde Tomer 35604cf151
Simplify rate limiters by making them all dynamic 2025-05-21 10:29:26 -07:00
Ravi Khadiwala aafcd63a9f Decrease the page size for OPK queries
A single element is almost always enough
2025-05-20 11:21:20 -04:00
Jon Chambers 43a534f05b Add a command for regenerating account constraint tables 2025-05-20 11:21:02 -04:00
Jon Chambers 9ec66dac7f Make `getRegistrationId` identity-type-aware 2025-05-14 14:39:11 -04:00
Jon Chambers 13fc0ffbca Assume that PNI registration IDs are always present on `Device` records 2025-05-14 14:39:11 -04:00
Jon Chambers 93ba6616d1 Perform device list validations in the scope of a pessimistic account lock 2025-05-14 14:39:11 -04:00
Jon Chambers a4b98f38a6 Use a `Callable` for tasks performed within the scope of a pessimistic lock 2025-05-14 14:39:11 -04:00
Jon Chambers b95d08aaea Drop `PqKeysUtil` 2025-05-14 14:39:11 -04:00
Jon Chambers b400d49e77 Require PQ keys when changing numbers or distributing key material 2025-05-14 14:39:11 -04:00
Jon Chambers e43487155f Remove commands for removing accounts/devices without PQ or PNI key material 2025-05-14 14:39:11 -04:00
Jon Chambers dee3723d97 Remove an unused user-agent argument 2025-05-14 14:39:11 -04:00
Jon Chambers b7e986f43c Add an integration test for changing phone numbers 2025-05-14 14:39:11 -04:00
Jon Chambers 664fb23e97 Resolve warnings/suggestions throughout `AccountsTest` 2025-05-14 11:30:59 -04:00
Chris Eager 714ef128a1
Compare using PNI in account reclamation 2025-05-13 16:41:42 -07:00
Ravi Khadiwala 7cf3fce624 Log unexpected account reclaim mismatches 2025-05-13 14:17:18 -05:00
ravi-signal 0cc5431867
Update noise-gRPC protocol errors 2025-05-13 14:16:23 -05:00
Ravi Khadiwala b8d5b2c8ea Match account idle duration in RemoveExpiredBackupsCommand 2025-05-13 14:15:50 -05:00
Ravi Khadiwala 894ca6d290 remove ANDROID_SKIP_LOW_URGENCY_PUSH_EXPERIMENT 2025-05-13 13:59:28 -05:00
Ravi Khadiwala 847b25f695 Add experiment to coalesce android notifications 2025-05-13 13:59:28 -05:00
Ravi Khadiwala 703a05cb15 Support scheduling background FCMs 2025-05-13 13:59:28 -05:00
Jon Chambers 30c194c557
Exclude `RateLimitExceededException` from fail-open checks 2025-05-12 15:24:57 -07:00
Jonathan Klabunde Tomer cc7b030a41
Send disconnection requests after non-API device unlinks 2025-05-06 13:36:41 -07:00
Jon Chambers 7a91c4d5b7 Correct metric names 2025-05-05 13:53:22 -04:00
Jon Chambers 287da6e7e3 Ignore already-locked accounts in PNI key cleanup operations 2025-05-05 13:53:22 -04:00
Katherine 7cf89764e7
Update `FullTreeHead` to use `FullAuditorTreeHead` 2025-05-05 10:44:57 -07:00
Jon Chambers d316c72beb
Add commands for removing accounts/devices without PNI key material 2025-05-05 12:10:47 -04:00
Katherine Yen 82d187cc45 Update key transparency protobufs 2025-05-02 10:40:53 -04:00
Jon Chambers 0c240d21d2 Update to the latest version of the spam filter 2025-05-02 10:40:07 -04:00
Jon Chambers 009252c831 Configure IP-keyed rate limiters to fail open 2025-05-02 10:30:29 -04:00
Jon Chambers 0c1146aaa5 Configure rate limiters with large initial capacities to fail open 2025-05-02 10:30:29 -04:00
Jon Chambers 4fd06594a0 Configure fast-regenerating rate limiters to fail open 2025-05-02 10:30:29 -04:00
Jon Chambers 4e175be88f Allow the "inbound message bytes" limiter to fail open 2025-05-02 10:30:29 -04:00
Jon Chambers 771a700acd Configure fail-open policy on individual rate limiters 2025-05-02 10:30:29 -04:00
Jon Chambers e9bd5da2c3 Allow fail-open behavior for a wider range of exceptions 2025-05-02 10:30:29 -04:00
Jon Chambers f64244f33a Remove an unused TURN rate limiter 2025-05-02 10:30:29 -04:00
Ravi Khadiwala ed1417c3e3 Update to the latest version of the spam filter 2025-04-30 15:06:03 -05:00
ravi-signal 0398e02690
Add NoiseDirect framing protocol 2025-04-30 15:05:05 -05:00
Chris Eager e285bf1a52 Fix test by using generic `exists` command 2025-04-29 13:05:10 -05:00
Ameya Lokare 2c9219d4f7 Update to the latest version of the spam filter 2025-04-29 10:57:05 -07:00
Jon Chambers 26b3b75054 Only fetch last-resort PQ keys for accounts with linked devices 2025-04-28 16:59:08 -04:00
Jon Chambers cdb651b68f
Add commands for removing devices without PQ keys 2025-04-28 15:45:27 -04:00
Ameya Lokare 91a36f4421 Update to the latest version of the spam filter 2025-04-28 11:59:43 -07:00
Jonathan Klabunde Tomer 21c1d71551 take advantage of list non-nullitude 2025-04-25 10:06:42 -05:00
Jonathan Klabunde Tomer 38befdb260 default lists to empty 2025-04-25 10:06:42 -05:00
Jonathan Klabunde Tomer 63c79173b2 limit prekey uploads to 100 2025-04-25 10:06:42 -05:00
Ameya Lokare d2ad003891 Remove free memory and OS memory gauges 2025-04-25 10:05:29 -05:00
Chris Eager eb89773819 Remove unused parameter 2025-04-25 10:05:18 -05:00
Chris Eager 403abd84f6 Run test action on pull_request events 2025-04-25 10:05:08 -05:00
Jon Chambers f62f79c95c Add a counter for cases where clients use both an authenticated identity and UAK when fetching profiles 2025-04-24 11:47:43 -04:00
Jon Chambers 144c4c9223 Add a "sync" dimension to the "sent message" counter 2025-04-24 10:33:39 -05:00
Ravi Khadiwala ab4fc4f459 Add skip low urgency push experiment 2025-04-24 10:32:46 -05:00
Jonathan Klabunde Tomer 51569ce0a5
Use cached partition topology for metrics/logs 2025-04-24 08:29:58 -07:00
Jon Chambers f191c68efc Close remote connections only after all active server calls have completed 2025-04-22 17:00:48 -04:00
Jon Chambers bb8ce6d981 Introduce `ClosableEpoch` 2025-04-22 17:00:48 -04:00
Katherine e0ee75e0d0
Fix Daylight Savings bug in recommended notification time calculation 2025-04-22 16:56:10 -04:00
Jon Chambers 1ef3a230a1 Tag queue size distribution with client platform 2025-04-22 16:55:16 -04:00
Jon Chambers b1805d4bf1 Add a "persisted bytes" counter 2025-04-22 16:55:16 -04:00
Jon Chambers cac979c7fd Count individual persisted messages 2025-04-22 16:55:16 -04:00
Jon Chambers 4072dcdda5 Introduce `DevicePlatformUtil` 2025-04-22 16:55:16 -04:00
Jonathan Klabunde Tomer ed382fff6d log slot number and shard host of message persister failures 2025-04-22 16:55:16 -04:00
Jon Chambers 23bb8277d5 Update to the latest version of the spam filter 2025-04-18 15:56:17 -04:00
Jon Chambers 8099d6465c
Clarify guarantees around remote channnel/request attribute presence 2025-04-18 15:44:21 -04:00
Jon Chambers 28a0b9e84e
Include a TURN credential TTL for clients in `GetCallingRelaysResponse` 2025-04-17 10:30:58 -04:00
Chris Eager 9287aaf7ce Add app info to Stripe API calls 2025-04-17 09:30:34 -05:00
Chris Eager 0585f862cb Add regression test for set profile badges calculation 2025-04-17 09:29:11 -05:00
Chris Eager 7cac6f6f72 Remove extraneous account fetch in POST /v1/donation/redeem-receipt 2025-04-17 09:28:57 -05:00
Jon Chambers 57be4d798b Add a counter for attempts to send empty message lists 2025-04-17 10:27:46 -04:00
Jon Chambers 05c74f1997 Simplify `UserAgentUtil` 2025-04-17 10:27:24 -04:00
Jon Chambers f5e49b6db7 Convert `UserAgent` to a record 2025-04-15 14:58:09 -04:00
Jon Chambers 3c40e72d27
Fix registration ID map construction when changing numbers 2025-04-15 14:57:28 -04:00
Ravi Khadiwala 2f2ae7cec5 simplify story tag calculation 2025-04-11 14:04:09 -05:00
Chris Eager b236b53dc3 set profile: move updated badge calculation into account updater lambda 2025-04-11 14:03:05 -05:00
Katherine eb71e30046
Update to protobuf 4.x 2025-04-10 13:05:23 -04:00
Jon Chambers aa5fd52302 Explicitly pass sync message sender device ID as an argument to `sendMessage` 2025-04-10 11:40:32 -04:00
211 changed files with 6304 additions and 4132 deletions

View File

@ -1,6 +1,7 @@
name: Service CI
on:
pull_request:
push:
branches-ignore:
- gh-pages

View File

@ -72,14 +72,12 @@ public final class Operations {
}
public static TestUser newRegisteredUser(final String number) {
final byte[] registrationPassword = randomBytes(32);
final byte[] registrationPassword = populateRandomRecoveryPassword(number);
final String accountPassword = Base64.getEncoder().encodeToString(randomBytes(32));
final TestUser user = TestUser.create(number, accountPassword, registrationPassword);
final AccountAttributes accountAttributes = user.accountAttributes();
INTEGRATION_TOOLS.populateRecoveryPassword(number, registrationPassword).join();
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@ -108,6 +106,7 @@ public final class Operations {
}
public record PrescribedVerificationNumber(String number, String verificationCode) {}
public static PrescribedVerificationNumber prescribedVerificationNumber() {
return new PrescribedVerificationNumber(
CONFIG.prescribedRegistrationNumber(),
@ -123,6 +122,13 @@ public final class Operations {
.orElseThrow(() -> new RuntimeException("push challenge not found for the verification session"));
}
public static byte[] populateRandomRecoveryPassword(final String number) {
final byte[] recoveryPassword = randomBytes(32);
INTEGRATION_TOOLS.populateRecoveryPassword(number, recoveryPassword).join();
return recoveryPassword;
}
public static <T> T sendEmptyRequestAuthenticated(
final String endpoint,
final String method,
@ -329,15 +335,15 @@ public final class Operations {
}
}
private static ECSignedPreKey generateSignedECPreKey(long id, final ECKeyPair identityKeyPair) {
public static ECSignedPreKey generateSignedECPreKey(final long id, final ECKeyPair identityKeyPair) {
final ECPublicKey pubKey = Curve.generateKeyPair().getPublicKey();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new ECSignedPreKey(id, pubKey, sig);
final byte[] signature = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new ECSignedPreKey(id, pubKey, signature);
}
private static KEMSignedPreKey generateSignedKEMPreKey(long id, final ECKeyPair identityKeyPair) {
public static KEMSignedPreKey generateSignedKEMPreKey(final long id, final ECKeyPair identityKeyPair) {
final KEMPublicKey pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new KEMSignedPreKey(id, pubKey, sig);
final byte[] signature = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new KEMSignedPreKey(id, pubKey, signature);
}
}

View File

@ -6,27 +6,35 @@
package org.signal.integration;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.usernames.BaseUsernameException;
import org.signal.libsignal.usernames.Username;
import org.whispersystems.textsecuregcm.entities.AccountIdentifierResponse;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
import org.whispersystems.textsecuregcm.entities.ConfirmUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse;
import org.whispersystems.textsecuregcm.entities.UsernameHashResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Device;
public class AccountTest {
@Test
public void testCreateAccount() throws Exception {
public void testCreateAccount() {
final TestUser user = Operations.newRegisteredUser("+19995550101");
try {
final Pair<Integer, AccountIdentityResponse> execute = Operations.apiGet("/v1/accounts/whoami")
@ -39,7 +47,7 @@ public class AccountTest {
}
@Test
public void testCreateAccountAtomic() throws Exception {
public void testCreateAccountAtomic() {
final TestUser user = Operations.newRegisteredUser("+19995550201");
try {
final Pair<Integer, AccountIdentityResponse> execute = Operations.apiGet("/v1/accounts/whoami")
@ -51,6 +59,33 @@ public class AccountTest {
}
}
@Test
public void changePhoneNumber() {
final TestUser user = Operations.newRegisteredUser("+19995550301");
final String targetNumber = "+19995550302";
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ChangeNumberRequest changeNumberRequest = new ChangeNumberRequest(null,
Operations.populateRandomRecoveryPassword(targetNumber),
targetNumber,
null,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Collections.emptyList(),
Map.of(Device.PRIMARY_ID, Operations.generateSignedECPreKey(1, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, Operations.generateSignedKEMPreKey(2, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 17));
final AccountIdentityResponse accountIdentityResponse =
Operations.apiPut("/v2/accounts/number", changeNumberRequest)
.authorized(user)
.executeExpectSuccess(AccountIdentityResponse.class);
assertEquals(user.aciUuid(), accountIdentityResponse.uuid());
assertNotEquals(user.pniUuid(), accountIdentityResponse.pni());
assertEquals(targetNumber, accountIdentityResponse.number());
}
@Test
public void testUsernameOperations() throws Exception {
final TestUser user = Operations.newRegisteredUser("+19995550102");

17
pom.xml
View File

@ -46,7 +46,7 @@
<!-- can be updated to latest version with Dropwizard 5 (Jetty 12); will then need to disable telemetry -->
<dynamodblocal.version>2.2.1</dynamodblocal.version>
<google-cloud-libraries.version>26.57.0</google-cloud-libraries.version>
<grpc.version>1.69.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
<grpc.version>1.70.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
<gson.version>2.12.1</gson.version>
<!-- several libraries (AWS, Google Cloud) use Apache http components transitively, and we need to align them -->
<httpcore.version>4.4.16</httpcore.version>
@ -65,9 +65,9 @@
<luajava.version>3.5.0</luajava.version>
<micrometer.version>1.14.5</micrometer.version>
<netty.version>4.1.119.Final</netty.version>
<!-- Must be greater than or equal to the value from Google libraries-bom
since some of its libraries generate code. See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
<protobuf.version>3.25.5</protobuf.version>
<!-- Must be less than or equal to the value from Google libraries-bom which controls the protobuf runtime version.
See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
<protoc.version>4.29.4</protoc.version>
<pushy.version>0.15.4</pushy.version>
<reactive.grpc.version>1.2.4</reactive.grpc.version>
<reactor-bom.version>2024.0.4</reactor-bom.version> <!-- 3.7.4, see https://github.com/reactor/reactor#bom-versioning-scheme -->
@ -127,7 +127,7 @@
</dependency>
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>libraries-bom-protobuf3</artifactId>
<artifactId>libraries-bom</artifactId>
<version>${google-cloud-libraries.version}</version>
<type>pom</type>
<scope>import</scope>
@ -175,11 +175,6 @@
<artifactId>pushy-dropwizard-metrics-listener</artifactId>
<version>${pushy.version}</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>
<dependency>
<groupId>com.googlecode.libphonenumber</groupId>
<artifactId>libphonenumber</artifactId>
@ -443,7 +438,7 @@
<version>0.6.1</version>
<configuration>
<checkStaleness>false</checkStaleness>
<protocArtifact>com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}</protocArtifact>
<protocArtifact>com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}</protocArtifact>
<pluginId>grpc-java</pluginId>
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>

View File

@ -482,7 +482,8 @@ turn:
- turn:%s
- turn:%s:80?transport=tcp
- turns:%s:443?transport=tcp
ttl: 86400
requestedCredentialTtl: PT24H
clientCredentialTtl: PT12H
hostname: turn.cloudflare.example.com
numHttpClients: 1

View File

@ -407,10 +407,6 @@ public class WhisperServerConfiguration extends Configuration {
return rateLimitersCluster;
}
public Map<String, RateLimiterConfig> getLimitsConfiguration() {
return limits;
}
public FcmConfiguration getFcmConfiguration() {
return fcm;
}

View File

@ -154,7 +154,7 @@ import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.NoiseWebSocketTunnelServer;
import org.whispersystems.textsecuregcm.grpc.net.websocket.NoiseWebSocketTunnelServer;
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
@ -269,6 +269,7 @@ import org.whispersystems.textsecuregcm.workers.IdleDeviceNotificationSchedulerF
import org.whispersystems.textsecuregcm.workers.MessagePersisterServiceCommand;
import org.whispersystems.textsecuregcm.workers.NotifyIdleDevicesCommand;
import org.whispersystems.textsecuregcm.workers.ProcessScheduledJobsServiceCommand;
import org.whispersystems.textsecuregcm.workers.RegenerateAccountConstraintDataCommand;
import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand;
import org.whispersystems.textsecuregcm.workers.RemoveExpiredBackupsCommand;
import org.whispersystems.textsecuregcm.workers.RemoveExpiredLinkedDevicesCommand;
@ -335,6 +336,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
bootstrap.addCommand(new ProcessScheduledJobsServiceCommand("process-idle-device-notification-jobs",
"Processes scheduled jobs to send notifications to idle devices",
new IdleDeviceNotificationSchedulerFactory()));
bootstrap.addCommand(new RegenerateAccountConstraintDataCommand());
}
@Override
@ -633,11 +636,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushNotificationScheduler pushNotificationScheduler = new PushNotificationScheduler(pushSchedulerCluster,
apnSender, fcmSender, accountsManager, 0, 0);
PushNotificationManager pushNotificationManager =
new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler);
new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler, experimentEnrollmentManager);
WebSocketConnectionEventManager webSocketConnectionEventManager =
new WebSocketConnectionEventManager(accountsManager, pushNotificationManager, messagesCluster, clientEventExecutor, asyncOperationQueueingExecutor);
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(),
dynamicConfigurationManager, rateLimitersCluster);
RateLimiters rateLimiters = RateLimiters.create(dynamicConfigurationManager, rateLimitersCluster);
ProvisioningManager provisioningManager = new ProvisioningManager(pubsubClient);
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
@ -668,12 +670,13 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager);
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager);
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
config.getTurnConfiguration().cloudflare().apiToken().value(),
config.getTurnConfiguration().cloudflare().endpoint(),
config.getTurnConfiguration().cloudflare().ttl(),
config.getTurnConfiguration().cloudflare().requestedCredentialTtl(),
config.getTurnConfiguration().cloudflare().clientCredentialTtl(),
config.getTurnConfiguration().cloudflare().urls(),
config.getTurnConfiguration().cloudflare().urlsWithIps(),
config.getTurnConfiguration().cloudflare().hostname(),
@ -693,7 +696,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager,
pushChallengeDynamoDb);
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, Clock.systemUTC());
HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build();
FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients()
@ -987,7 +990,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor));
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager));
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));

View File

@ -15,6 +15,7 @@ import java.net.Inet6Address;
import java.net.URI;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletionException;
@ -39,16 +40,18 @@ public class CloudflareTurnCredentialsManager {
private final List<String> cloudflareTurnUrls;
private final List<String> cloudflareTurnUrlsWithIps;
private final String cloudflareTurnHostname;
private final HttpRequest request;
private final HttpRequest getCredentialsRequest;
private final FaultTolerantHttpClient cloudflareTurnClient;
private final DnsNameResolver dnsNameResolver;
record CredentialRequest(long ttl) {}
private final Duration clientCredentialTtl;
record CloudflareTurnResponse(IceServer iceServers) {
private record CredentialRequest(long ttl) {}
record IceServer(
private record CloudflareTurnResponse(IceServer iceServers) {
private record IceServer(
String username,
String credential,
List<String> urls) {
@ -56,10 +59,17 @@ public class CloudflareTurnCredentialsManager {
}
public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken,
final String cloudflareTurnEndpoint, final long cloudflareTurnTtl, final List<String> cloudflareTurnUrls,
final List<String> cloudflareTurnUrlsWithIps, final String cloudflareTurnHostname,
final int cloudflareTurnNumHttpClients, final CircuitBreakerConfiguration circuitBreaker,
final ExecutorService executor, final RetryConfiguration retry, final ScheduledExecutorService retryExecutor,
final String cloudflareTurnEndpoint,
final Duration requestedCredentialTtl,
final Duration clientCredentialTtl,
final List<String> cloudflareTurnUrls,
final List<String> cloudflareTurnUrlsWithIps,
final String cloudflareTurnHostname,
final int cloudflareTurnNumHttpClients,
final CircuitBreakerConfiguration circuitBreaker,
final ExecutorService executor,
final RetryConfiguration retry,
final ScheduledExecutorService retryExecutor,
final DnsNameResolver dnsNameResolver) {
this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder()
@ -75,17 +85,24 @@ public class CloudflareTurnCredentialsManager {
this.cloudflareTurnHostname = cloudflareTurnHostname;
this.dnsNameResolver = dnsNameResolver;
final String credentialsRequestBody;
try {
final String body = SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(cloudflareTurnTtl));
this.request = HttpRequest.newBuilder()
credentialsRequestBody =
SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(requestedCredentialTtl.toSeconds()));
} catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
// We repeat the same request to Cloudflare every time, so we can construct it once and re-use it
this.getCredentialsRequest = HttpRequest.newBuilder()
.uri(URI.create(cloudflareTurnEndpoint))
.header("Content-Type", "application/json")
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
.POST(HttpRequest.BodyPublishers.ofString(body))
.POST(HttpRequest.BodyPublishers.ofString(credentialsRequestBody))
.build();
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
this.clientCredentialTtl = clientCredentialTtl;
}
public TurnToken retrieveFromCloudflare() throws IOException {
@ -105,7 +122,7 @@ public class CloudflareTurnCredentialsManager {
final Timer.Sample sample = Timer.start();
final HttpResponse<String> response;
try {
response = cloudflareTurnClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join();
response = cloudflareTurnClient.sendAsync(getCredentialsRequest, HttpResponse.BodyHandlers.ofString()).join();
sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME)
.publishPercentileHistogram(true)
.tags("outcome", "success")
@ -130,6 +147,7 @@ public class CloudflareTurnCredentialsManager {
return new TurnToken(
cloudflareTurnResponse.iceServers().username(),
cloudflareTurnResponse.iceServers().credential(),
clientCredentialTtl.toSeconds(),
cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls,
cloudflareTurnComposedUrls,
cloudflareTurnHostname

View File

@ -5,13 +5,15 @@
package org.whispersystems.textsecuregcm.auth;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;
public record TurnToken(
String username,
String password,
@JsonProperty("ttl") long ttlSeconds,
@Nonnull List<String> urls,
@Nonnull List<String> urlsWithIps,
@Nullable String hostname) {

View File

@ -1,34 +1,22 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import java.util.Optional;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Metadata EMPTY_TRAILERS = new Metadata();
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
return grpcClientConnectionManager.getAuthenticatedDevice(localAddress);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call)
throws ChannelNotFoundException {
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) {
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
return new ServerCall.Listener<>() {};
return grpcClientConnectionManager.getAuthenticatedDevice(call);
}
}

View File

@ -3,12 +3,17 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
* device are closed with an {@code UNAUTHENTICATED} status.
* device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
* reject the call with a status of {@code UNAVAILABLE}.
*/
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -21,8 +26,15 @@ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInt
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
try {
return getAuthenticatedDevice(call)
.map(ignored -> closeAsUnauthenticated(call))
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited
// service via an authenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
.orElseGet(() -> next.startCall(call, headers));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
}
}

View File

@ -5,12 +5,16 @@ import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
* status.
* status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
*/
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
@ -23,10 +27,17 @@ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInte
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
try {
return getAuthenticatedDevice(call)
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
call, headers, next))
.orElseGet(() -> closeAsUnauthenticated(call));
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
}
}

View File

@ -6,16 +6,36 @@
package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import java.time.Duration;
import java.util.List;
import jakarta.validation.constraints.Positive;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
/**
* Configuration properties for Cloudflare TURN integration.
*
* @param apiToken the API token to use when requesting TURN tokens from Cloudflare
* @param endpoint the URI of the Cloudflare API endpoint that vends TURN tokens
* @param requestedCredentialTtl the lifetime of TURN tokens to request from Cloudflare
* @param clientCredentialTtl the time clients may cache a TURN token; must be less than or equal to {@link #requestedCredentialTtl}
* @param urls a collection of TURN URLs to include verbatim in responses to clients
* @param urlsWithIps a collection of {@link String#format(String, Object...)} patterns to be populated with resolved IP
* addresses for {@link #hostname} in responses to clients; each pattern must include a single
* {@code %s} placeholder for the IP address
* @param circuitBreaker a circuit breaker for requests to Cloudflare
* @param retry a retry policy for requests to Cloudflare
* @param hostname the hostname to resolve to IP addresses for use with {@link #urlsWithIps}; also transmitted to
* clients for use as an SNI when connecting to pre-resolved hosts
* @param numHttpClients the number of parallel HTTP clients to use to communicate with Cloudflare
*/
public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
@NotBlank String endpoint,
@NotBlank long ttl,
@NotNull Duration requestedCredentialTtl,
@NotNull Duration clientCredentialTtl,
@NotNull @NotEmpty @Valid List<@NotBlank String> urls,
@NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps,
@NotNull @Valid CircuitBreakerConfiguration circuitBreaker,
@ -35,4 +55,9 @@ public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
retry = new RetryConfiguration();
}
}
@AssertTrue
public boolean isClientTtlShorterThanRequestedTtl() {
return clientCredentialTtl.compareTo(requestedCredentialTtl) <= 0;
}
}

View File

@ -46,10 +46,6 @@ public class DynamicConfiguration {
@Valid
DynamicMessagePersisterConfiguration messagePersister = new DynamicMessagePersisterConfiguration();
@JsonProperty
@Valid
DynamicRateLimitPolicy rateLimitPolicy = new DynamicRateLimitPolicy(false);
@JsonProperty
@Valid
DynamicRegistrationConfiguration registrationConfiguration = new DynamicRegistrationConfiguration(false);
@ -100,10 +96,6 @@ public class DynamicConfiguration {
return messagePersister;
}
public DynamicRateLimitPolicy getRateLimitPolicy() {
return rateLimitPolicy;
}
public DynamicRegistrationConfiguration getRegistrationConfiguration() {
return registrationConfiguration;
}

View File

@ -1,8 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration.dynamic;
public record DynamicRateLimitPolicy(boolean failOpen) {}

View File

@ -102,9 +102,9 @@ public class AccountControllerV2 {
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public AccountIdentityResponse changeNumber(@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
@NotNull @Valid final ChangeNumberRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgentString,
@Context final ContainerRequestContext requestContext)
throws RateLimitExceededException, InterruptedException {
@NotNull @Valid final ChangeNumberRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgentString,
@Context final ContainerRequestContext requestContext) throws RateLimitExceededException, InterruptedException {
if (!authenticatedDevice.getAuthenticatedDevice().isPrimary()) {
throw new ForbiddenException();

View File

@ -15,16 +15,12 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly;
@ -32,14 +28,16 @@ import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v2/calling")
public class CallRoutingControllerV2 {
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
private final RateLimiters rateLimiters;
private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER =
Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
public CallRoutingControllerV2(
final RateLimiters rateLimiters,
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager
) {
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager) {
this.rateLimiters = rateLimiters;
this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager;
}
@ -58,25 +56,17 @@ public class CallRoutingControllerV2 {
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public GetCallingRelaysResponse getCallingRelays(
final @ReadOnly @Auth AuthenticatedDevice auth
) throws RateLimitExceededException, IOException {
UUID aci = auth.getAccount().getUuid();
public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth)
throws RateLimitExceededException, IOException {
final UUID aci = auth.getAccount().getUuid();
rateLimiters.getCallEndpointLimiter().validate(aci);
List<TurnToken> tokens = new ArrayList<>();
try {
tokens.add(cloudflareTurnCredentialsManager.retrieveFromCloudflare());
} catch (Exception e) {
CallRoutingControllerV2.CLOUDFLARE_TURN_ERROR_COUNTER.increment();
return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare()));
} catch (final Exception e) {
CLOUDFLARE_TURN_ERROR_COUNTER.increment();
throw e;
}
return new GetCallingRelaysResponse(tokens);
}
public record GetCallingRelaysResponse(
List<TurnToken> relays
) {
}
}

View File

@ -44,7 +44,6 @@ import java.util.EnumMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
@ -52,7 +51,6 @@ import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
@ -74,6 +72,7 @@ import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
@ -402,7 +401,7 @@ public class DeviceController {
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
try {
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform());
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).platform());
} catch (final UnrecognizedUserAgentException ignored) {
return linkedDeviceListenersForUnrecognizedPlatforms;
}
@ -600,25 +599,9 @@ public class DeviceController {
}
private static io.micrometer.core.instrument.Tag primaryPlatformTag(final Account account) {
final Device primaryDevice = account.getPrimaryDevice();
Optional<ClientPlatform> clientPlatform = Optional.empty();
if (StringUtils.isNotBlank(primaryDevice.getGcmId())) {
clientPlatform = Optional.of(ClientPlatform.ANDROID);
} else if (StringUtils.isNotBlank(primaryDevice.getApnId())) {
clientPlatform = Optional.of(ClientPlatform.IOS);
}
clientPlatform = clientPlatform.or(() -> Optional.ofNullable(
switch (primaryDevice.getUserAgent()) {
case "OWA" -> ClientPlatform.ANDROID;
case "OWI", "OWP" -> ClientPlatform.IOS;
case "OWD" -> ClientPlatform.DESKTOP;
case null, default -> null;
}));
return io.micrometer.core.instrument.Tag.of(
"primaryPlatform",
clientPlatform
DevicePlatformUtil.getDevicePlatform(account.getPrimaryDevice())
.map(p -> p.name().toLowerCase(Locale.ROOT))
.orElse("unknown"));
}

View File

@ -98,19 +98,18 @@ public class DonationController {
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid())
.thenCompose(receiptMatched -> {
if (!receiptMatched) {
return CompletableFuture.completedFuture(
Response.status(Status.BAD_REQUEST).entity("receipt serial is already redeemed")
.type(MediaType.TEXT_PLAIN_TYPE).build());
}
return accountsManager.getByAccountIdentifierAsync(auth.getAccount().getUuid())
.thenCompose(optionalAccount ->
optionalAccount.map(account -> accountsManager.updateAsync(account, a -> {
return accountsManager.updateAsync(auth.getAccount(), a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId);
}
})).orElse(CompletableFuture.completedFuture(null)))
})
.thenApply(ignored -> Response.ok().build());
});
}).thenCompose(Function.identity());

View File

@ -0,0 +1,13 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import java.util.List;
public record GetCallingRelaysResponse(List<TurnToken> relays) {
}

View File

@ -152,7 +152,7 @@ public class KeysController {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
if (setKeysRequest.preKeys() != null && !setKeysRequest.preKeys().isEmpty()) {
if (!setKeysRequest.preKeys().isEmpty()) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec")))
.increment();
@ -168,7 +168,7 @@ public class KeysController {
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
}
if (setKeysRequest.pqPreKeys() != null && !setKeysRequest.pqPreKeys().isEmpty()) {
if (!setKeysRequest.pqPreKeys().isEmpty()) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber")))
.increment();
@ -192,11 +192,7 @@ public class KeysController {
final IdentityKey identityKey,
@Nullable final String userAgent) {
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>();
if (setKeysRequest.pqPreKeys() != null) {
signedPreKeys.addAll(setKeysRequest.pqPreKeys());
}
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(setKeysRequest.pqPreKeys());
if (setKeysRequest.pqLastResortPreKey() != null) {
signedPreKeys.add(setKeysRequest.pqLastResortPreKey());
@ -244,8 +240,7 @@ public class KeysController {
@ApiResponse(responseCode = "422", description = "Invalid request format")
public CompletableFuture<Response> checkKeys(
@ReadOnly @Auth final AuthenticatedDevice auth,
@RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
@RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest) {
final UUID identifier = auth.getAccount().getIdentifier(checkKeysRequest.identityType());
final byte deviceId = auth.getAuthenticatedDevice().getId();
@ -386,10 +381,7 @@ public class KeysController {
.increment();
if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};
final int registrationId = device.getRegistrationId(targetIdentifier.identityType());
responseItems.add(
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey,

View File

@ -436,11 +436,16 @@ public class MessageController {
final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
final Optional<Byte> syncMessageSenderDeviceId = messageType == MessageType.SYNC
? Optional.ofNullable(sender).map(authenticatedDevice -> authenticatedDevice.getAuthenticatedDevice().getId())
: Optional.empty();
try {
messageSender.sendMessages(destination,
destinationIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
syncMessageSenderDeviceId,
userAgent);
} catch (final MismatchedDevicesException e) {
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {

View File

@ -428,7 +428,7 @@ public class OneTimeDonationController {
@Nullable
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
try {
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
return UserAgentUtil.parseUserAgentString(userAgentString).platform();
} catch (final UnrecognizedUserAgentException e) {
return null;
}

View File

@ -47,6 +47,7 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ServiceId;
@ -123,6 +124,7 @@ public class ProfileController {
private static final String EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE = "expiringProfileKey";
private static final String VERSION_NOT_FOUND_COUNTER_NAME = name(ProfileController.class, "versionNotFound");
private static final String DUPLICATE_AUTHENTICATION_COUNTER_NAME = name(ProfileController.class, "duplicateAuthentication");
public ProfileController(
Clock clock,
@ -204,11 +206,12 @@ public class ProfileController {
.build()));
}
final List<AccountBadge> updatedBadges = request.badges()
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, auth.getAccount().getBadges()))
.orElseGet(() -> auth.getAccount().getBadges());
accountsManager.update(auth.getAccount(), a -> {
final List<AccountBadge> updatedBadges = request.badges()
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
.orElseGet(a::getBadges);
a.setBadges(clock, updatedBadges);
a.setCurrentProfileVersion(request.version());
});
@ -229,11 +232,12 @@ public class ProfileController {
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version)
@PathParam("version") String version,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier);
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "getVersionedProfile", userAgent);
return buildVersionedProfileResponse(targetAccount,
version,
@ -252,7 +256,8 @@ public class ProfileController {
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version,
@PathParam("credentialRequest") String credentialRequest,
@QueryParam("credentialType") String credentialType)
@QueryParam("credentialType") String credentialType,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {
if (!EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE.equals(credentialType)) {
@ -260,7 +265,7 @@ public class ProfileController {
}
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier);
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "credentialRequest", userAgent);
final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false);
return buildExpiringProfileKeyCredentialProfileResponse(targetAccount,
@ -282,8 +287,7 @@ public class ProfileController {
@HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken,
@Context ContainerRequestContext containerRequestContext,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@PathParam("identifier") ServiceIdentifier identifier,
@QueryParam("ca") boolean useCaCertificate)
@PathParam("identifier") ServiceIdentifier identifier)
throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
@ -302,7 +306,7 @@ public class ProfileController {
}
} else {
targetAccount = verifyPermissionToReceiveProfile(
maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier);
maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier, "getUnversionedProfile", userAgent);
}
return switch (identifier.identityType()) {
case ACI -> buildBaseProfileResponseForAccountIdentity(targetAccount,
@ -385,7 +389,7 @@ public class ProfileController {
profileKeyCredentialResponse = ProfileHelper.getExpiringProfileKeyCredential(HexFormat.of().parseHex(encodedCredentialRequest),
profile, new ServiceId.Aci(account.getUuid()), zkProfileOperations);
} catch (VerificationFailedException | InvalidInputException e) {
throw new BadRequestException(Response.status(Response.Status.BAD_REQUEST).build(), e);
throw new BadRequestException(e);
}
return profileKeyCredentialResponse;
})
@ -473,7 +477,15 @@ public class ProfileController {
*/
private Account verifyPermissionToReceiveProfile(final Optional<Account> maybeRequester,
final Optional<Anonymous> maybeAccessKey,
final ServiceIdentifier accountIdentifier) throws RateLimitExceededException {
final ServiceIdentifier accountIdentifier,
final String endpoint,
@Nullable final String userAgent) throws RateLimitExceededException {
if (maybeRequester.isPresent() && maybeAccessKey.isPresent()) {
Metrics.counter(DUPLICATE_AUTHENTICATION_COUNTER_NAME,
Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), io.micrometer.core.instrument.Tag.of("endpoint", endpoint)))
.increment();
}
if (maybeRequester.isPresent()) {
rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid());

View File

@ -755,7 +755,7 @@ public class SubscriptionController {
@Nullable
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
try {
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
return UserAgentUtil.parseUserAgentString(userAgentString).platform();
} catch (final UnrecognizedUserAgentException e) {
return null;
}

View File

@ -7,30 +7,30 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.UUID;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
public record AccountIdentityResponse(
@Schema(description="the account identifier for this account")
@Schema(description = "the account identifier for this account")
UUID uuid,
@Schema(description="the phone number associated with this account")
@Schema(description = "the phone number associated with this account")
String number,
@Schema(description="the account identifier for this account's phone-number identity")
@Schema(description = "the account identifier for this account's phone-number identity")
UUID pni,
@Schema(description="a hash of this account's username, if set")
@Schema(description = "a hash of this account's username, if set")
@JsonSerialize(using = ByteArrayBase64UrlAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayBase64UrlAdapter.Deserializing.class)
@Nullable byte[] usernameHash,
@Schema(description="this account's username link handle, if set")
@Schema(description = "this account's username link handle, if set")
@Nullable UUID usernameLinkHandle,
@Schema(description="whether any of this account's devices support storage")
@Schema(description = "whether any of this account's devices support storage")
boolean storageCapable,
@Schema(description = "entitlements for this account and their current expirations")

View File

@ -17,6 +17,7 @@ import javax.annotation.Nullable;
import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
@ -51,34 +52,27 @@ public record ChangeNumberRequest(
arraySchema=@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keysManager
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device."""))
Exactly one message must be supplied for each device other than the sending (primary) device."""))
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
A new signed elliptic-curve prekey for each device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@NotNull @Valid Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@NotNull @NotEmpty @Valid Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
A new signed post-quantum last-resort prekey for each device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@NotNull @NotEmpty @Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@NotNull Map<Byte, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
@Schema(description="the new phone-number-identity registration ID for each device on the account, including this one")
@NotNull @NotEmpty Map<Byte, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) {
List<SignedPreKey<?>> spks = new ArrayList<>();
if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values());
}
if (devicePniPqLastResortPrekeys != null) {
final List<SignedPreKey<?>> spks = new ArrayList<>(devicePniSignedPrekeys.values());
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number");
return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number");
}
@AssertTrue

View File

@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.ArraySchema;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import java.util.ArrayList;
import java.util.List;
@ -29,36 +30,36 @@ public record PhoneNumberIdentityKeyDistributionRequest(
arraySchema=@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.
Exactly one message must be supplied for each device other than the sending (primary) device.
"""))
List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull
@NotEmpty
@Valid
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
A new signed elliptic-curve prekey for each device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@NotNull
@NotEmpty
@Valid
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
A new signed post-quantum last-resort prekey for each device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@NotNull
@NotEmpty
@Valid
@Schema(description="The new registration ID to use for the phone-number identity of each device, including this one.")
Map<Byte, Integer> pniRegistrationIds) {
public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) {
List<SignedPreKey<?>> spks = new ArrayList<>(devicePniSignedPrekeys.values());
if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "distribute-pni-keys");
}
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(devicePniSignedPrekeys.values());
signedPreKeys.addAll(devicePniPqLastResortPrekeys.values());
return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, signedPreKeys, userAgent, "distribute-pni-keys");
}
}

View File

@ -5,11 +5,15 @@
package org.whispersystems.textsecuregcm.entities;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import jakarta.validation.Valid;
import java.util.List;
public record SetKeysRequest(
@NotNull
@Valid
@Size(max=100)
@Schema(description = """
A list of unsigned elliptic-curve prekeys to use for this device. If present and not empty, replaces all stored
unsigned EC prekeys for the device; if absent or empty, any stored unsigned EC prekeys for the device are not
@ -25,7 +29,9 @@ public record SetKeysRequest(
""")
ECSignedPreKey signedPreKey,
@NotNull
@Valid
@Size(max=100)
@Schema(description = """
A list of signed post-quantum one-time prekeys to use for this device. Each key must have a valid signature from
the identity key in this request. If present and not empty, replaces all stored unsigned PQ prekeys for the
@ -40,4 +46,16 @@ public record SetKeysRequest(
deleted. If present, must have a valid signature from the identity key in this request.
""")
KEMSignedPreKey pqLastResortPreKey) {
public SetKeysRequest {
// Its a little counter-intuitive, but this compact constructor allows a default value
// to be used when one isnt specified, allowing the field to still be
// validated as @NotNull
if (preKeys == null) {
preKeys = List.of();
}
if (pqPreKeys == null) {
pqPreKeys = List.of();
}
}
}

View File

@ -81,7 +81,16 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) {
@Nullable final UserAgent userAgent = RequestAttributesUtil.getUserAgent()
.map(userAgentString -> {
try {
return UserAgentUtil.parseUserAgentString(userAgentString);
} catch (final UnrecognizedUserAgentException e) {
return null;
}
}).orElse(null);
if (shouldBlock(userAgent)) {
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
return new ServerCall.Listener<>() {};
} else {
@ -108,28 +117,28 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
return true;
}
if (blockedVersionsByPlatform.containsKey(userAgent.getPlatform())) {
if (blockedVersionsByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) {
if (blockedVersionsByPlatform.containsKey(userAgent.platform())) {
if (blockedVersionsByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
recordDeprecation(userAgent, BLOCKED_CLIENT_REASON);
shouldBlock = true;
}
}
if (minimumVersionsByPlatform.containsKey(userAgent.getPlatform())) {
if (userAgent.getVersion().isLowerThan(minimumVersionsByPlatform.get(userAgent.getPlatform()))) {
if (minimumVersionsByPlatform.containsKey(userAgent.platform())) {
if (userAgent.version().isLowerThan(minimumVersionsByPlatform.get(userAgent.platform()))) {
recordDeprecation(userAgent, EXPIRED_CLIENT_REASON);
shouldBlock = true;
}
}
if (versionsPendingBlockByPlatform.containsKey(userAgent.getPlatform())) {
if (versionsPendingBlockByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) {
if (versionsPendingBlockByPlatform.containsKey(userAgent.platform())) {
if (versionsPendingBlockByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
recordPendingDeprecation(userAgent, BLOCKED_CLIENT_REASON);
}
}
if (versionsPendingDeprecationByPlatform.containsKey(userAgent.getPlatform())) {
if (userAgent.getVersion().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.getPlatform()))) {
if (versionsPendingDeprecationByPlatform.containsKey(userAgent.platform())) {
if (userAgent.version().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.platform()))) {
recordPendingDeprecation(userAgent, EXPIRED_CLIENT_REASON);
}
}
@ -139,13 +148,13 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
private void recordDeprecation(final UserAgent userAgent, final String reason) {
Metrics.counter(DEPRECATED_CLIENT_COUNTER_NAME,
PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized",
PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized",
REASON_TAG_NAME, reason).increment();
}
private void recordPendingDeprecation(final UserAgent userAgent, final String reason) {
Metrics.counter(PENDING_DEPRECATION_COUNTER_NAME,
PLATFORM_TAG, userAgent.getPlatform().name().toLowerCase(),
PLATFORM_TAG, userAgent.platform().name().toLowerCase(),
REASON_TAG_NAME, reason).increment();
}
}

View File

@ -15,8 +15,6 @@ import jakarta.ws.rs.container.ContainerRequestFilter;
import jakarta.ws.rs.core.SecurityContext;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@ -70,8 +68,8 @@ public class RestDeprecationFilter implements ContainerRequestFilter {
try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
final ClientPlatform platform = userAgent.getPlatform();
final Semver version = userAgent.getVersion();
final ClientPlatform platform = userAgent.platform();
final Semver version = userAgent.version();
if (!minimumRestFreeVersion.containsKey(platform)) {
return;
}

View File

@ -0,0 +1,12 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
/**
* Indicates that a remote channel was not found for a given server call or remote address.
*/
public class ChannelNotFoundException extends Exception {
}

View File

@ -0,0 +1,55 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with
* {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise.
*/
public class ChannelShutdownInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (!grpcClientConnectionManager.handleServerCallStart(call)) {
// Don't allow new calls if the connection is getting ready to close
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {
@Override
public void onComplete() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onComplete();
}
@Override
public void onCancel() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onCancel();
}
};
}
}

View File

@ -10,8 +10,12 @@ import org.whispersystems.textsecuregcm.storage.Device;
public class DeviceIdUtil {
public static boolean isValid(int deviceId) {
return deviceId >= Device.PRIMARY_ID && deviceId <= Byte.MAX_VALUE;
}
static byte validate(int deviceId) {
if (deviceId < Device.PRIMARY_ID || deviceId > Byte.MAX_VALUE) {
if (!isValid(deviceId)) {
throw Status.INVALID_ARGUMENT.withDescription("Device ID is out of range").asRuntimeException();
}

View File

@ -187,7 +187,8 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me
destination,
destinationServiceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId);
registrationIdsByDeviceId,
Optional.empty());
}
@Override
@ -252,7 +253,7 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me
story,
ephemeral,
urgent,
RequestAttributesUtil.getRawUserAgent().orElse(null));
RequestAttributesUtil.getUserAgent().orElse(null));
final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder();

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Status;
import io.grpc.StatusException;
import java.util.Map;
import java.util.Optional;
import org.signal.chat.messages.MismatchedDevices;
import org.signal.chat.messages.SendMessageResponse;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
@ -31,6 +32,8 @@ public class MessagesGrpcHelper {
* @param destinationServiceIdentifier the service identifier for the destination account
* @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
* caller's own account), contains the ID of the device that sent the message
*
* @return a response object to send to callers
*
@ -42,14 +45,17 @@ public class MessagesGrpcHelper {
final Account destination,
final ServiceIdentifier destinationServiceIdentifier,
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId) throws StatusException, RateLimitExceededException {
final Map<Byte, Integer> registrationIdsByDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId)
throws StatusException, RateLimitExceededException {
try {
messageSender.sendMessages(destination,
destinationServiceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
RequestAttributesUtil.getRawUserAgent().orElse(null));
syncMessageSenderDeviceId,
RequestAttributesUtil.getUserAgent().orElse(null));
return SEND_MESSAGE_SUCCESS_RESPONSE;
} catch (final MismatchedDevicesException e) {

View File

@ -172,7 +172,8 @@ public class MessagesGrpcService extends SimpleMessagesGrpc.MessagesImplBase {
destination,
destinationServiceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId);
registrationIdsByDeviceId,
messageType == MessageType.SYNC ? Optional.of(sender.deviceId()) : Optional.empty());
}
private static MessageProtos.Envelope.Type getEnvelopeType(final AuthenticatedSenderMessageType type) {

View File

@ -145,11 +145,13 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase {
request.getCommitment().toByteArray())));
final List<Mono<?>> updates = new ArrayList<>(2);
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, account.getBadges()))
.orElseGet(account::getBadges);
updates.add(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> {
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
.orElseGet(a::getBadges);
a.setBadges(clock, updatedBadges);
a.setCurrentProfileVersion(request.getVersion());
})));

View File

@ -0,0 +1,16 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
public record RequestAttributes(InetAddress remoteAddress,
@Nullable String userAgent,
List<Locale.LanguageRange> acceptLanguage) {
}

View File

@ -2,28 +2,25 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
/**
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
*
* @see RequestAttributesUtil
*/
public class RequestAttributesInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
@ -33,52 +30,12 @@ public class RequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
Context context = Context.current();
{
final Optional<InetAddress> maybeRemoteAddress = grpcClientConnectionManager.getRemoteAddress(localAddress);
if (maybeRemoteAddress.isEmpty()) {
// We should never have a call from a party whose remote address we can't identify
log.warn("No remote address available");
call.close(Status.INTERNAL, new Metadata());
return new ServerCall.Listener<>() {};
}
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
}
{
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
grpcClientConnectionManager.getAcceptableLanguages(localAddress);
if (maybeAcceptLanguage.isPresent()) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
}
}
{
final Optional<String> maybeRawUserAgent =
grpcClientConnectionManager.getRawUserAgent(localAddress);
if (maybeRawUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.RAW_USER_AGENT_CONTEXT_KEY, maybeRawUserAgent.get());
}
}
{
final Optional<UserAgent> maybeUserAgent = grpcClientConnectionManager.getUserAgent(localAddress);
if (maybeUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
}
}
return Contexts.interceptCall(context, call, headers, next);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
try {
return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next);
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
}
}

View File

@ -3,18 +3,13 @@ package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class RequestAttributesUtil {
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language");
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
static final Context.Key<String> RAW_USER_AGENT_CONTEXT_KEY = Context.key("unparsed-user-agent");
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("parsed-user-agent");
static final Context.Key<RequestAttributes> REQUEST_ATTRIBUTES_CONTEXT_KEY = Context.key("request-attributes");
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
@ -23,8 +18,8 @@ public class RequestAttributesUtil {
*
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
*/
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() {
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get());
public static List<Locale.LanguageRange> getAcceptableLanguages() {
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().acceptLanguage();
}
/**
@ -35,9 +30,7 @@ public class RequestAttributesUtil {
* @return a list of distinct locales acceptable to the remote client and available in this JVM
*/
public static List<Locale> getAvailableAcceptedLocales() {
return getAcceptableLanguages()
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
.orElseGet(Collections::emptyList);
return Locale.filter(getAcceptableLanguages(), AVAILABLE_LOCALES);
}
/**
@ -46,16 +39,7 @@ public class RequestAttributesUtil {
* @return the remote address of the remote client
*/
public static InetAddress getRemoteAddress() {
return REMOTE_ADDRESS_CONTEXT_KEY.get();
}
/**
* Returns the parsed user-agent of the remote client in the current gRPC request context.
*
* @return the parsed user-agent of the remote client; may be empty if unparseable or not specified
*/
public static Optional<UserAgent> getUserAgent() {
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().remoteAddress();
}
/**
@ -63,7 +47,7 @@ public class RequestAttributesUtil {
*
* @return the unparsed user-agent of the remote client; may be empty if not specified
*/
public static Optional<String> getRawUserAgent() {
return Optional.ofNullable(RAW_USER_AGENT_CONTEXT_KEY.get());
public static Optional<String> getUserAgent() {
return Optional.ofNullable(REQUEST_ATTRIBUTES_CONTEXT_KEY.get().userAgent());
}
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
public class ServerInterceptorUtil {
@SuppressWarnings("rawtypes")
private static final ServerCall.Listener NO_OP_LISTENER = new ServerCall.Listener<>() {};
private static final Metadata EMPTY_TRAILERS = new Metadata();
private ServerInterceptorUtil() {
}
/**
* Closes the given server call with the given status, returning a no-op listener.
*
* @param call the server call to close
* @param status the status with which to close the call
*
* @return a no-op server call listener
*
* @param <ReqT> the type of request object handled by the server call
* @param <RespT> the type of response object returned by the server call
*/
public static <ReqT, RespT> ServerCall.Listener<ReqT> closeWithStatus(final ServerCall<ReqT, RespT> call, final Status status) {
call.close(status, EMPTY_TRAILERS);
//noinspection unchecked
return NO_OP_LISTENER;
}
}

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.internalError;
import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.Message;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.ServerCall;
@ -75,7 +75,7 @@ public class ValidatingInterceptor implements ServerInterceptor {
}
private void validateMessage(final Object message) throws StatusException {
if (message instanceof GeneratedMessageV3 msg) {
if (message instanceof Message msg) {
try {
for (final Descriptors.FieldDescriptor fd: msg.getDescriptorForType().getFields()) {
for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry: fd.getOptions().getAllFields().entrySet()) {

View File

@ -1,7 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
/**
* Indicates that an attempt to authenticate a remote client failed for some reason.
*/
class ClientAuthenticationException extends Exception {
}

View File

@ -3,60 +3,39 @@ package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import javax.crypto.BadPaddingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a
* WebSocket handshake, the error handler will send appropriate WebSocket closure codes to the client in an attempt to
* identify the problem. If the client has not completed a WebSocket handshake, the handler simply closes the
* connection.
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. It translates exceptions
* thrown in inbound handlers into {@link OutboundCloseErrorMessage}s.
*/
class ErrorHandler extends ChannelInboundHandlerAdapter {
private boolean websocketHandshakeComplete = false;
public class ErrorHandler extends ChannelInboundHandlerAdapter {
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
setWebsocketHandshakeComplete();
}
context.fireUserEventTriggered(event);
}
protected void setWebsocketHandshakeComplete() {
this.websocketHandshakeComplete = true;
}
private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.NOISE_ERROR,
"Noise encryption error");
@Override
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
if (websocketHandshakeComplete) {
final WebSocketCloseStatus webSocketCloseStatus = switch (ExceptionUtils.unwrap(cause)) {
case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage());
case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated");
case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
case NoiseException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
final OutboundCloseErrorMessage closeMessage = switch (ExceptionUtils.unwrap(cause)) {
case NoiseHandshakeException e -> new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR,
e.getMessage());
case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
default -> {
log.warn("An unexpected exception reached the end of the pipeline", cause);
yield WebSocketCloseStatus.INTERNAL_SERVER_ERROR;
yield new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.INTERNAL_SERVER_ERROR,
cause.getMessage());
}
};
context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus))
context.writeAndFlush(closeMessage)
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} else {
log.debug("Error occurred before websocket handshake complete", cause);
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
// way; just close the connection instead.
context.close();
}
}
}

View File

@ -7,20 +7,21 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.util.ReferenceCountUtil;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
/**
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
* any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
* server.
*/
class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
private final GrpcClientConnectionManager grpcClientConnectionManager;
@ -48,15 +49,20 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
@Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
if (event instanceof NoiseIdentityDeterminedEvent(
final Optional<AuthenticatedDevice> authenticatedDevice,
InetAddress remoteAddress, String userAgent, String acceptLanguage)) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
// authenticated service.
final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent()
final LocalAddress grpcServerAddress = authenticatedDevice.isPresent()
? authenticatedGrpcServerAddress
: anonymousGrpcServerAddress;
GrpcClientConnectionManager.handleHandshakeInitiated(
remoteChannelContext.channel(), remoteAddress, userAgent, acceptLanguage);
new Bootstrap()
.remoteAddress(grpcServerAddress)
.channel(LocalChannel.class)
@ -72,12 +78,14 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
if (localChannelFuture.isSuccess()) {
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(),
noiseIdentityDeterminedEvent.authenticatedDevice());
authenticatedDevice);
// Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
remoteChannelContext.channel()
.write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
remoteChannelContext.pipeline()
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));

View File

@ -1,11 +1,12 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Grpc;
import io.grpc.ServerCall;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.util.AttributeKey;
import java.net.InetAddress;
import java.util.ArrayList;
@ -23,15 +24,26 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.ClosableEpoch;
/**
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate.
* Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
* identity of the device that opened the connection (for non-anonymous connections). It can also close connections
* associated with a given device if that device's credentials have changed and clients must reauthenticate.
* <p>
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
* narrow window between a server call being created and the start of call execution, in which case accessor methods
* in this class will throw a {@link ChannelNotFoundException}.
* <p>
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
* be called from any application code.
*/
public class GrpcClientConnectionManager implements DisconnectionRequestListener {
@ -43,94 +55,96 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
@VisibleForTesting
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress");
public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
@VisibleForTesting
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent");
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
@VisibleForTesting
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
@VisibleForTesting
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
private static OutboundCloseErrorMessage SERVER_CLOSED =
new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed");
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
/**
* Returns the authenticated device associated with the given local address, if any. An authenticated device is
* available if and only if the given local address maps to an active local connection and that connection is
* authenticated (i.e. not anonymous).
* Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
* (i.e. unauthenticated), the returned value will be empty.
*
* @param localAddress the local address for which to find an authenticated device
* @param serverCall the gRPC server call for which to find an authenticated device
*
* @return the authenticated device associated with the given local address, if any
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) {
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress));
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall)
throws ChannelNotFoundException {
return getAuthenticatedDevice(getRemoteChannel(serverCall));
}
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) {
return Optional.ofNullable(remoteChannel)
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
@VisibleForTesting
Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) {
return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
}
/**
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may
* be unavailable if the local connection associated with the given local address has already closed, if the client
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
* Returns the request attributes associated with the given server call.
*
* @param localAddress the local address for which to find acceptable languages
* @param serverCall the gRPC server call for which to retrieve request attributes
*
* @return the acceptable languages associated with the given local address, if any
* @return the request attributes associated with the given server call
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRequestAttributes(getRemoteChannel(serverCall));
}
@VisibleForTesting
RequestAttributes getRequestAttributes(final Channel remoteChannel) {
final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
if (requestAttributes == null) {
throw new IllegalStateException("Channel does not have request attributes");
}
return requestAttributes;
}
/**
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if
* the local connection associated with the given local address has already closed.
* Handles the start of a server call, incrementing the active call count for the remote channel associated with the
* given server call.
*
* @param localAddress the local address for which to find a remote address
* @param serverCall the server call to start
*
* @return the remote address associated with the given local address, if any
* @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the
* underlying channel is closing
*/
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
public boolean handleServerCallStart(final ServerCall<?, ?> serverCall) {
try {
return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive();
} catch (final ChannelNotFoundException e) {
// This would only happen if the channel had already closed, which is certainly possible. In this case, the call
// should certainly not proceed.
return false;
}
}
/**
* Returns the unparsed user agent provided by the client that opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address.
* Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel
* associated with the given server call.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent string associated with the given local address
* @param serverCall the server call to complete
*/
public Optional<String> getRawUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(RAW_USER_AGENT_ATTRIBUTE_KEY).get());
public void handleServerCallComplete(final ServerCall<?, ?> serverCall) {
try {
getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart();
} catch (final ChannelNotFoundException ignored) {
// In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
}
/**
* Returns the parsed user agent provided by the client that opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address or if the client's user-agent string was not recognized.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent associated with the given local address
*/
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
/**
@ -145,10 +159,11 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
final List<Channel> channelsToClose =
new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
channelsToClose.forEach(channel ->
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
.toWebSocketCloseStatus("Reauthentication required")))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
}
private static void closeRemoteChannel(final Channel channel) {
channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
@VisibleForTesting
@ -156,53 +171,66 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
}
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRemoteChannel(getLocalAddress(serverCall));
}
@VisibleForTesting
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) {
Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
if (remoteChannel == null) {
throw new ChannelNotFoundException();
}
return remoteChannelsByLocalAddress.get(localAddress);
}
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
// a proxied Noise channel.
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
return localAddress;
}
/**
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
* Handles receipt of a handshake message and associates attributes and headers from the handshake
* request with the channel via which the handshake took place.
*
* @param channel the channel that completed a WebSocket handshake
* @param channel the channel where the handshake was initiated
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null}
*/
static void handleWebSocketHandshakeComplete(final Channel channel,
public static void handleHandshakeInitiated(final Channel channel,
final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) {
channel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress);
if (StringUtils.isNotBlank(userAgentHeader)) {
channel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
try {
channel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
} catch (final UnrecognizedUserAgentException ignored) {
}
}
@Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try {
channel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader));
acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
} catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
}
}
channel.attr(REQUEST_ATTRIBUTES_KEY)
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
}
/**
* Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server.
* Handles successful establishment of a Noise connection from a remote client to a local gRPC server.
*
* @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC
* server
* @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel
* @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server
* @param remoteChannel the channel from the remote client to the Noise tunnel
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
*/
void handleConnectionEstablished(final LocalChannel localChannel,
@ -212,6 +240,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
.set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->

View File

@ -4,7 +4,7 @@
*/
package org.whispersystems.textsecuregcm.grpc.net;
enum HandshakePattern {
public enum HandshakePattern {
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");

View File

@ -1,34 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
/**
* A NoiseAnonymousHandler is a netty pipeline element that handles the responder side of an unauthenticated handshake
* and noise encryption/decryption.
* <p>
* A noise NK handshake must be used for unauthenticated connections. Optionally, the initiator can also include an
* initial request in their payload. If provided, this allows the server to begin processing the request without an
* initial message delay (fast open).
* <p>
* Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent}
* indicating that initiator connected anonymously.
*/
class NoiseAnonymousHandler extends NoiseHandler {
public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) {
super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair));
}
@Override
CompletableFuture<HandshakeResult> handleHandshakePayload(final ChannelHandlerContext context,
final Optional<byte[]> initiatorPublicKey, final ByteBuf handshakePayload) {
return CompletableFuture.completedFuture(new HandshakeResult(
handshakePayload,
Optional.empty()
));
}
}

View File

@ -1,96 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.ReferenceCountUtil;
import java.security.MessageDigest;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* A NoiseAuthenticatedHandler is a netty pipeline element that handles the responder side of an authenticated handshake
* and noise encryption/decryption. Authenticated handshakes are noise IK handshakes where the initiator's static public
* key is authenticated by the responder.
* <p>
* The authenticated handshake requires the initiator to provide a payload with their first handshake message that
* includes their account identifier and device id in network byte-order. Optionally, the initiator can also include an
* initial request in their payload. If provided, this allows the server to begin processing the request without an
* initial message delay (fast open).
* <pre>
* +-----------------+----------------+------------------------+
* | UUID (16) | deviceId (1) | request bytes (N) |
* +-----------------+----------------+------------------------+
* </pre>
* <p>
* For a successful handshake, the static key provided in the handshake message must match the server's stored public
* key for the device identified by the provided ACI and deviceId.
* <p>
* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}.
*/
class NoiseAuthenticatedHandler extends NoiseHandler {
private final ClientPublicKeysManager clientPublicKeysManager;
NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair) {
super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair));
this.clientPublicKeysManager = clientPublicKeysManager;
}
@Override
CompletableFuture<HandshakeResult> handleHandshakePayload(
final ChannelHandlerContext context,
final Optional<byte[]> initiatorPublicKey,
final ByteBuf handshakePayload) throws NoiseHandshakeException {
if (handshakePayload.readableBytes() < 17) {
throw new NoiseHandshakeException("Invalid handshake payload");
}
final byte[] publicKeyFromClient = initiatorPublicKey
.orElseThrow(() -> new IllegalStateException("No remote public key"));
// Advances the read index by 16 bytes
final UUID accountIdentifier = parseUUID(handshakePayload);
// Advances the read index by 1 byte
final byte deviceId = handshakePayload.readByte();
final ByteBuf fastOpenRequest = handshakePayload.slice();
return clientPublicKeysManager
.findPublicKey(accountIdentifier, deviceId)
.handleAsync((storedPublicKey, throwable) -> {
if (throwable != null) {
ReferenceCountUtil.release(fastOpenRequest);
throw ExceptionUtils.wrap(throwable);
}
final boolean valid = storedPublicKey
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
.orElse(false);
if (!valid) {
throw ExceptionUtils.wrap(new ClientAuthenticationException());
}
return new HandshakeResult(
fastOpenRequest,
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
}, context.executor());
}
/**
* Parse a {@link UUID} out of bytes, advancing the readerIdx by 16
*
* @param bytes The {@link ByteBuf} to read from
* @return The parsed UUID
* @throws NoiseHandshakeException If a UUID could not be parsed from bytes
*/
private UUID parseUUID(final ByteBuf bytes) throws NoiseHandshakeException {
if (bytes.readableBytes() < 16) {
throw new NoiseHandshakeException("Could not parse account identifier");
}
return new UUID(bytes.readLong(), bytes.readLong());
}
}

View File

@ -1,10 +1,12 @@
package org.whispersystems.textsecuregcm.grpc.net;
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
/**
* Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/
* format or a general encryption error).
*/
class NoiseException extends Exception {
public class NoiseException extends NoStackTraceException {
public NoiseException(final String message) {
super(message);
}

View File

@ -11,135 +11,50 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.PromiseCombiner;
import io.netty.util.internal.EmptyArrays;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts
* inbound messages, and encrypts outbound messages
* A bidirectional {@link io.netty.channel.ChannelHandler} that decrypts inbound messages, and encrypts outbound
* messages
*/
abstract class NoiseHandler extends ChannelDuplexHandler {
public class NoiseHandler extends ChannelDuplexHandler {
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
private final CipherStatePair cipherStatePair;
private enum State {
// Waiting for handshake to complete
HANDSHAKE,
// Can freely exchange encrypted noise messages on an established session
TRANSPORT,
// Finished with error
ERROR
NoiseHandler(CipherStatePair cipherStatePair) {
this.cipherStatePair = cipherStatePair;
}
private final NoiseHandshakeHelper handshakeHelper;
private State state = State.HANDSHAKE;
private CipherStatePair cipherStatePair;
NoiseHandler(NoiseHandshakeHelper handshakeHelper) {
this.handshakeHelper = handshakeHelper;
}
/**
* The result of processing an initiator handshake payload
*
* @param fastOpenRequest A fast-open request included in the handshake. If none was present, this should be an
* empty ByteBuf
* @param authenticatedDevice If present, the successfully authenticated initiator identity
*/
record HandshakeResult(ByteBuf fastOpenRequest, Optional<AuthenticatedDevice> authenticatedDevice) {}
/**
* Parse and potentially authenticate the initiator handshake message
*
* @param context A {@link ChannelHandlerContext}
* @param initiatorPublicKey The initiator's static public key, if a handshake pattern that includes it was used
* @param handshakePayload The handshake payload provided in the initiator message
* @return A {@link HandshakeResult} that includes an authenticated device and a parsed fast-open request if one was
* present in the handshake payload.
* @throws NoiseHandshakeException If the handshake payload was invalid
* @throws ClientAuthenticationException If the initiatorPublicKey could not be authenticated
*/
abstract CompletableFuture<HandshakeResult> handleHandshakePayload(
final ChannelHandlerContext context,
final Optional<byte[]> initiatorPublicKey,
final ByteBuf handshakePayload) throws NoiseHandshakeException, ClientAuthenticationException;
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (message instanceof BinaryWebSocketFrame frame) {
if (frame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
final String error = "Invalid noise message length " + frame.content().readableBytes();
throw state == State.HANDSHAKE ? new NoiseHandshakeException(error) : new NoiseException(error);
if (message instanceof ByteBuf frame) {
if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
throw new NoiseException("Invalid noise message length " + frame.readableBytes());
}
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
// We'll need to copy it to a heap buffer.
handleInboundMessage(context, ByteBufUtil.getBytes(frame.content()));
handleInboundDataMessage(context, ByteBufUtil.getBytes(frame));
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// error
// Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
} catch (Exception e) {
fail(context, e);
} finally {
ReferenceCountUtil.release(message);
}
}
private void handleInboundMessage(final ChannelHandlerContext context, final byte[] frameBytes)
throws NoiseHandshakeException, ShortBufferException, BadPaddingException, ClientAuthenticationException {
switch (state) {
// Got an initiator handshake message
case HANDSHAKE -> {
final ByteBuf payload = handshakeHelper.read(frameBytes);
handleHandshakePayload(context, handshakeHelper.remotePublicKey(), payload).whenCompleteAsync(
(result, throwable) -> {
if (state == State.ERROR) {
return;
}
if (throwable != null) {
fail(context, ExceptionUtils.unwrap(throwable));
return;
}
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(result.authenticatedDevice()));
// Now that we've authenticated, write the handshake response
byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES);
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
this.state = State.TRANSPORT;
this.cipherStatePair = handshakeHelper.getHandshakeState().split();
if (result.fastOpenRequest().isReadable()) {
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
// encrypt the response when the server writes back through us
context.fireChannelRead(result.fastOpenRequest());
} else {
ReferenceCountUtil.release(result.fastOpenRequest());
}
}, context.executor());
}
// Got a client message that should be decrypted and forwarded
case TRANSPORT -> {
private void handleInboundDataMessage(final ChannelHandlerContext context, final byte[] frameBytes)
throws ShortBufferException, BadPaddingException {
final CipherState cipherState = cipherStatePair.getReceiver();
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
final int plaintextLength = cipherState.decryptWithAd(null,
@ -151,21 +66,6 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
}
// The session is already in an error state, drop the message
case ERROR -> {
}
}
}
/**
* Set the state to the error state (so subsequent messages fast-fail) and propagate the failure reason on the
* context
*/
private void fail(final ChannelHandlerContext context, final Throwable cause) {
this.state = State.ERROR;
context.fireExceptionCaught(cause);
}
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
throws Exception {
@ -193,19 +93,27 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer))));
pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer)));
}
pc.finish(promise);
} finally {
ReferenceCountUtil.release(byteBuf);
}
} else {
if (!(message instanceof WebSocketFrame)) {
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
// get issued in response to exceptions)
if (!(message instanceof OutboundCloseErrorMessage)) {
// Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames
// that get issued in response to exceptions)
log.warn("Unexpected object in pipeline: {}", message);
}
context.write(message, promise);
}
}
@Override
public void handlerRemoved(ChannelHandlerContext var1) {
if (cipherStatePair != null) {
cipherStatePair.destroy();
}
}
}

View File

@ -1,10 +1,12 @@
package org.whispersystems.textsecuregcm.grpc.net;
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
/**
* Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
* a general encryption error).
*/
class NoiseHandshakeException extends Exception {
public class NoiseHandshakeException extends NoStackTraceException {
public NoiseHandshakeException(final String message) {
super(message);

View File

@ -0,0 +1,197 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import java.io.IOException;
import java.net.InetAddress;
import java.security.MessageDigest;
import java.util.Optional;
import java.util.UUID;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.DeviceIdUtil;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
/**
* Handles the responder side of a noise handshake and then replaces itself with a {@link NoiseHandler} which will
* encrypt/decrypt subsequent data frames
* <p>
* The handler expects to receive a single inbound message, a {@link NoiseHandshakeInit} that includes the initiator
* handshake message, connection metadata, and the type of handshake determined by the framing layer. This handler
* currently supports two types of handshakes.
* <p>
* The first are IK handshakes where the initiator's static public key is authenticated by the responder. The initiator
* handshake message must contain the ACI and deviceId of the initiator. To be authenticated, the static key provided in
* the handshake message must match the server's stored public key for the device identified by the provided ACI and
* deviceId.
* <p>
* The second are NK handshakes which are anonymous.
* <p>
* Optionally, the initiator can also include an initial request in their payload. If provided, this allows the server
* to begin processing the request without an initial message delay (fast open).
* <p>
* Once the handshake has been validated, a {@link NoiseIdentityDeterminedEvent} will be fired. For an IK handshake,
* this will include the {@link org.whispersystems.textsecuregcm.auth.AuthenticatedDevice} of the initiator. This
* handler will then replace itself with a {@link NoiseHandler} with a noise state pair ready to encrypt/decrypt data
* frames.
*/
public class NoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
private static final byte[] HANDSHAKE_WRONG_PK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY)
.build().toByteArray();
private static final byte[] HANDSHAKE_OK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
.build().toByteArray();
// We might get additional messages while we're waiting to process a handshake, so keep track of where we are
private boolean receivedHandshakeInit = false;
private final ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair ecKeyPair;
public NoiseHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
this.clientPublicKeysManager = clientPublicKeysManager;
this.ecKeyPair = ecKeyPair;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (!(message instanceof NoiseHandshakeInit handshakeInit)) {
// Anything except HandshakeInit should have been filtered out of the pipeline by now; treat this as an error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
if (receivedHandshakeInit) {
throw new NoiseHandshakeException("Should not receive messages until handshake complete");
}
receivedHandshakeInit = true;
if (handshakeInit.content().readableBytes() > Noise.MAX_PACKET_LEN) {
throw new NoiseHandshakeException("Invalid noise message length " + handshakeInit.content().readableBytes());
}
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
// We'll need to copy it to a heap buffer
handleInboundHandshake(context,
handshakeInit.getRemoteAddress(),
handshakeInit.getHandshakePattern(),
ByteBufUtil.getBytes(handshakeInit.content()));
} finally {
ReferenceCountUtil.release(message);
}
}
private void handleInboundHandshake(
final ChannelHandlerContext context,
final InetAddress remoteAddress,
final HandshakePattern handshakePattern,
final byte[] frameBytes) throws NoiseHandshakeException {
final NoiseHandshakeHelper handshakeHelper = new NoiseHandshakeHelper(handshakePattern, ecKeyPair);
final ByteBuf payload = handshakeHelper.read(frameBytes);
// Parse the handshake message
final NoiseTunnelProtos.HandshakeInit handshakeInit;
try {
handshakeInit = NoiseTunnelProtos.HandshakeInit.parseFrom(new ByteBufInputStream(payload));
} catch (IOException e) {
throw new NoiseHandshakeException("Failed to parse handshake message");
}
switch (handshakePattern) {
case NK -> {
if (handshakeInit.getDeviceId() != 0 || !handshakeInit.getAci().isEmpty()) {
throw new NoiseHandshakeException("Anonymous handshake should not include identifiers");
}
handleAuthenticated(context, handshakeHelper, remoteAddress, handshakeInit, Optional.empty());
}
case IK -> {
final byte[] publicKeyFromClient = handshakeHelper.remotePublicKey()
.orElseThrow(() -> new IllegalStateException("No remote public key"));
final UUID accountIdentifier = aci(handshakeInit);
final byte deviceId = deviceId(handshakeInit);
clientPublicKeysManager
.findPublicKey(accountIdentifier, deviceId)
.whenCompleteAsync((storedPublicKey, throwable) -> {
if (throwable != null) {
context.fireExceptionCaught(ExceptionUtils.unwrap(throwable));
return;
}
final boolean valid = storedPublicKey
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
.orElse(false);
if (!valid) {
// Write a handshake response indicating that the client used the wrong public key
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_WRONG_PK);
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
context.fireExceptionCaught(new NoiseHandshakeException("Bad public key"));
return;
}
handleAuthenticated(context,
handshakeHelper, remoteAddress, handshakeInit,
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
}, context.executor());
}
};
}
private void handleAuthenticated(final ChannelHandlerContext context,
final NoiseHandshakeHelper handshakeHelper,
final InetAddress remoteAddress,
final NoiseTunnelProtos.HandshakeInit handshakeInit,
final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(
maybeAuthenticatedDevice,
remoteAddress,
handshakeInit.getUserAgent(),
handshakeInit.getAcceptLanguage()));
// Now that we've authenticated, write the handshake response
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_OK);
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
// Note: It may be tempting to swap the before/remove for a replace, but then when we forward the fast open
// request it will go through the NoiseHandler. We want to skip the NoiseHandler because we've already
// decrypted the fastOpen request
context.pipeline()
.addBefore(context.name(), null, new NoiseHandler(handshakeHelper.getHandshakeState().split()));
context.pipeline().remove(NoiseHandshakeHandler.class);
if (!handshakeInit.getFastOpenRequest().isEmpty()) {
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
// encrypt the response when the server writes back through us
context.fireChannelRead(Unpooled.wrappedBuffer(handshakeInit.getFastOpenRequest().asReadOnlyByteBuffer()));
}
}
private static UUID aci(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
try {
return UUIDUtil.fromByteString(handshakePayload.getAci());
} catch (IllegalArgumentException e) {
throw new NoiseHandshakeException("Could not parse aci");
}
}
private static byte deviceId(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
if (!DeviceIdUtil.isValid(handshakePayload.getDeviceId())) {
throw new NoiseHandshakeException("Invalid deviceId");
}
return (byte) handshakePayload.getDeviceId();
}
}

View File

@ -0,0 +1,33 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.DefaultByteBufHolder;
import java.net.InetAddress;
/**
* A message that includes the initiator's handshake message, connection metadata, and the handshake type. The metadata
* and handshake type are extracted from the framing layer, so this allows receivers to be framing layer agnostic.
*/
public class NoiseHandshakeInit extends DefaultByteBufHolder {
private final InetAddress remoteAddress;
private final HandshakePattern handshakePattern;
public NoiseHandshakeInit(
final InetAddress remoteAddress,
final HandshakePattern handshakePattern,
final ByteBuf initiatorHandshakeMessage) {
super(initiatorHandshakeMessage);
this.remoteAddress = remoteAddress;
this.handshakePattern = handshakePattern;
}
public InetAddress getRemoteAddress() {
return remoteAddress;
}
public HandshakePattern getHandshakePattern() {
return handshakePattern;
}
}

View File

@ -1,5 +1,6 @@
package org.whispersystems.textsecuregcm.grpc.net;
import java.net.InetAddress;
import java.util.Optional;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
@ -9,5 +10,12 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
*
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
* type that performs authentication
* @param remoteAddress the remote address of the connecting client
* @param userAgent the client supplied userAgent
* @param acceptLanguage the client supplied acceptLanguage
*/
record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}
public record NoiseIdentityDeterminedEvent(
Optional<AuthenticatedDevice> authenticatedDevice,
InetAddress remoteAddress,
String userAgent,
String acceptLanguage) {}

View File

@ -0,0 +1,31 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
/**
* An error written to the outbound pipeline that indicates the connection should be closed
*/
public record OutboundCloseErrorMessage(Code code, String message) {
public enum Code {
/**
* The server decided to close the connection. This could be because the server is going away, or it could be
* because the credentials for the connected client have been updated.
*/
SERVER_CLOSED,
/**
* There was a noise decryption error after the noise session was established
*/
NOISE_ERROR,
/**
* There was an error establishing the noise handshake
*/
NOISE_HANDSHAKE_ERROR,
INTERNAL_SERVER_ERROR
}
}

View File

@ -8,7 +8,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
/**
* A proxy handler writes all data read from one channel to another peer channel.
*/
class ProxyHandler extends ChannelInboundHandlerAdapter {
public class ProxyHandler extends ChannelInboundHandlerAdapter {
private final Channel peerChannel;

View File

@ -0,0 +1,45 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
/**
* In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the
* noise packet to the noise layer as a {@link ByteBuf} for decryption.
* <p>
* In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper
* so it can be wire serialized. This handler assumes the first outbound message received will correspond to the
* handshake response, and then the subsequent messages are all data frame payloads.
*/
public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame frame) {
if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) {
ReferenceCountUtil.release(msg);
throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType());
}
ctx.fireChannelRead(frame.content());
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
if (msg instanceof ByteBuf bb) {
ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise);
} else {
ctx.write(msg, promise);
}
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.DefaultByteBufHolder;
public class NoiseDirectFrame extends DefaultByteBufHolder {
static final byte VERSION = 0x00;
private final FrameType frameType;
public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) {
super(data);
this.frameType = frameType;
}
public FrameType frameType() {
return frameType;
}
public byte versionedFrameTypeByte() {
final byte frameBits = frameType().getFrameBits();
return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits);
}
public enum FrameType {
/**
* The payload is the initiator message for a Noise NK handshake. If established, the
* session will be unauthenticated.
*/
NK_HANDSHAKE((byte) 1),
/**
* The payload is the initiator message for a Noise IK handshake. If established, the
* session will be authenticated.
*/
IK_HANDSHAKE((byte) 2),
/**
* The payload is an encrypted noise packet.
*/
DATA((byte) 3),
/**
* A frame sent before the connection is closed. The payload is a protobuf indicating why the connection is being
* closed.
*/
CLOSE((byte) 4);
private final byte frameType;
FrameType(byte frameType) {
if (frameType != (0x0F & frameType)) {
throw new IllegalStateException("Frame type must fit in 4 bits");
}
this.frameType = frameType;
}
public byte getFrameBits() {
return frameType;
}
public boolean isHandshake() {
return switch (this) {
case IK_HANDSHAKE, NK_HANDSHAKE -> true;
case DATA, CLOSE -> false;
};
}
}
}

View File

@ -0,0 +1,90 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
/**
* Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes
* have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder}
*/
public class NoiseDirectFrameCodec extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof ByteBuf byteBuf) {
try {
ctx.fireChannelRead(deserialize(byteBuf));
} catch (Exception e) {
ReferenceCountUtil.release(byteBuf);
throw e;
}
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
if (msg instanceof NoiseDirectFrame noiseDirectFrame) {
try {
// Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the
// network, nothing should have to make another copy of this. If later another layer is added, it may be more
// efficient to reuse the input buffer (typically not direct) by using a composite byte buffer
final ByteBuf serialized = serialize(ctx, noiseDirectFrame);
ctx.writeAndFlush(serialized, promise);
} finally {
ReferenceCountUtil.release(noiseDirectFrame);
}
} else {
ctx.write(msg, promise);
}
}
private ByteBuf serialize(
final ChannelHandlerContext ctx,
final NoiseDirectFrame noiseDirectFrame) {
if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes());
}
// 1 version/frametype byte, 2 length bytes, content
final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes());
byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte());
byteBuf.writeShort(noiseDirectFrame.content().readableBytes());
byteBuf.writeBytes(noiseDirectFrame.content());
return byteBuf;
}
private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception {
final byte versionAndFrameByte = byteBuf.readByte();
final int version = (versionAndFrameByte & 0xF0) >> 4;
if (version != NoiseDirectFrame.VERSION) {
throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version);
}
final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F);
final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) {
case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
case 3 -> NoiseDirectFrame.FrameType.DATA;
case 4 -> NoiseDirectFrame.FrameType.CLOSE;
default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
};
final int length = Short.toUnsignedInt(byteBuf.readShort());
if (length != byteBuf.readableBytes()) {
throw new IllegalArgumentException(
"Payload length did not match remaining buffer, should have been guaranteed by a previous handler");
}
return new NoiseDirectFrame(frameType, byteBuf.readSlice(length));
}
}

View File

@ -0,0 +1,53 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import java.io.IOException;
import java.net.InetSocketAddress;
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
/**
* Waits for a Handshake {@link NoiseDirectFrame} and then replaces itself with a {@link NoiseDirectDataFrameCodec} and
* forwards the handshake frame along as a {@link NoiseHandshakeInit} message
*/
public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame frame) {
try {
if (!(ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress)) {
throw new IOException("Could not determine remote address");
}
// We've received an inbound handshake frame. Pull the framing-protocol specific data the downstream handler
// needs into a NoiseHandshakeInit message and forward that along
final NoiseHandshakeInit handshakeMessage = new NoiseHandshakeInit(inetSocketAddress.getAddress(),
switch (frame.frameType()) {
case DATA -> throw new NoiseHandshakeException("First message must have handshake frame type");
case CLOSE -> throw new IllegalStateException("Close frames should not reach handshake selector");
case IK_HANDSHAKE -> HandshakePattern.IK;
case NK_HANDSHAKE -> HandshakePattern.NK;
}, frame.content());
// Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
// should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
// for network serialization. Note that we need to install the Data frame handler before firing the read,
// because we may receive an outbound message from the noiseHandler
ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
ctx.fireChannelRead(handshakeMessage);
} catch (Exception e) {
ReferenceCountUtil.release(msg);
throw e;
}
} else {
ctx.fireChannelRead(msg);
}
}
}

View File

@ -0,0 +1,36 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.micrometer.core.instrument.Metrics;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
/**
* Watches for inbound close frames and closes the connection in response
*/
public class NoiseDirectInboundCloseHandler extends ChannelInboundHandlerAdapter {
private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(ChannelInboundHandlerAdapter.class, "clientClose");
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
try {
final NoiseDirectProtos.CloseReason closeReason = NoiseDirectProtos.CloseReason
.parseFrom(ByteBufUtil.getBytes(ndf.content()));
Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "reason", closeReason.getCode().name()).increment();
} finally {
ReferenceCountUtil.release(msg);
ctx.close();
}
} else {
ctx.fireChannelRead(msg);
}
}
}

View File

@ -0,0 +1,38 @@
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
/**
* Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are
* written, the channel is closed
*/
class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof OutboundCloseErrorMessage err) {
final NoiseDirectProtos.CloseReason.Code code = switch (err.code()) {
case SERVER_CLOSED -> NoiseDirectProtos.CloseReason.Code.UNAVAILABLE;
case NOISE_ERROR -> NoiseDirectProtos.CloseReason.Code.ENCRYPTION_ERROR;
case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.CloseReason.Code.HANDSHAKE_ERROR;
case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.CloseReason.Code.INTERNAL_ERROR;
};
final NoiseDirectProtos.CloseReason proto = NoiseDirectProtos.CloseReason.newBuilder()
.setCode(code)
.setMessage(err.message())
.build();
final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
proto.writeTo(new ByteBufOutputStream(byteBuf));
ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.CLOSE, byteBuf))
.addListener(ChannelFutureListener.CLOSE);
} else {
ctx.write(msg, promise);
}
}
}

View File

@ -0,0 +1,93 @@
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import com.google.common.annotations.VisibleForTesting;
import com.southernstorm.noise.protocol.Noise;
import io.dropwizard.lifecycle.Managed;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import java.net.InetSocketAddress;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler;
import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeHandler;
import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
* A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom
* binary framing protocol) and passes it through to a local gRPC server.
*/
public class NoiseDirectTunnelServer implements Managed {
private final ServerBootstrap bootstrap;
private ServerSocketChannel channel;
private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class);
public NoiseDirectTunnelServer(final int port,
final NioEventLoopGroup eventLoopGroup,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress) {
this.bootstrap = new ServerBootstrap()
.group(eventLoopGroup)
.channel(NioServerSocketChannel.class)
.localAddress(port)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel socketChannel) {
socketChannel.pipeline()
.addLast(new ProxyProtocolDetectionHandler())
.addLast(new HAProxyMessageHandler())
// frame byte followed by a 2-byte length field
.addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
// Parses NoiseDirectFrames from wire bytes and vice versa
.addLast(new NoiseDirectFrameCodec())
// Terminate the connection if the client sends us a close frame
.addLast(new NoiseDirectInboundCloseHandler())
// Turn generic OutboundCloseErrorMessages into noise direct error frames
.addLast(new NoiseDirectOutboundErrorHandler())
// Forwards the first payload supplemented with handshake metadata, and then replaces itself with a
// NoiseDirectDataFrameCodec to handle subsequent data frames
.addLast(new NoiseDirectHandshakeSelector())
// Performs the noise handshake and then replace itself with a NoiseHandler
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
// once the Noise handshake has completed
.addLast(new EstablishLocalGrpcConnectionHandler(
grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
.addLast(new ErrorHandler());
}
});
}
@VisibleForTesting
public InetSocketAddress getLocalAddress() {
return channel.localAddress();
}
@Override
public void start() throws InterruptedException {
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
}
@Override
public void stop() throws InterruptedException {
if (channel != null) {
channel.close().await();
}
}
}

View File

@ -1,12 +1,10 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
enum ApplicationWebSocketCloseReason {
NOISE_HANDSHAKE_ERROR(4001),
CLIENT_AUTHENTICATION_ERROR(4002),
NOISE_ENCRYPTION_ERROR(4003),
REAUTHENTICATION_REQUIRED(4004);
NOISE_ENCRYPTION_ERROR(4002);
private final int statusCode;
@ -17,8 +15,4 @@ enum ApplicationWebSocketCloseReason {
public int getStatusCode() {
return statusCode;
}
WebSocketCloseStatus toWebSocketCloseStatus(final String reason) {
return new WebSocketCloseStatus(statusCode, reason);
}
}

View File

@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import com.google.common.annotations.VisibleForTesting;
import com.southernstorm.noise.protocol.Noise;
@ -28,6 +28,7 @@ import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.*;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
@ -103,10 +104,16 @@ public class NoiseWebSocketTunnelServer implements Managed {
// request and passed it down the pipeline
.addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH))
.addLast(new WebSocketServerProtocolHandler("/", true))
// Turn generic OutboundCloseErrorMessages into websocket close frames
.addLast(new WebSocketOutboundErrorHandler())
.addLast(new RejectUnsupportedMessagesHandler())
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
// a WebSocket handshake has been completed
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret))
.addLast(new WebsocketPayloadCodec())
// The WebSocket handshake complete listener will forward the first payload supplemented with
// data from the websocket handshake completion event, and then remove itself from the pipeline
.addLast(new WebsocketHandshakeCompleteHandler(recognizedProxySecret))
// The NoiseHandshakeHandler will perform the noise handshake and then replace itself with a
// NoiseHandler
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
// once the Noise handshake has completed
.addLast(new EstablishLocalGrpcConnectionHandler(grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))

View File

@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

View File

@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;

View File

@ -0,0 +1,58 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
/**
* Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
*/
class WebSocketOutboundErrorHandler extends ChannelDuplexHandler {
private boolean websocketHandshakeComplete = false;
private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class);
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
setWebsocketHandshakeComplete();
}
context.fireUserEventTriggered(event);
}
protected void setWebsocketHandshakeComplete() {
this.websocketHandshakeComplete = true;
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof OutboundCloseErrorMessage err) {
if (websocketHandshakeComplete) {
final int status = switch (err.code()) {
case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
};
ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} else {
log.debug("Error {} occurred before websocket handshake complete", err);
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
// way; just close the connection instead.
ctx.close();
}
} else {
ctx.write(msg, promise);
}
}
}

View File

@ -1,15 +1,15 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.InetAddresses;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.ReferenceCountUtil;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
@ -17,10 +17,10 @@ import java.security.MessageDigest;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
/**
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
@ -28,10 +28,6 @@ import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
*/
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
private final ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair ecKeyPair;
private final byte[] recognizedProxySecret;
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
@ -42,12 +38,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
@VisibleForTesting
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final String recognizedProxySecret) {
private InetAddress remoteAddress = null;
private HandshakePattern handshakePattern = null;
this.clientPublicKeysManager = clientPublicKeysManager;
this.ecKeyPair = ecKeyPair;
WebsocketHandshakeCompleteHandler(final String recognizedProxySecret) {
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
@ -58,8 +52,6 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
final InetAddress preferredRemoteAddress;
{
final Optional<InetAddress> maybePreferredRemoteAddress =
getPreferredRemoteAddress(context, handshakeCompleteEvent);
@ -71,34 +63,41 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
return;
}
preferredRemoteAddress = maybePreferredRemoteAddress.get();
}
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(),
preferredRemoteAddress,
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH ->
new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair);
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH ->
new NoiseAnonymousHandler(ecKeyPair);
default -> {
remoteAddress = maybePreferredRemoteAddress.get();
handshakePattern = switch (handshakeCompleteEvent.requestUri()) {
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> HandshakePattern.IK;
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> HandshakePattern.NK;
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
// internal error if something slipped through.
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
}
default -> throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
};
context.pipeline().replace(WebsocketHandshakeCompleteHandler.this, null, noiseHandshakeHandler);
}
context.fireUserEventTriggered(event);
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object msg) {
try {
if (!(msg instanceof ByteBuf frame)) {
throw new IllegalStateException("Unexpected msg type: " + msg.getClass());
}
if (handshakePattern == null || remoteAddress == null) {
throw new IllegalStateException("Received payload before websocket handshake complete");
}
final NoiseHandshakeInit handshakeMessage =
new NoiseHandshakeInit(remoteAddress, handshakePattern, frame);
context.pipeline().remove(WebsocketHandshakeCompleteHandler.class);
context.fireChannelRead(handshakeMessage);
} catch (Exception e) {
ReferenceCountUtil.release(msg);
throw e;
}
}
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {

View File

@ -0,0 +1,38 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
/**
* Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise
* packet buffers in BinaryWebsocketFrames for writing through the websocket layer.
*/
public class WebsocketPayloadCodec extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
if (msg instanceof BinaryWebSocketFrame frame) {
ctx.fireChannelRead(frame.content());
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
if (msg instanceof ByteBuf bb) {
ctx.write(new BinaryWebSocketFrame(bb), promise);
} else {
ctx.write(msg, promise);
}
}
}

View File

@ -11,7 +11,6 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.Message;
import io.grpc.Status;
import io.grpc.StatusException;
@ -49,7 +48,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
public void validate(
final Object extensionValue,
final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException {
final Message msg) throws StatusException {
try {
final T extensionValueTyped = resolveExtensionValue(extensionValue);
@ -116,7 +115,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
protected void validateRepeatedField(
final T extensionValue,
final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException {
final Message msg) throws StatusException {
throw internalError("`validateRepeatedField` method needs to be implemented");
}

View File

@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.Message;
import io.grpc.StatusException;
import java.util.List;
import java.util.Set;
@ -53,7 +53,7 @@ public class ExactlySizeFieldValidator extends BaseFieldValidator<Set<Integer>>
protected void validateRepeatedField(
final Set<Integer> permittedSizes,
final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException {
final Message msg) throws StatusException {
final int size = msg.getRepeatedFieldCount(fd);
if (permittedSizes.contains(size)) {
return;

View File

@ -6,11 +6,11 @@
package org.whispersystems.textsecuregcm.grpc.validators;
import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.Message;
import io.grpc.StatusException;
public interface FieldValidator {
void validate(Object extensionValue, Descriptors.FieldDescriptor fd, GeneratedMessageV3 msg)
void validate(Object extensionValue, Descriptors.FieldDescriptor fd, Message msg)
throws StatusException;
}

View File

@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.Message;
import io.grpc.StatusException;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
@ -52,7 +52,7 @@ public class NonEmptyFieldValidator extends BaseFieldValidator<Boolean> {
protected void validateRepeatedField(
final Boolean extensionValue,
final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException {
final Message msg) throws StatusException {
if (msg.getRepeatedFieldCount(fd) > 0) {
return;
}

View File

@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.Message;
import io.grpc.StatusException;
import java.util.Set;
import org.signal.chat.require.SizeConstraint;
@ -48,7 +48,7 @@ public class SizeFieldValidator extends BaseFieldValidator<Range> {
}
@Override
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final GeneratedMessageV3 msg) throws StatusException {
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final Message msg) throws StatusException {
final int size = msg.getRepeatedFieldCount(fd);
if (size < range.min() || size > range.max()) {
throw invalidArgument("field value is [%d] but expected to be within the [%d, %d] range".formatted(

View File

@ -10,16 +10,12 @@ import static java.util.Objects.requireNonNull;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
import java.time.Clock;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
@ -27,25 +23,18 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
private final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private final Map<T, RateLimiter> rateLimiterByDescriptor;
private final Map<String, RateLimiterConfig> configs;
protected BaseRateLimiters(
final T[] values,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cacheCluster,
final Clock clock) {
this.configs = configs;
this.rateLimiterByDescriptor = Arrays.stream(values)
.map(descriptor -> Pair.of(
descriptor,
createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
createForDescriptor(descriptor, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
}
@ -53,22 +42,6 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
return requireNonNull(rateLimiterByDescriptor.get(handle));
}
public void validateValuesAndConfigs() {
final Set<String> ids = rateLimiterByDescriptor.keySet().stream()
.map(RateLimiterDescriptor::id)
.collect(Collectors.toSet());
for (final String key: configs.keySet()) {
if (!ids.contains(key)) {
final String message = String.format(
"Static configuration has an unexpected field '%s' that doesn't match any RateLimiterDescriptor",
key
);
logger.error(message);
throw new IllegalArgumentException(message);
}
}
}
protected static ClusterLuaScript defaultScript(final FaultTolerantRedisClusterClient cacheCluster) {
try {
return ClusterLuaScript.fromResource(
@ -80,21 +53,12 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
private static RateLimiter createForDescriptor(
final RateLimiterDescriptor descriptor,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cacheCluster,
final Clock clock) {
if (descriptor.isDynamic()) {
final Supplier<RateLimiterConfig> configResolver = () -> {
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
return config != null
? config
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
};
return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock);
}
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock, dynamicConfigurationManager);
final Supplier<RateLimiterConfig> configResolver =
() -> dynamicConfigurationManager.getConfiguration().getLimits().getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock);
}
}

View File

@ -7,87 +7,167 @@ package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.time.Clock;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util;
public class DynamicRateLimiter implements RateLimiter {
private final String name;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final Supplier<RateLimiterConfig> configResolver;
private final ClusterLuaScript validateScript;
private final FaultTolerantRedisClusterClient cluster;
private final Clock clock;
private final Counter limitExceededCounter;
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
private final Clock clock;
public DynamicRateLimiter(
final String name,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final Supplier<RateLimiterConfig> configResolver,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cluster,
final Clock clock) {
this.name = requireNonNull(name);
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.configResolver = requireNonNull(configResolver);
this.validateScript = requireNonNull(validateScript);
this.cluster = requireNonNull(cluster);
this.clock = requireNonNull(clock);
this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
current().getRight().validate(key, amount);
final RateLimiterConfig config = config();
try {
final long deficitPermitsAmount = executeValidateScript(config, key, amount, true);
if (deficitPermitsAmount > 0) {
limitExceededCounter.increment();
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
throw new RateLimitExceededException(retryAfter);
}
} catch (final Exception e) {
if (e instanceof RateLimitExceededException rateLimitExceededException) {
throw rateLimitExceededException;
}
if (!config.failOpen()) {
throw e;
}
}
}
@Override
public CompletionStage<Void> validateAsync(final String key, final int amount) {
return current().getRight().validateAsync(key, amount);
final RateLimiterConfig config = config();
return executeValidateScriptAsync(config, key, amount, true)
.thenCompose(deficitPermitsAmount -> {
if (deficitPermitsAmount == 0) {
return CompletableFuture.completedFuture((Void) null);
}
limitExceededCounter.increment();
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
return CompletableFuture.failedFuture(new RateLimitExceededException(retryAfter));
})
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) {
throw ExceptionUtils.wrap(rateLimitExceededException);
}
if (config.failOpen()) {
return null;
}
throw ExceptionUtils.wrap(throwable);
});
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return current().getRight().hasAvailablePermits(key, permits);
final RateLimiterConfig config = config();
try {
final long deficitPermitsAmount = executeValidateScript(config, key, permits, false);
return deficitPermitsAmount == 0;
} catch (final Exception e) {
if (config.failOpen()) {
return true;
} else {
throw e;
}
}
}
@Override
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
return current().getRight().hasAvailablePermitsAsync(key, amount);
final RateLimiterConfig config = config();
return executeValidateScriptAsync(config, key, amount, false)
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
.exceptionally(throwable -> {
if (config.failOpen()) {
return true;
}
throw ExceptionUtils.wrap(throwable);
});
}
@Override
public void clear(final String key) {
current().getRight().clear(key);
cluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
}
@Override
public CompletionStage<Void> clearAsync(final String key) {
return current().getRight().clearAsync(key);
return cluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
.thenRun(Util.NOOP);
}
@Override
public RateLimiterConfig config() {
return current().getLeft();
return configResolver.get();
}
private Pair<RateLimiterConfig, RateLimiter> current() {
final RateLimiterConfig cfg = configResolver.get();
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
? p
: Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock, dynamicConfigurationManager))
private long executeValidateScript(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return (Long) validateScript.execute(keys, arguments);
}
private CompletionStage<Long> executeValidateScriptAsync(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
}
private static String bucketName(final String name, final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.limits;
import jakarta.validation.constraints.AssertTrue;
import java.time.Duration;
public record RateLimiterConfig(int bucketSize, Duration permitRegenerationDuration) {
public record RateLimiterConfig(int bucketSize, Duration permitRegenerationDuration, boolean failOpen) {
public double leakRatePerMillis() {
return 1.0 / (permitRegenerationDuration.toNanos() / 1e6);

View File

@ -15,14 +15,9 @@ public interface RateLimiterDescriptor {
*/
String id();
/**
* @return {@code true} if this rate limiter needs to watch for dynamic configuration changes.
*/
boolean isDynamic();
/**
* @return an instance of {@link RateLimiterConfig} to be used by default,
* i.e. if there is no overrides in the application configuration files (static or dynamic).
* i.e. if there is no override in the application dynamic configuration.
*/
RateLimiterConfig defaultConfig();
}

View File

@ -4,11 +4,9 @@
*/
package org.whispersystems.textsecuregcm.limits;
import com.google.common.annotations.VisibleForTesting;
import java.time.Clock;
import java.time.Duration;
import java.util.Map;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
@ -17,59 +15,54 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public enum For implements RateLimiterDescriptor {
BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
PIN("pin", false, new RateLimiterConfig(10, Duration.ofDays(1))),
ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, Duration.ofMillis(1200))),
BACKUP_ATTACHMENT("backupAttachmentCreate", true, new RateLimiterConfig(10_000, Duration.ofSeconds(1))),
PRE_KEYS("prekeys", false, new RateLimiterConfig(6, Duration.ofMinutes(10))),
MESSAGES("messages", false, new RateLimiterConfig(60, Duration.ofSeconds(1))),
STORIES("stories", false, new RateLimiterConfig(5_000, Duration.ofSeconds(8))),
ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2))),
VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2))),
TURN("turnAllocate", false, new RateLimiterConfig(60, Duration.ofSeconds(1))),
PROFILE("profile", false, new RateLimiterConfig(4320, Duration.ofSeconds(20))),
STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, Duration.ofMinutes(72))),
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1))),
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15))),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofSeconds(4))),
REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofSeconds(30))),
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofSeconds(30))),
VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, Duration.ofSeconds(30))),
RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, Duration.ofHours(12))),
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))),
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))),
SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(10, Duration.ofHours(1))),
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", true, new RateLimiterConfig(5, Duration.ofDays(7))),
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))),
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))),
GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10))),
CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000))),
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))),
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", true,
new RateLimiterConfig(100, Duration.ofSeconds(15))),
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1))),
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))),
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))),
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", true, new RateLimiterConfig(10, Duration.ofMinutes(1))),
BACKUP_AUTH_CHECK("backupAuthCheck", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
PIN("pin", new RateLimiterConfig(10, Duration.ofDays(1), false)),
ATTACHMENT("attachmentCreate", new RateLimiterConfig(50, Duration.ofMillis(1200), true)),
BACKUP_ATTACHMENT("backupAttachmentCreate", new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)),
PRE_KEYS("prekeys", new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
MESSAGES("messages", new RateLimiterConfig(60, Duration.ofSeconds(1), true)),
STORIES("stories", new RateLimiterConfig(5_000, Duration.ofSeconds(8), true)),
ALLOCATE_DEVICE("allocateDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
VERIFY_DEVICE("verifyDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
PROFILE("profile", new RateLimiterConfig(4320, Duration.ofSeconds(20), true)),
STICKER_PACK("stickerPack", new RateLimiterConfig(50, Duration.ofMinutes(72), false)),
USERNAME_LOOKUP("usernameLookup", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
USERNAME_SET("usernameSet", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
USERNAME_RESERVE("usernameReserve", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
USERNAME_LINK_OPERATION("usernameLinkOperation", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", new RateLimiterConfig(1000, Duration.ofSeconds(4), true)),
REGISTRATION("registration", new RateLimiterConfig(6, Duration.ofSeconds(30), false)),
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", new RateLimiterConfig(5, Duration.ofSeconds(30), false)),
VERIFICATION_CAPTCHA("verificationCaptcha", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
RATE_LIMIT_RESET("rateLimitReset", new RateLimiterConfig(2, Duration.ofHours(12), false)),
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
SET_BACKUP_ID("setBackupId", new RateLimiterConfig(10, Duration.ofHours(1), false)),
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", new RateLimiterConfig(5, Duration.ofDays(7), false)),
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
GET_CALLING_RELAYS("getCallingRelays", new RateLimiterConfig(100, Duration.ofMinutes(10), false)),
CREATE_CALL_LINK("createCallLink", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
INBOUND_MESSAGE_BYTES("inboundMessageBytes", new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), true)),
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
;
private final String id;
private final boolean dynamic;
private final RateLimiterConfig defaultConfig;
For(final String id, final boolean dynamic, final RateLimiterConfig defaultConfig) {
For(final String id, final RateLimiterConfig defaultConfig) {
this.id = id;
this.dynamic = dynamic;
this.defaultConfig = defaultConfig;
}
@ -77,34 +70,25 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
return id;
}
@Override
public boolean isDynamic() {
return dynamic;
}
public RateLimiterConfig defaultConfig() {
return defaultConfig;
}
}
public static RateLimiters createAndValidate(
final Map<String, RateLimiterConfig> configs,
public static RateLimiters create(
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisClusterClient cacheCluster) {
final RateLimiters rateLimiters = new RateLimiters(
configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
rateLimiters.validateValuesAndConfigs();
return rateLimiters;
return new RateLimiters(
dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
}
@VisibleForTesting
RateLimiters(
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cacheCluster,
final Clock clock) {
super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock);
super(For.values(), dynamicConfigurationManager, validateScript, cacheCluster, clock);
}
public RateLimiter getAllocateDeviceLimiter() {
@ -131,10 +115,6 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
return forDescriptor(For.PIN);
}
public RateLimiter getTurnLimiter() {
return forDescriptor(For.TURN);
}
public RateLimiter getProfileLimiter() {
return forDescriptor(For.PROFILE);
}

View File

@ -1,171 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.RedisException;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.time.Clock;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util;
public class StaticRateLimiter implements RateLimiter {
protected final String name;
private final RateLimiterConfig config;
private final Counter counter;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ClusterLuaScript validateScript;
private final FaultTolerantRedisClusterClient cacheCluster;
private final Clock clock;
public StaticRateLimiter(
final String name,
final RateLimiterConfig config,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cacheCluster,
final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.name = requireNonNull(name);
this.config = requireNonNull(config);
this.validateScript = requireNonNull(validateScript);
this.cacheCluster = requireNonNull(cacheCluster);
this.clock = requireNonNull(clock);
this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
this.dynamicConfigurationManager = dynamicConfigurationManager;
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
try {
final long deficitPermitsAmount = executeValidateScript(key, amount, true);
if (deficitPermitsAmount > 0) {
counter.increment();
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
throw new RateLimitExceededException(retryAfter);
}
} catch (RedisException e) {
if (!failOpen()) {
throw e;
}
}
}
@Override
public CompletionStage<Void> validateAsync(final String key, final int amount) {
return executeValidateScriptAsync(key, amount, true)
.thenCompose(deficitPermitsAmount -> {
if (deficitPermitsAmount == 0) {
return completedFuture((Void) null);
}
counter.increment();
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
return failedFuture(new RateLimitExceededException(retryAfter));
})
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) {
return null;
}
throw ExceptionUtils.wrap(throwable);
});
}
@Override
public boolean hasAvailablePermits(final String key, final int amount) {
try {
final long deficitPermitsAmount = executeValidateScript(key, amount, false);
return deficitPermitsAmount == 0;
} catch (RedisException e) {
if (failOpen()) {
return true;
} else {
throw e;
}
}
}
@Override
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
return executeValidateScriptAsync(key, amount, false)
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) {
return true;
}
throw ExceptionUtils.wrap(throwable);
});
}
@Override
public void clear(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
}
@Override
public CompletionStage<Void> clearAsync(final String key) {
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
.thenRun(Util.NOOP);
}
@Override
public RateLimiterConfig config() {
return config;
}
private boolean failOpen() {
return this.dynamicConfigurationManager.getConfiguration().getRateLimitPolicy().failOpen();
}
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return (Long) validateScript.execute(keys, arguments);
}
private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
}
@VisibleForTesting
protected static String bucketName(final String name, final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import java.util.Optional;
public class DevicePlatformUtil {
private DevicePlatformUtil() {
}
/**
* Returns the most likely client platform for a device.
*
* @param device the device for which to find a client platform
*
* @return the most likely client platform for the given device or empty if no likely platform could be determined
*/
public static Optional<ClientPlatform> getDevicePlatform(final Device device) {
final Optional<ClientPlatform> clientPlatform;
if (StringUtils.isNotBlank(device.getGcmId())) {
clientPlatform = Optional.of(ClientPlatform.ANDROID);
} else if (StringUtils.isNotBlank(device.getApnId())) {
clientPlatform = Optional.of(ClientPlatform.IOS);
} else {
clientPlatform = Optional.empty();
}
return clientPlatform.or(() -> Optional.ofNullable(
switch (device.getUserAgent()) {
case "OWA" -> ClientPlatform.ANDROID;
case "OWI", "OWP" -> ClientPlatform.IOS;
case "OWD" -> ClientPlatform.DESKTOP;
case null, default -> null;
}));
}
}

View File

@ -1,32 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.sun.management.OperatingSystemMXBean;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.binder.MeterBinder;
import java.lang.management.ManagementFactory;
public class FreeMemoryGauge implements MeterBinder {
private final OperatingSystemMXBean operatingSystemMXBean;
public FreeMemoryGauge() {
this.operatingSystemMXBean = (com.sun.management.OperatingSystemMXBean)
ManagementFactory.getOperatingSystemMXBean();
}
@Override
public void bindTo(final MeterRegistry registry) {
Gauge.builder(name(FreeMemoryGauge.class, "freeMemory"), operatingSystemMXBean,
OperatingSystemMXBean::getFreeMemorySize)
.register(registry);
}
}

View File

@ -120,10 +120,7 @@ public class MetricsUtil {
public static void registerSystemResourceMetrics(final Environment environment) {
new ProcessorMetrics().bindTo(Metrics.globalRegistry);
new FreeMemoryGauge().bindTo(Metrics.globalRegistry);
new FileDescriptorMetrics().bindTo(Metrics.globalRegistry);
new OperatingSystemMemoryGauge("Buffers").bindTo(Metrics.globalRegistry);
new OperatingSystemMemoryGauge("Cached").bindTo(Metrics.globalRegistry);
new JvmMemoryMetrics().bindTo(Metrics.globalRegistry);
new JvmThreadMetrics().bindTo(Metrics.globalRegistry);

View File

@ -67,7 +67,7 @@ public class OpenWebSocketCounter {
try {
final ClientPlatform clientPlatform =
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).getPlatform();
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).platform();
calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform);
calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform);

View File

@ -1,56 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.binder.MeterBinder;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
public class OperatingSystemMemoryGauge implements MeterBinder {
private final String metricName;
private static final File MEMINFO_FILE = new File("/proc/meminfo");
private static final Pattern MEMORY_METRIC_PATTERN = Pattern.compile("^([^:]+):\\s+([0-9]+).*$");
public OperatingSystemMemoryGauge(final String metricName) {
this.metricName = metricName;
}
@Override
public void bindTo(MeterRegistry registry) {
final String metricName = this.metricName;
Gauge.builder(name(OperatingSystemMemoryGauge.class, metricName.toLowerCase(Locale.ROOT)), () -> {
try (final BufferedReader bufferedReader = new BufferedReader(new FileReader(MEMINFO_FILE))) {
return getValue(bufferedReader.lines(), metricName);
} catch (final IOException e) {
return 0L;
}
})
.register(registry);
}
@VisibleForTesting
static double getValue(final Stream<String> lines, final String metricName) {
return lines.map(MEMORY_METRIC_PATTERN::matcher)
.filter(Matcher::matches)
.filter(matcher -> metricName.equalsIgnoreCase(matcher.group(1)))
.map(matcher -> Double.parseDouble(matcher.group(2)))
.findFirst()
.orElse(0d);
}
}

View File

@ -9,6 +9,7 @@ import io.micrometer.core.instrument.Tag;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.WhisperServerVersion;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
@ -48,15 +49,15 @@ public class UserAgentTagUtil {
}
public static Tag getPlatformTag(@Nullable final UserAgent userAgent) {
return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized");
return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized");
}
public static Optional<Tag> getClientVersionTag(final String userAgentString, final ClientReleaseManager clientReleaseManager) {
try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
if (clientReleaseManager.isVersionActive(userAgent.getPlatform(), userAgent.getVersion())) {
return Optional.of(Tag.of(VERSION_TAG, userAgent.getVersion().toString()));
if (clientReleaseManager.isVersionActive(userAgent.platform(), userAgent.version())) {
return Optional.of(Tag.of(VERSION_TAG, userAgent.version().toString()));
}
} catch (final UnrecognizedUserAgentException ignored) {
}
@ -70,10 +71,8 @@ public class UserAgentTagUtil {
try {
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
platform = userAgent.getPlatform().name().toLowerCase();
libsignal = userAgent.getAdditionalSpecifiers()
.map(additionalSpecifiers -> additionalSpecifiers.contains("libsignal"))
.orElse(false);
platform = userAgent.platform().name().toLowerCase();
libsignal = StringUtils.contains(userAgent.additionalSpecifiers(), "libsignal");
} catch (final UnrecognizedUserAgentException e) {
platform = "unrecognized";
libsignal = false;

View File

@ -13,7 +13,6 @@ import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
@ -21,6 +20,7 @@ import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.util.Pair;
@ -28,6 +28,7 @@ import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
@ -36,7 +37,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
/**
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
@ -53,10 +53,12 @@ public class MessageSender {
private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
// Note that these names deliberately reference `MessageController` for metric continuity
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage");
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize");
private static final String EMPTY_MESSAGE_LIST_COUNTER_NAME = MetricsUtil.name(MessageSender.class, "emptyMessageList");
private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage");
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
@ -65,6 +67,7 @@ public class MessageSender {
private static final String STORY_TAG_NAME = "story";
private static final String SEALED_SENDER_TAG_NAME = "sealedSender";
private static final String MULTI_RECIPIENT_TAG_NAME = "multiRecipient";
private static final String SYNC_MESSAGE_TAG_NAME = "sync";
@VisibleForTesting
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
@ -72,9 +75,13 @@ public class MessageSender {
@VisibleForTesting
static final byte NO_EXCLUDED_DEVICE_ID = -1;
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
public MessageSender(
final MessagesManager messagesManager,
final PushNotificationManager pushNotificationManager,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager;
this.experimentEnrollmentManager = experimentEnrollmentManager;
}
/**
@ -86,6 +93,8 @@ public class MessageSender {
* @param destinationIdentifier the service identifier to which the messages are addressed
* @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
* caller's own account), contains the ID of the device that sent the message
* @param userAgent the User-Agent string for the sender; may be {@code null} if not known
*
* @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required
@ -97,38 +106,48 @@ public class MessageSender {
final ServiceIdentifier destinationIdentifier,
final Map<Byte, Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId,
@Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException {
if (messagesByDeviceId.isEmpty()) {
// TODO Simply return and don't throw an exception when iOS clients no longer depend on this behavior
throw new MismatchedDevicesException(new MismatchedDevices(
destination.getDevices().stream().map(Device::getId).collect(Collectors.toSet()),
Collections.emptySet(),
Collections.emptySet()));
}
if (!destination.isIdentifiedBy(destinationIdentifier)) {
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
}
final Envelope firstMessage = messagesByDeviceId.values().iterator().next();
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId()));
if (messagesByDeviceId.isEmpty()) {
Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME,
Tags.of(SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment();
}
final boolean isStory = firstMessage.getStory();
final byte excludedDeviceId;
if (syncMessageSenderDeviceId.isPresent()) {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) ||
!destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
validateIndividualMessageContentLength(messagesByDeviceId.values(), isSyncMessage, isStory, userAgent);
throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender");
}
excludedDeviceId = syncMessageSenderDeviceId.get();
} else {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender");
}
excludedDeviceId = NO_EXCLUDED_DEVICE_ID;
}
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
destinationIdentifier,
registrationIdsByDeviceId,
isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID);
excludedDeviceId);
if (maybeMismatchedDevices.isPresent()) {
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
}
validateIndividualMessageContentLength(messagesByDeviceId.values(), syncMessageSenderDeviceId.isPresent(), userAgent);
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
.forEach((deviceId, destinationPresent) -> {
final Envelope message = messagesByDeviceId.get(deviceId);
@ -146,8 +165,9 @@ public class MessageSender {
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()),
SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent()),
MULTI_RECIPIENT_TAG_NAME, "false")
.and(UserAgentTagUtil.getPlatformTag(userAgent));
.and(platformTag);
Metrics.counter(SEND_COUNTER_NAME, tags).increment();
});
@ -230,6 +250,7 @@ public class MessageSender {
URGENT_TAG_NAME, String.valueOf(isUrgent),
STORY_TAG_NAME, String.valueOf(isStory),
SEALED_SENDER_TAG_NAME, "true",
SYNC_MESSAGE_TAG_NAME, "false",
MULTI_RECIPIENT_TAG_NAME, "true")
.and(UserAgentTagUtil.getPlatformTag(userAgent));
@ -295,11 +316,7 @@ public class MessageSender {
// We know the device must be present because we've already filtered out device IDs that aren't associated
// with the given account
final Device device = account.getDevice(deviceId).orElseThrow();
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
};
final int expectedRegistrationId = device.getRegistrationId(serviceIdentifier.identityType());
return registrationId != expectedRegistrationId;
})
@ -313,14 +330,13 @@ public class MessageSender {
private static void validateIndividualMessageContentLength(final Iterable<Envelope> messages,
final boolean isSyncMessage,
final boolean isStory,
@Nullable final String userAgent) throws MessageTooLargeException {
for (final Envelope message : messages) {
MessageSender.validateContentLength(message.getContent().size(),
false,
isSyncMessage,
isStory,
message.getStory(),
userAgent);
}
}

View File

@ -17,6 +17,7 @@ import java.util.function.BiConsumer;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@ -28,6 +29,9 @@ public class PushNotificationManager {
private final APNSender apnSender;
private final FcmSender fcmSender;
private final PushNotificationScheduler pushNotificationScheduler;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
public static final String SCHEDULE_LOW_URGENCY_FCM_PUSH_EXPERIMENT = "scheduleLowUregencyFcmPush";
private static final String SENT_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "sentPushNotification");
private static final String FAILED_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "failedPushNotification");
@ -38,12 +42,14 @@ public class PushNotificationManager {
public PushNotificationManager(final AccountsManager accountsManager,
final APNSender apnSender,
final FcmSender fcmSender,
final PushNotificationScheduler pushNotificationScheduler) {
final PushNotificationScheduler pushNotificationScheduler,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.accountsManager = accountsManager;
this.apnSender = apnSender;
this.fcmSender = fcmSender;
this.pushNotificationScheduler = pushNotificationScheduler;
this.experimentEnrollmentManager = experimentEnrollmentManager;
}
public CompletableFuture<Optional<SendPushNotificationResult>> sendNewMessageNotification(final Account destination, final byte destinationDeviceId, final boolean urgent) throws NotPushRegisteredException {
@ -101,11 +107,11 @@ public class PushNotificationManager {
@VisibleForTesting
CompletableFuture<Optional<SendPushNotificationResult>> sendNotification(final PushNotification pushNotification) {
if (pushNotification.tokenType() == PushNotification.TokenType.APN && !pushNotification.urgent()) {
// APNs imposes a per-device limit on background push notifications; schedule a notification for some time in the
// future (possibly even now!) rather than sending a notification directly
if (shouldScheduleNotification(pushNotification)) {
// Schedule a notification for some time in the future (possibly even now!) rather than sending a notification
// directly
return pushNotificationScheduler
.scheduleBackgroundApnsNotification(pushNotification.destination(), pushNotification.destinationDevice())
.scheduleBackgroundNotification(pushNotification.tokenType(), pushNotification.destination(), pushNotification.destinationDevice())
.whenComplete(logErrors())
.thenApply(ignored -> Optional.<SendPushNotificationResult>empty())
.toCompletableFuture();
@ -150,6 +156,16 @@ public class PushNotificationManager {
.thenApply(Optional::of);
}
private boolean shouldScheduleNotification(final PushNotification pushNotification) {
return !pushNotification.urgent() && switch (pushNotification.tokenType()) {
// APNs imposes a per-device limit on background push notifications
case APN -> true;
case FCM -> experimentEnrollmentManager.isEnrolled(
pushNotification.destination().getUuid(),
SCHEDULE_LOW_URGENCY_FCM_PUSH_EXPERIMENT);
};
}
private static <T> BiConsumer<T, Throwable> logErrors() {
return (ignored, throwable) -> {
if (throwable != null) {

View File

@ -13,7 +13,6 @@ import io.lettuce.core.Range;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.SetArgs;
import io.lettuce.core.cluster.SlotHash;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.time.Clock;
@ -44,14 +43,15 @@ public class PushNotificationScheduler implements Managed {
private static final Logger logger = LoggerFactory.getLogger(PushNotificationScheduler.class);
private static final String PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN";
private static final String PENDING_BACKGROUND_APN_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN";
private static final String PENDING_BACKGROUND_FCM_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_FCM";
private static final String LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX = "LAST_BACKGROUND_NOTIFICATION";
private static final String PENDING_DELAYED_NOTIFICATIONS_KEY_PREFIX = "DELAYED";
@VisibleForTesting
static final String NEXT_SLOT_TO_PROCESS_KEY = "pending_notification_next_slot";
private static final Counter BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER = Metrics.counter(name(PushNotificationScheduler.class, "backgroundNotification", "scheduled"));
private static final String BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER_NAME = name(PushNotificationScheduler.class, "backgroundNotification", "scheduled");
private static final String BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME = name(PushNotificationScheduler.class, "backgroundNotification", "sent");
private static final String DELAYED_NOTIFICATION_SCHEDULED_COUNTER_NAME = name(PushNotificationScheduler.class, "delayedNotificationScheduled");
@ -65,7 +65,7 @@ public class PushNotificationScheduler implements Managed {
private final FaultTolerantRedisClusterClient pushSchedulingCluster;
private final Clock clock;
private final ClusterLuaScript scheduleBackgroundApnsNotificationScript;
private final ClusterLuaScript scheduleBackgroundNotificationScript;
private final Thread[] workerThreads;
@ -103,15 +103,18 @@ public class PushNotificationScheduler implements Managed {
final int slot = (int) (pushSchedulingCluster.withCluster(connection ->
connection.sync().incr(NEXT_SLOT_TO_PROCESS_KEY)) % SlotHash.SLOT_COUNT);
return processScheduledBackgroundApnsNotifications(slot) + processScheduledDelayedNotifications(slot);
return processScheduledBackgroundNotifications(PushNotification.TokenType.APN, slot)
+ processScheduledBackgroundNotifications(PushNotification.TokenType.FCM, slot)
+ processScheduledDelayedNotifications(slot);
}
@VisibleForTesting
long processScheduledBackgroundApnsNotifications(final int slot) {
return processScheduledNotifications(getPendingBackgroundApnsNotificationQueueKey(slot),
PushNotificationScheduler.this::sendBackgroundApnsNotification);
long processScheduledBackgroundNotifications(PushNotification.TokenType tokenType, final int slot) {
return processScheduledNotifications(getPendingBackgroundNotificationQueueKey(tokenType, slot),
(account, device) -> sendBackgroundNotification(tokenType, account, device));
}
@VisibleForTesting
long processScheduledDelayedNotifications(final int slot) {
return processScheduledNotifications(getDelayedNotificationQueueKey(slot),
@ -172,7 +175,7 @@ public class PushNotificationScheduler implements Managed {
this.pushSchedulingCluster = pushSchedulingCluster;
this.clock = clock;
this.scheduleBackgroundApnsNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster,
this.scheduleBackgroundNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster,
"lua/apn/schedule_background_notification.lua", ScriptOutputType.VALUE);
this.workerThreads = new Thread[dedicatedProcessThreadCount];
@ -183,23 +186,21 @@ public class PushNotificationScheduler implements Managed {
}
/**
* Schedule a background APNs notification to be sent some time in the future.
* Schedule a background push notification to be sent some time in the future.
*
* @return A CompletionStage that completes when the notification has successfully been scheduled
*
* @throws IllegalArgumentException if the given device does not have an APNs token
* @throws IllegalArgumentException if the given device does not have a push token
*/
public CompletionStage<Void> scheduleBackgroundApnsNotification(final Account account, final Device device) {
if (StringUtils.isBlank(device.getApnId())) {
throw new IllegalArgumentException("Device must have an APNs token");
public CompletionStage<Void> scheduleBackgroundNotification(final PushNotification.TokenType tokenType, final Account account, final Device device) {
if (StringUtils.isBlank(getPushToken(tokenType, device))) {
throw new IllegalArgumentException("Device must have an " + tokenType + " token");
}
BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER.increment();
return scheduleBackgroundApnsNotificationScript.executeAsync(
Metrics.counter(BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER_NAME, "type", tokenType.name()).increment();
return scheduleBackgroundNotificationScript.executeAsync(
List.of(
getLastBackgroundApnsNotificationTimestampKey(account, device),
getPendingBackgroundApnsNotificationQueueKey(account, device)),
getLastBackgroundNotificationTimestampKey(account, device),
getPendingBackgroundNotificationQueueKey(tokenType, account, device)),
List.of(
encodeAciAndDeviceId(account, device),
String.valueOf(clock.millis()),
@ -236,14 +237,15 @@ public class PushNotificationScheduler implements Managed {
*/
public CompletionStage<Void> cancelScheduledNotifications(Account account, Device device) {
return CompletableFuture.allOf(
cancelBackgroundApnsNotifications(account, device),
cancelBackgroundNotifications(PushNotification.TokenType.FCM, account, device),
cancelBackgroundNotifications(PushNotification.TokenType.APN, account, device),
cancelDelayedNotifications(account, device));
}
@VisibleForTesting
CompletableFuture<Void> cancelBackgroundApnsNotifications(final Account account, final Device device) {
CompletableFuture<Void> cancelBackgroundNotifications(final PushNotification.TokenType tokenType, final Account account, final Device device) {
return pushSchedulingCluster.withCluster(connection -> connection.async()
.zrem(getPendingBackgroundApnsNotificationQueueKey(account, device), encodeAciAndDeviceId(account, device)))
.zrem(getPendingBackgroundNotificationQueueKey(tokenType, account, device), encodeAciAndDeviceId(account, device)))
.thenRun(Util.NOOP)
.toCompletableFuture();
}
@ -276,17 +278,23 @@ public class PushNotificationScheduler implements Managed {
}
@VisibleForTesting
CompletableFuture<Void> sendBackgroundApnsNotification(final Account account, final Device device) {
if (StringUtils.isBlank(device.getApnId())) {
CompletableFuture<Void> sendBackgroundNotification(PushNotification.TokenType tokenType, final Account account, final Device device) {
final String pushToken = getPushToken(tokenType, device);
if (StringUtils.isBlank(pushToken)) {
return CompletableFuture.completedFuture(null);
}
final PushNotificationSender sender = switch (tokenType) {
case FCM -> fcmSender;
case APN -> apnSender;
};
// It's okay for the "last notification" timestamp to expire after the "cooldown" period has elapsed; a missing
// timestamp and a timestamp older than the period are functionally equivalent.
return pushSchedulingCluster.withCluster(connection -> connection.async().set(
getLastBackgroundApnsNotificationTimestampKey(account, device),
getLastBackgroundNotificationTimestampKey(account, device),
String.valueOf(clock.millis()), new SetArgs().ex(BACKGROUND_NOTIFICATION_PERIOD)))
.thenCompose(ignored -> apnSender.sendNotification(new PushNotification(device.getApnId(), PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, false)))
.thenCompose(ignored -> sender.sendNotification(new PushNotification(pushToken, tokenType, PushNotification.NotificationType.NOTIFICATION, null, account, device, false)))
.thenAccept(response -> Metrics.counter(BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME,
ACCEPTED_TAG, String.valueOf(response.accepted()))
.increment())
@ -321,6 +329,10 @@ public class PushNotificationScheduler implements Managed {
@VisibleForTesting
static String encodeAciAndDeviceId(final Account account, final Device device) {
// Note: This does not include a device registration id. If a device is unlinked and a new device is linked with
// the original device's id, the new device might get the old device's scheduled push, or the new device might
// delay its own push because the old device had a recent push. An extra or delayed background push is harmless,
// so this is okay.
return account.getUuid() + ":" + device.getId();
}
@ -351,15 +363,19 @@ public class PushNotificationScheduler implements Managed {
}
@VisibleForTesting
static String getPendingBackgroundApnsNotificationQueueKey(final Account account, final Device device) {
return getPendingBackgroundApnsNotificationQueueKey(SlotHash.getSlot(encodeAciAndDeviceId(account, device)));
static String getPendingBackgroundNotificationQueueKey(final PushNotification.TokenType tokenType, final Account account, final Device device) {
return getPendingBackgroundNotificationQueueKey(tokenType, SlotHash.getSlot(encodeAciAndDeviceId(account, device)));
}
private static String getPendingBackgroundApnsNotificationQueueKey(final int slot) {
return PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
private static String getPendingBackgroundNotificationQueueKey(final PushNotification.TokenType tokenType, final int slot) {
final String prefix = switch (tokenType) {
case APN -> PENDING_BACKGROUND_APN_NOTIFICATIONS_KEY_PREFIX;
case FCM -> PENDING_BACKGROUND_FCM_NOTIFICATIONS_KEY_PREFIX;
};
return prefix + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
private static String getLastBackgroundApnsNotificationTimestampKey(final Account account, final Device device) {
private static String getLastBackgroundNotificationTimestampKey(final Account account, final Device device) {
return LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX + "::{" + encodeAciAndDeviceId(account, device) + "}";
}
@ -376,15 +392,15 @@ public class PushNotificationScheduler implements Managed {
Optional<Instant> getLastBackgroundApnsNotificationTimestamp(final Account account, final Device device) {
return Optional.ofNullable(
pushSchedulingCluster.withCluster(connection ->
connection.sync().get(getLastBackgroundApnsNotificationTimestampKey(account, device))))
connection.sync().get(getLastBackgroundNotificationTimestampKey(account, device))))
.map(timestampString -> Instant.ofEpochMilli(Long.parseLong(timestampString)));
}
@VisibleForTesting
Optional<Instant> getNextScheduledBackgroundApnsNotificationTimestamp(final Account account, final Device device) {
Optional<Instant> getNextScheduledBackgroundNotificationTimestamp(PushNotification.TokenType tokenType, final Account account, final Device device) {
return Optional.ofNullable(
pushSchedulingCluster.withCluster(connection ->
connection.sync().zscore(getPendingBackgroundApnsNotificationQueueKey(account, device),
connection.sync().zscore(getPendingBackgroundNotificationQueueKey(tokenType, account, device),
encodeAciAndDeviceId(account, device))))
.map(timestamp -> Instant.ofEpochMilli(timestamp.longValue()));
}
@ -407,4 +423,11 @@ public class PushNotificationScheduler implements Managed {
return "unknown";
}
}
private static String getPushToken(final PushNotification.TokenType tokenType, final Device device) {
return switch (tokenType) {
case FCM -> device.getGcmId();
case APN -> device.getApnId();
};
}
}

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.push;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.slf4j.Logger;
@ -60,16 +61,15 @@ public class ReceiptSender {
.collect(Collectors.toMap(Device::getId, ignored -> message));
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
}));
.collect(Collectors.toMap(Device::getId,
device -> device.getRegistrationId(destinationIdentifier.identityType())));
try {
messageSender.sendMessages(destinationAccount,
destinationIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
Optional.empty(),
UserAgentTagUtil.SERVER_UA);
} catch (final Exception e) {
logger.warn("Could not send delivery receipt", e);

View File

@ -38,8 +38,8 @@ public class SchedulingUtil {
final LocalTime preferredTime,
final Clock clock) {
final ZonedDateTime candidateNotificationTime = getZoneOffset(account, clock)
.map(zoneOffset -> ZonedDateTime.now(zoneOffset).with(preferredTime))
final ZonedDateTime candidateNotificationTime = getZoneId(account, clock)
.map(zoneId -> ZonedDateTime.now(clock.withZone(zoneId)).with(preferredTime))
.orElseGet(() -> {
// We couldn't find a reasonable timezone for the account for some reason, so make an educated guess at a
// reasonable time to send a notification based on the account's creation time.
@ -59,7 +59,7 @@ public class SchedulingUtil {
}
@VisibleForTesting
static Optional<ZoneOffset> getZoneOffset(final Account account, final Clock clock) {
static Optional<ZoneId> getZoneId(final Account account, final Clock clock) {
try {
final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(account.getNumber(), null);
@ -70,7 +70,7 @@ public class SchedulingUtil {
return Optional.empty();
}
final List<ZoneOffset> sortedZoneOffsets = timeZonesForNumber
final List<ZoneId> sortedZoneOffsets = timeZonesForNumber
.stream()
.map(id -> {
try {
@ -80,9 +80,6 @@ public class SchedulingUtil {
}
})
.filter(Objects::nonNull)
.map(ZoneId::getRules)
.distinct()
.map(zoneRules -> zoneRules.getOffset(clock.instant()))
.sorted()
.toList();

View File

@ -9,6 +9,7 @@ import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
@ -49,10 +50,15 @@ public class AccountLockManager {
* @param task the task to execute once locks have been acquired
* @param lockAcquisitionExecutor the executor on which to run blocking lock acquire/release tasks. this executor
* should not use virtual threads.
* @throws InterruptedException if interrupted while acquiring a lock
*
* @return the value returned by the given {@code task}
*
* @throws Exception if an exception is thrown by the given {@code task}
*/
public void withLock(final List<UUID> phoneNumberIdentifiers, final Runnable task,
final Executor lockAcquisitionExecutor) {
public <V> V withLock(final List<UUID> phoneNumberIdentifiers,
final Callable<V> task,
final Executor lockAcquisitionExecutor) throws Exception {
if (phoneNumberIdentifiers.isEmpty()) {
throw new IllegalArgumentException("List of PNIs to lock must not be empty");
}
@ -75,7 +81,7 @@ public class AccountLockManager {
}
}, lockAcquisitionExecutor).join();
task.run();
return task.call();
} finally {
CompletableFuture.runAsync(() -> {
for (final LockItem lockItem : lockItems) {

View File

@ -11,6 +11,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
@ -36,9 +37,11 @@ import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.util.AsyncTimerUtil;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
@ -57,6 +60,7 @@ import software.amazon.awssdk.services.dynamodb.model.Delete;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure;
@ -87,7 +91,8 @@ public class Accounts {
static final List<String> ACCOUNT_FIELDS_TO_EXCLUDE_FROM_SERIALIZATION = List.of("uuid", "usernameLinkHandle");
private static final ObjectWriter ACCOUNT_DDB_JSON_WRITER = SystemMapper.jsonMapper()
@VisibleForTesting
static final ObjectWriter ACCOUNT_DDB_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, ACCOUNT_FIELDS_TO_EXCLUDE_FROM_SERIALIZATION));
private static final Timer CREATE_TIMER = Metrics.timer(name(Accounts.class, "create"));
@ -299,8 +304,11 @@ public class Accounts {
final Collection<TransactWriteItem> additionalWriteItems) {
if (!existingAccount.getUuid().equals(accountToCreate.getUuid()) ||
!existingAccount.getNumber().equals(accountToCreate.getNumber())) {
!existingAccount.getPhoneNumberIdentifier().equals(accountToCreate.getPhoneNumberIdentifier())) {
log.error("Reclaimed accounts must match. Old account {}:{}:{}, New account {}:{}:{}",
existingAccount.getUuid(), redactPhoneNumber(existingAccount.getNumber()), existingAccount.getPhoneNumberIdentifier(),
accountToCreate.getUuid(), redactPhoneNumber(accountToCreate.getNumber()), accountToCreate.getPhoneNumberIdentifier());
throw new IllegalArgumentException("reclaimed accounts must match");
}
@ -1399,8 +1407,7 @@ public class Accounts {
final String tableName,
final AttributeValue uuidAttr,
final String keyName,
final AttributeValue keyValue
) {
final AttributeValue keyValue) {
return TransactWriteItem.builder()
.put(Put.builder()
.tableName(tableName)
@ -1465,6 +1472,68 @@ public class Accounts {
.build();
}
public CompletableFuture<Void> regenerateConstraints(final Account account) {
final List<CompletableFuture<?>> constraintFutures = new ArrayList<>();
constraintFutures.add(writeConstraint(phoneNumberConstraintTableName,
account.getIdentifier(IdentityType.ACI),
ATTR_ACCOUNT_E164,
AttributeValues.fromString(account.getNumber())));
constraintFutures.add(writeConstraint(phoneNumberIdentifierConstraintTableName,
account.getIdentifier(IdentityType.ACI),
ATTR_PNI_UUID,
AttributeValues.fromUUID(account.getPhoneNumberIdentifier())));
account.getUsernameHash().ifPresent(usernameHash ->
constraintFutures.add(writeUsernameConstraint(account.getIdentifier(IdentityType.ACI),
usernameHash,
Optional.empty())));
account.getUsernameHolds().forEach(usernameHold ->
constraintFutures.add(writeUsernameConstraint(account.getIdentifier(IdentityType.ACI),
usernameHold.usernameHash(),
Optional.of(Instant.ofEpochSecond(usernameHold.expirationSecs())))));
return CompletableFuture.allOf(constraintFutures.toArray(CompletableFuture[]::new));
}
private CompletableFuture<Void> writeConstraint(
final String tableName,
final UUID accountIdentifier,
final String keyName,
final AttributeValue keyValue) {
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(Map.of(
keyName, keyValue,
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.build())
.thenRun(Util.NOOP);
}
private CompletableFuture<Void> writeUsernameConstraint(
final UUID accountIdentifier,
final byte[] usernameHash,
final Optional<Instant> maybeExpiration) {
final Map<String, AttributeValue> item = new HashMap<>(Map.of(
UsernameTable.KEY_USERNAME_HASH, AttributeValues.fromByteArray(usernameHash),
UsernameTable.ATTR_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier),
UsernameTable.ATTR_CONFIRMED, AttributeValues.fromBool(maybeExpiration.isEmpty())
));
maybeExpiration.ifPresent(expiration ->
item.put(UsernameTable.ATTR_TTL, AttributeValues.fromLong(expiration.getEpochSecond())));
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(usernamesConstraintTableName)
.item(item)
.build())
.thenRun(Util.NOOP);
}
@Nonnull
private static String extractCancellationReasonCodes(final TransactionCanceledException exception) {
return exception.cancellationReasons().stream()
@ -1522,4 +1591,15 @@ public class Accounts {
private static boolean isTransactionConflict(final CancellationReason reason) {
return TRANSACTION_CONFLICT.equals(reason.code());
}
private static String redactPhoneNumber(final String phoneNumber) {
final StringBuilder sb = new StringBuilder();
sb.append("+");
sb.append(Util.getCountryCode(phoneNumber));
sb.append("???");
sb.append(StringUtils.length(phoneNumber) < 3
? ""
: phoneNumber.substring(phoneNumber.length() - 2, phoneNumber.length()));
return sb.toString();
}
}

View File

@ -62,7 +62,6 @@ import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger;
@ -274,11 +273,33 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
final DeviceSpec primaryDeviceSpec,
@Nullable final String userAgent) throws InterruptedException {
final Account account = new Account();
final UUID pni = phoneNumberIdentifiers.getPhoneNumberIdentifier(number).join();
return createTimer.record(() -> {
accountLockManager.withLock(List.of(pni), () -> {
try {
return accountLockManager.withLock(List.of(pni),
() -> create(number, pni, accountAttributes, accountBadges, aciIdentityKey, pniIdentityKey, primaryDeviceSpec, userAgent), accountLockExecutor);
} catch (final Exception e) {
if (e instanceof RuntimeException runtimeException) {
throw runtimeException;
}
logger.error("Unexpected exception while creating account", e);
throw new RuntimeException(e);
}
});
}
private Account create(final String number,
final UUID pni,
final AccountAttributes accountAttributes,
final List<AccountBadge> accountBadges,
final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final DeviceSpec primaryDeviceSpec,
@Nullable final String userAgent) {
final Account account = new Account();
final Optional<UUID> maybeRecentlyDeletedAccountIdentifier =
accounts.findRecentlyDeletedAccountIdentifier(pni);
@ -390,10 +411,8 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
accountAttributes.recoveryPassword().ifPresent(registrationRecoveryPassword ->
registrationRecoveryPasswordsManager.store(account.getIdentifier(IdentityType.PNI),
registrationRecoveryPassword));
}, accountLockExecutor);
return account;
});
}
public CompletableFuture<Pair<Account, Device>> addDevice(final Account account, final DeviceSpec deviceSpec, final String linkDeviceToken) {
@ -580,6 +599,15 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
return Optional.of(aci);
}
/**
* Unlink a device from the given account. The device will be immediately disconnected if it is
* connected to any chat frontend, but it is the caller's responsibility to make sure that the
* account's *other* devices are disconnected, either by use of
* {@link org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider} or
* directly by calling {@link DeviceDisconnectionManager#requestDisconnection}.
*
* @returns the updated Account
*/
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) {
throw new IllegalArgumentException("Cannot remove primary device");
@ -633,26 +661,45 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
public Account changeNumber(final Account account,
final String targetNumber,
@Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final IdentityKey pniIdentityKey,
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Byte, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier();
final UUID targetPhoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber).join();
if (originalPhoneNumberIdentifier.equals(targetPhoneNumberIdentifier)) {
if (pniIdentityKey != null) {
throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePniKeys");
}
return account;
}
try {
return accountLockManager.withLock(List.of(account.getPhoneNumberIdentifier(), targetPhoneNumberIdentifier),
() -> changeNumber(account, targetNumber, targetPhoneNumberIdentifier, pniIdentityKey, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds), accountLockExecutor);
} catch (final Exception e) {
if (e instanceof MismatchedDevicesException mismatchedDevicesException) {
throw mismatchedDevicesException;
} if (e instanceof RuntimeException runtimeException) {
throw runtimeException;
}
logger.error("Unexpected exception when changing phone number", e);
throw new RuntimeException(e);
}
}
private Account changeNumber(final Account account,
final String targetNumber,
final UUID targetPhoneNumberIdentifier,
final IdentityKey pniIdentityKey,
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
final AtomicReference<Account> updatedAccount = new AtomicReference<>();
final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier();
accountLockManager.withLock(List.of(account.getPhoneNumberIdentifier(), targetPhoneNumberIdentifier), () -> {
redisDelete(account);
// There are three possible states for accounts associated with the target phone number:
@ -685,9 +732,9 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
.join();
final Collection<TransactWriteItem> keyWriteItems =
buildPniKeyWriteItems(uuid, targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys);
buildPniKeyWriteItems(targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys);
final Account numberChangedAccount = updateWithRetries(
return updateWithRetries(
account,
a -> {
setPniKeys(account, pniIdentityKey, pniRegistrationIds);
@ -696,26 +743,23 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
a -> accounts.changeNumber(a, targetNumber, targetPhoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
updatedAccount.set(numberChangedAccount);
}, accountLockExecutor);
return updatedAccount.get();
}
public Account updatePniKeys(final Account account,
final IdentityKey pniIdentityKey,
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
try {
return accountLockManager.withLock(List.of(account.getIdentifier(IdentityType.PNI)), () -> {
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
final UUID aci = account.getIdentifier(IdentityType.ACI);
final UUID pni = account.getIdentifier(IdentityType.PNI);
final Collection<TransactWriteItem> keyWriteItems =
buildPniKeyWriteItems(pni, pni, pniSignedPreKeys, pniPqLastResortPreKeys);
buildPniKeyWriteItems(pni, pniSignedPreKeys, pniPqLastResortPreKeys);
return redisDeleteAsync(account)
.thenCompose(ignored -> keysManager.deleteSingleUsePreKeys(pni))
@ -727,44 +771,38 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
AccountChangeValidator.GENERAL_CHANGE_VALIDATOR,
MAX_UPDATE_ATTEMPTS))
.join();
}, accountLockExecutor);
} catch (final Exception e) {
if (e instanceof MismatchedDevicesException mismatchedDevicesException) {
throw mismatchedDevicesException;
} else if (e instanceof RuntimeException runtimeException) {
throw runtimeException;
}
logger.error("Unexpected exception when updating PNI key material", e);
throw new RuntimeException(e);
}
}
private Collection<TransactWriteItem> buildPniKeyWriteItems(
final UUID enabledDevicesIdentifier,
final UUID phoneNumberIdentifier,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys) {
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys) {
final List<TransactWriteItem> keyWriteItems = new ArrayList<>();
if (pniSignedPreKeys != null) {
pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
keyWriteItems.add(keysManager.buildWriteItemForEcSignedPreKey(phoneNumberIdentifier, deviceId, signedPreKey)));
}
if (pniPqLastResortPreKeys != null) {
keysManager.getPqEnabledDevices(enabledDevicesIdentifier)
.thenAccept(deviceIds -> deviceIds.stream()
.filter(pniPqLastResortPreKeys::containsKey)
.map(deviceId -> keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier,
deviceId,
pniPqLastResortPreKeys.get(deviceId)))
.forEach(keyWriteItems::add))
.join();
}
pniPqLastResortPreKeys.forEach((deviceId, lastResortKey) ->
keyWriteItems.add(keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier, deviceId, lastResortKey)));
return keyWriteItems;
}
private void setPniKeys(final Account account,
@Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Byte, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniRegistrationIds)) {
return;
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key and registration IDs must be all null or all non-null");
}
final IdentityKey pniIdentityKey,
final Map<Byte, Integer> pniRegistrationIds) {
account.getDevices()
.forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId())));
@ -773,22 +811,15 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
private void validateDevices(final Account account,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
if (pniSignedPreKeys == null && pniRegistrationIds == null) {
return;
} else if (pniSignedPreKeys == null || pniRegistrationIds == null) {
throw new IllegalArgumentException("Signed pre-keys and registration IDs must both be null or both be non-null");
}
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
// Check that all including primary ID are in signed pre-keys
validateCompleteDeviceList(account, pniSignedPreKeys.keySet());
// Check that all including primary ID are in Pq pre-keys
if (pniPqLastResortPreKeys != null) {
validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet());
}
// Check that all devices are accounted for in the map of new PNI registration IDs
validateCompleteDeviceList(account, pniRegistrationIds.keySet());
@ -807,8 +838,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(
new MismatchedDevices(missingDeviceIds, extraDeviceIds, Collections.emptySet()));
throw new MismatchedDevicesException(new MismatchedDevices(missingDeviceIds, extraDeviceIds, Set.of()));
}
}

View File

@ -5,8 +5,10 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import java.time.Clock;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils;
@ -28,21 +30,26 @@ public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class);
private final MessageSender messageSender;
private final AccountsManager accountsManager;
private final Clock clock;
public ChangeNumberManager(
final MessageSender messageSender,
final AccountsManager accountsManager) {
final AccountsManager accountsManager,
final Clock clock) {
this.messageSender = messageSender;
this.accountsManager = accountsManager;
this.clock = clock;
}
public Account changeNumber(final Account account, final String number,
@Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Byte, ECSignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
@Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Byte, Integer> pniRegistrationIds,
@Nullable final String senderUserAgent)
public Account changeNumber(final Account account,
final String number,
final IdentityKey pniIdentityKey,
final Map<Byte, ECSignedPreKey> deviceSignedPreKeys,
final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages,
final Map<Byte, Integer> pniRegistrationIds,
final String senderUserAgent)
throws InterruptedException, MismatchedDevicesException, MessageTooLargeException {
if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) ||
@ -96,7 +103,7 @@ public class ChangeNumberManager {
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
try {
final long serverTimestamp = System.currentTimeMillis();
final long serverTimestamp = clock.millis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
@ -113,10 +120,15 @@ public class ChangeNumberManager {
.setEphemeral(false)
.build()));
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId));
final Map<Byte, Integer> registrationIdsByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId, senderUserAgent);
messageSender.sendMessages(account,
serviceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
Optional.of(Device.PRIMARY_ID),
senderUserAgent);
} catch (final RuntimeException e) {
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
throw e;

View File

@ -20,6 +20,7 @@ import java.util.stream.IntStream;
import javax.annotation.Nullable;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter;
import org.whispersystems.textsecuregcm.util.DeviceNameByteArrayAdapter;
@ -64,9 +65,8 @@ public class Device {
@JsonProperty
private int registrationId;
@Nullable
@JsonProperty("pniRegistrationId")
private Integer phoneNumberIdentityRegistrationId;
private int phoneNumberIdentityRegistrationId;
@JsonProperty
private long lastSeen;
@ -208,18 +208,17 @@ public class Device {
return getId() == PRIMARY_ID;
}
public int getRegistrationId() {
return registrationId;
public int getRegistrationId(final IdentityType identityType) {
return switch (identityType) {
case ACI -> registrationId;
case PNI -> phoneNumberIdentityRegistrationId;
};
}
public void setRegistrationId(int registrationId) {
this.registrationId = registrationId;
}
public OptionalInt getPhoneNumberIdentityRegistrationId() {
return phoneNumberIdentityRegistrationId != null ? OptionalInt.of(phoneNumberIdentityRegistrationId) : OptionalInt.empty();
}
public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRegistrationId) {
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
}

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@ -114,10 +113,6 @@ public class KeysManager {
return ecSignedPreKeys.find(identifier, deviceId);
}
public CompletableFuture<List<Byte>> getPqEnabledDevices(final UUID identifier) {
return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().toFuture();
}
public CompletableFuture<Integer> getEcCount(final UUID identifier, final byte deviceId) {
return ecPreKeys.getCount(identifier, deviceId);
}

Some files were not shown because too many files have changed in this diff Show More