Compare commits
87 Commits
v20250408.
...
main
Author | SHA1 | Date |
---|---|---|
![]() |
74ee1c8c4f | |
![]() |
35604cf151 | |
![]() |
aafcd63a9f | |
![]() |
43a534f05b | |
![]() |
9ec66dac7f | |
![]() |
13fc0ffbca | |
![]() |
93ba6616d1 | |
![]() |
a4b98f38a6 | |
![]() |
b95d08aaea | |
![]() |
b400d49e77 | |
![]() |
e43487155f | |
![]() |
dee3723d97 | |
![]() |
b7e986f43c | |
![]() |
664fb23e97 | |
![]() |
714ef128a1 | |
![]() |
7cf3fce624 | |
![]() |
0cc5431867 | |
![]() |
b8d5b2c8ea | |
![]() |
894ca6d290 | |
![]() |
847b25f695 | |
![]() |
703a05cb15 | |
![]() |
30c194c557 | |
![]() |
cc7b030a41 | |
![]() |
7a91c4d5b7 | |
![]() |
287da6e7e3 | |
![]() |
7cf89764e7 | |
![]() |
d316c72beb | |
![]() |
82d187cc45 | |
![]() |
0c240d21d2 | |
![]() |
009252c831 | |
![]() |
0c1146aaa5 | |
![]() |
4fd06594a0 | |
![]() |
4e175be88f | |
![]() |
771a700acd | |
![]() |
e9bd5da2c3 | |
![]() |
f64244f33a | |
![]() |
ed1417c3e3 | |
![]() |
0398e02690 | |
![]() |
e285bf1a52 | |
![]() |
2c9219d4f7 | |
![]() |
26b3b75054 | |
![]() |
cdb651b68f | |
![]() |
91a36f4421 | |
![]() |
21c1d71551 | |
![]() |
38befdb260 | |
![]() |
63c79173b2 | |
![]() |
d2ad003891 | |
![]() |
eb89773819 | |
![]() |
403abd84f6 | |
![]() |
f62f79c95c | |
![]() |
144c4c9223 | |
![]() |
ab4fc4f459 | |
![]() |
51569ce0a5 | |
![]() |
f191c68efc | |
![]() |
bb8ce6d981 | |
![]() |
e0ee75e0d0 | |
![]() |
1ef3a230a1 | |
![]() |
b1805d4bf1 | |
![]() |
cac979c7fd | |
![]() |
4072dcdda5 | |
![]() |
ed382fff6d | |
![]() |
23bb8277d5 | |
![]() |
8099d6465c | |
![]() |
28a0b9e84e | |
![]() |
9287aaf7ce | |
![]() |
0585f862cb | |
![]() |
7cac6f6f72 | |
![]() |
57be4d798b | |
![]() |
05c74f1997 | |
![]() |
f5e49b6db7 | |
![]() |
3c40e72d27 | |
![]() |
2f2ae7cec5 | |
![]() |
b236b53dc3 | |
![]() |
eb71e30046 | |
![]() |
aa5fd52302 | |
![]() |
d6bc2765b6 | |
![]() |
01258de560 | |
![]() |
3af2cc5c70 | |
![]() |
2278842531 | |
![]() |
60ab00ecc6 | |
![]() |
1fb6d23500 | |
![]() |
8d8a2a5583 | |
![]() |
caa81b4885 | |
![]() |
37c4a0451a | |
![]() |
11df8fcc6c | |
![]() |
5a7f4d8381 | |
![]() |
1f1e4c72ec |
|
@ -1,6 +1,7 @@
|
|||
name: Service CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches-ignore:
|
||||
- gh-pages
|
||||
|
|
|
@ -72,14 +72,12 @@ public final class Operations {
|
|||
}
|
||||
|
||||
public static TestUser newRegisteredUser(final String number) {
|
||||
final byte[] registrationPassword = randomBytes(32);
|
||||
final byte[] registrationPassword = populateRandomRecoveryPassword(number);
|
||||
final String accountPassword = Base64.getEncoder().encodeToString(randomBytes(32));
|
||||
|
||||
final TestUser user = TestUser.create(number, accountPassword, registrationPassword);
|
||||
final AccountAttributes accountAttributes = user.accountAttributes();
|
||||
|
||||
INTEGRATION_TOOLS.populateRecoveryPassword(number, registrationPassword).join();
|
||||
|
||||
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
|
||||
|
@ -108,6 +106,7 @@ public final class Operations {
|
|||
}
|
||||
|
||||
public record PrescribedVerificationNumber(String number, String verificationCode) {}
|
||||
|
||||
public static PrescribedVerificationNumber prescribedVerificationNumber() {
|
||||
return new PrescribedVerificationNumber(
|
||||
CONFIG.prescribedRegistrationNumber(),
|
||||
|
@ -123,6 +122,13 @@ public final class Operations {
|
|||
.orElseThrow(() -> new RuntimeException("push challenge not found for the verification session"));
|
||||
}
|
||||
|
||||
public static byte[] populateRandomRecoveryPassword(final String number) {
|
||||
final byte[] recoveryPassword = randomBytes(32);
|
||||
INTEGRATION_TOOLS.populateRecoveryPassword(number, recoveryPassword).join();
|
||||
|
||||
return recoveryPassword;
|
||||
}
|
||||
|
||||
public static <T> T sendEmptyRequestAuthenticated(
|
||||
final String endpoint,
|
||||
final String method,
|
||||
|
@ -329,15 +335,15 @@ public final class Operations {
|
|||
}
|
||||
}
|
||||
|
||||
private static ECSignedPreKey generateSignedECPreKey(long id, final ECKeyPair identityKeyPair) {
|
||||
public static ECSignedPreKey generateSignedECPreKey(final long id, final ECKeyPair identityKeyPair) {
|
||||
final ECPublicKey pubKey = Curve.generateKeyPair().getPublicKey();
|
||||
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
|
||||
return new ECSignedPreKey(id, pubKey, sig);
|
||||
final byte[] signature = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
|
||||
return new ECSignedPreKey(id, pubKey, signature);
|
||||
}
|
||||
|
||||
private static KEMSignedPreKey generateSignedKEMPreKey(long id, final ECKeyPair identityKeyPair) {
|
||||
public static KEMSignedPreKey generateSignedKEMPreKey(final long id, final ECKeyPair identityKeyPair) {
|
||||
final KEMPublicKey pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey();
|
||||
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
|
||||
return new KEMSignedPreKey(id, pubKey, sig);
|
||||
final byte[] signature = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
|
||||
return new KEMSignedPreKey(id, pubKey, signature);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,27 +6,35 @@
|
|||
package org.signal.integration;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.apache.http.HttpStatus;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.usernames.BaseUsernameException;
|
||||
import org.signal.libsignal.usernames.Username;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountIdentifierResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.ConfirmUsernameHashRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.UsernameHashResponse;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
||||
public class AccountTest {
|
||||
|
||||
@Test
|
||||
public void testCreateAccount() throws Exception {
|
||||
public void testCreateAccount() {
|
||||
final TestUser user = Operations.newRegisteredUser("+19995550101");
|
||||
try {
|
||||
final Pair<Integer, AccountIdentityResponse> execute = Operations.apiGet("/v1/accounts/whoami")
|
||||
|
@ -39,7 +47,7 @@ public class AccountTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCreateAccountAtomic() throws Exception {
|
||||
public void testCreateAccountAtomic() {
|
||||
final TestUser user = Operations.newRegisteredUser("+19995550201");
|
||||
try {
|
||||
final Pair<Integer, AccountIdentityResponse> execute = Operations.apiGet("/v1/accounts/whoami")
|
||||
|
@ -51,6 +59,33 @@ public class AccountTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void changePhoneNumber() {
|
||||
final TestUser user = Operations.newRegisteredUser("+19995550301");
|
||||
final String targetNumber = "+19995550302";
|
||||
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
|
||||
final ChangeNumberRequest changeNumberRequest = new ChangeNumberRequest(null,
|
||||
Operations.populateRandomRecoveryPassword(targetNumber),
|
||||
targetNumber,
|
||||
null,
|
||||
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
|
||||
Collections.emptyList(),
|
||||
Map.of(Device.PRIMARY_ID, Operations.generateSignedECPreKey(1, pniIdentityKeyPair)),
|
||||
Map.of(Device.PRIMARY_ID, Operations.generateSignedKEMPreKey(2, pniIdentityKeyPair)),
|
||||
Map.of(Device.PRIMARY_ID, 17));
|
||||
|
||||
final AccountIdentityResponse accountIdentityResponse =
|
||||
Operations.apiPut("/v2/accounts/number", changeNumberRequest)
|
||||
.authorized(user)
|
||||
.executeExpectSuccess(AccountIdentityResponse.class);
|
||||
|
||||
assertEquals(user.aciUuid(), accountIdentityResponse.uuid());
|
||||
assertNotEquals(user.pniUuid(), accountIdentityResponse.pni());
|
||||
assertEquals(targetNumber, accountIdentityResponse.number());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUsernameOperations() throws Exception {
|
||||
final TestUser user = Operations.newRegisteredUser("+19995550102");
|
||||
|
|
31
pom.xml
31
pom.xml
|
@ -46,7 +46,7 @@
|
|||
<!-- can be updated to latest version with Dropwizard 5 (Jetty 12); will then need to disable telemetry -->
|
||||
<dynamodblocal.version>2.2.1</dynamodblocal.version>
|
||||
<google-cloud-libraries.version>26.57.0</google-cloud-libraries.version>
|
||||
<grpc.version>1.69.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
|
||||
<grpc.version>1.70.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
|
||||
<gson.version>2.12.1</gson.version>
|
||||
<!-- several libraries (AWS, Google Cloud) use Apache http components transitively, and we need to align them -->
|
||||
<httpcore.version>4.4.16</httpcore.version>
|
||||
|
@ -65,14 +65,15 @@
|
|||
<luajava.version>3.5.0</luajava.version>
|
||||
<micrometer.version>1.14.5</micrometer.version>
|
||||
<netty.version>4.1.119.Final</netty.version>
|
||||
<!-- Must be greater than or equal to the value from Google libraries-bom
|
||||
since some of its libraries generate code. See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
|
||||
<protobuf.version>3.25.5</protobuf.version>
|
||||
<!-- Must be less than or equal to the value from Google libraries-bom which controls the protobuf runtime version.
|
||||
See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
|
||||
<protoc.version>4.29.4</protoc.version>
|
||||
<pushy.version>0.15.4</pushy.version>
|
||||
<reactive.grpc.version>1.2.4</reactive.grpc.version>
|
||||
<reactor-bom.version>2024.0.4</reactor-bom.version> <!-- 3.7.4, see https://github.com/reactor/reactor#bom-versioning-scheme -->
|
||||
<resilience4j.version>2.3.0</resilience4j.version>
|
||||
<semver4j.version>3.1.0</semver4j.version>
|
||||
<simple-grpc.version>0.1.0</simple-grpc.version>
|
||||
<slf4j.version>2.0.17</slf4j.version>
|
||||
<stripe.version>23.10.0</stripe.version>
|
||||
<swagger.version>2.2.27</swagger.version>
|
||||
|
@ -126,7 +127,7 @@
|
|||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.cloud</groupId>
|
||||
<artifactId>libraries-bom-protobuf3</artifactId>
|
||||
<artifactId>libraries-bom</artifactId>
|
||||
<version>${google-cloud-libraries.version}</version>
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
|
@ -174,11 +175,6 @@
|
|||
<artifactId>pushy-dropwizard-metrics-listener</artifactId>
|
||||
<version>${pushy.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>${protobuf.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.googlecode.libphonenumber</groupId>
|
||||
<artifactId>libphonenumber</artifactId>
|
||||
|
@ -267,6 +263,11 @@
|
|||
<artifactId>libsignal-server</artifactId>
|
||||
<version>0.67.6</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.signal</groupId>
|
||||
<artifactId>simple-grpc-runtime</artifactId>
|
||||
<version>${simple-grpc.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.signal.forks</groupId>
|
||||
<artifactId>noise-java</artifactId>
|
||||
|
@ -437,7 +438,7 @@
|
|||
<version>0.6.1</version>
|
||||
<configuration>
|
||||
<checkStaleness>false</checkStaleness>
|
||||
<protocArtifact>com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}</protocArtifact>
|
||||
<protocArtifact>com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}</protocArtifact>
|
||||
<pluginId>grpc-java</pluginId>
|
||||
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
|
||||
|
||||
|
@ -449,6 +450,14 @@
|
|||
<version>${reactive.grpc.version}</version>
|
||||
<mainClass>com.salesforce.reactorgrpc.ReactorGrpcGenerator</mainClass>
|
||||
</protocPlugin>
|
||||
|
||||
<protocPlugin>
|
||||
<id>simple</id>
|
||||
<groupId>org.signal</groupId>
|
||||
<artifactId>simple-grpc-generator</artifactId>
|
||||
<version>${simple-grpc.version}</version>
|
||||
<mainClass>org.signal.grpc.simple.SimpleGrpcGenerator</mainClass>
|
||||
</protocPlugin>
|
||||
</protocPlugins>
|
||||
</configuration>
|
||||
<executions>
|
||||
|
|
|
@ -482,7 +482,8 @@ turn:
|
|||
- turn:%s
|
||||
- turn:%s:80?transport=tcp
|
||||
- turns:%s:443?transport=tcp
|
||||
ttl: 86400
|
||||
requestedCredentialTtl: PT24H
|
||||
clientCredentialTtl: PT12H
|
||||
hostname: turn.cloudflare.example.com
|
||||
numHttpClients: 1
|
||||
|
||||
|
|
|
@ -80,6 +80,11 @@
|
|||
<artifactId>noise-java</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.signal</groupId>
|
||||
<artifactId>simple-grpc-runtime</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.dropwizard</groupId>
|
||||
<artifactId>dropwizard-core</artifactId>
|
||||
|
|
|
@ -407,10 +407,6 @@ public class WhisperServerConfiguration extends Configuration {
|
|||
return rateLimitersCluster;
|
||||
}
|
||||
|
||||
public Map<String, RateLimiterConfig> getLimitsConfiguration() {
|
||||
return limits;
|
||||
}
|
||||
|
||||
public FcmConfiguration getFcmConfiguration() {
|
||||
return fcm;
|
||||
}
|
||||
|
|
|
@ -154,7 +154,7 @@ import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
|||
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseWebSocketTunnelServer;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.websocket.NoiseWebSocketTunnelServer;
|
||||
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
|
||||
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
|
||||
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
|
||||
|
@ -269,6 +269,7 @@ import org.whispersystems.textsecuregcm.workers.IdleDeviceNotificationSchedulerF
|
|||
import org.whispersystems.textsecuregcm.workers.MessagePersisterServiceCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.NotifyIdleDevicesCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.ProcessScheduledJobsServiceCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.RegenerateAccountConstraintDataCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.RemoveExpiredBackupsCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.RemoveExpiredLinkedDevicesCommand;
|
||||
|
@ -335,6 +336,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
bootstrap.addCommand(new ProcessScheduledJobsServiceCommand("process-idle-device-notification-jobs",
|
||||
"Processes scheduled jobs to send notifications to idle devices",
|
||||
new IdleDeviceNotificationSchedulerFactory()));
|
||||
|
||||
bootstrap.addCommand(new RegenerateAccountConstraintDataCommand());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -633,11 +636,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
PushNotificationScheduler pushNotificationScheduler = new PushNotificationScheduler(pushSchedulerCluster,
|
||||
apnSender, fcmSender, accountsManager, 0, 0);
|
||||
PushNotificationManager pushNotificationManager =
|
||||
new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler);
|
||||
new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler, experimentEnrollmentManager);
|
||||
WebSocketConnectionEventManager webSocketConnectionEventManager =
|
||||
new WebSocketConnectionEventManager(accountsManager, pushNotificationManager, messagesCluster, clientEventExecutor, asyncOperationQueueingExecutor);
|
||||
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(),
|
||||
dynamicConfigurationManager, rateLimitersCluster);
|
||||
RateLimiters rateLimiters = RateLimiters.create(dynamicConfigurationManager, rateLimitersCluster);
|
||||
ProvisioningManager provisioningManager = new ProvisioningManager(pubsubClient);
|
||||
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
|
||||
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
|
||||
|
@ -668,12 +670,13 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
|
||||
final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
|
||||
|
||||
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager);
|
||||
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager);
|
||||
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
|
||||
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
|
||||
config.getTurnConfiguration().cloudflare().apiToken().value(),
|
||||
config.getTurnConfiguration().cloudflare().endpoint(),
|
||||
config.getTurnConfiguration().cloudflare().ttl(),
|
||||
config.getTurnConfiguration().cloudflare().requestedCredentialTtl(),
|
||||
config.getTurnConfiguration().cloudflare().clientCredentialTtl(),
|
||||
config.getTurnConfiguration().cloudflare().urls(),
|
||||
config.getTurnConfiguration().cloudflare().urlsWithIps(),
|
||||
config.getTurnConfiguration().cloudflare().hostname(),
|
||||
|
@ -693,7 +696,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager,
|
||||
pushChallengeDynamoDb);
|
||||
|
||||
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
|
||||
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, Clock.systemUTC());
|
||||
|
||||
HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build();
|
||||
FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients()
|
||||
|
@ -987,7 +990,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
webSocketEnvironment.setConnectListener(
|
||||
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
|
||||
pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor,
|
||||
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor));
|
||||
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager));
|
||||
webSocketEnvironment.jersey()
|
||||
.register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager));
|
||||
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));
|
||||
|
|
|
@ -15,6 +15,7 @@ import java.net.Inet6Address;
|
|||
import java.net.URI;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.time.Duration;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletionException;
|
||||
|
@ -39,16 +40,18 @@ public class CloudflareTurnCredentialsManager {
|
|||
private final List<String> cloudflareTurnUrls;
|
||||
private final List<String> cloudflareTurnUrlsWithIps;
|
||||
private final String cloudflareTurnHostname;
|
||||
private final HttpRequest request;
|
||||
private final HttpRequest getCredentialsRequest;
|
||||
|
||||
private final FaultTolerantHttpClient cloudflareTurnClient;
|
||||
private final DnsNameResolver dnsNameResolver;
|
||||
|
||||
record CredentialRequest(long ttl) {}
|
||||
private final Duration clientCredentialTtl;
|
||||
|
||||
record CloudflareTurnResponse(IceServer iceServers) {
|
||||
private record CredentialRequest(long ttl) {}
|
||||
|
||||
record IceServer(
|
||||
private record CloudflareTurnResponse(IceServer iceServers) {
|
||||
|
||||
private record IceServer(
|
||||
String username,
|
||||
String credential,
|
||||
List<String> urls) {
|
||||
|
@ -56,10 +59,17 @@ public class CloudflareTurnCredentialsManager {
|
|||
}
|
||||
|
||||
public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken,
|
||||
final String cloudflareTurnEndpoint, final long cloudflareTurnTtl, final List<String> cloudflareTurnUrls,
|
||||
final List<String> cloudflareTurnUrlsWithIps, final String cloudflareTurnHostname,
|
||||
final int cloudflareTurnNumHttpClients, final CircuitBreakerConfiguration circuitBreaker,
|
||||
final ExecutorService executor, final RetryConfiguration retry, final ScheduledExecutorService retryExecutor,
|
||||
final String cloudflareTurnEndpoint,
|
||||
final Duration requestedCredentialTtl,
|
||||
final Duration clientCredentialTtl,
|
||||
final List<String> cloudflareTurnUrls,
|
||||
final List<String> cloudflareTurnUrlsWithIps,
|
||||
final String cloudflareTurnHostname,
|
||||
final int cloudflareTurnNumHttpClients,
|
||||
final CircuitBreakerConfiguration circuitBreaker,
|
||||
final ExecutorService executor,
|
||||
final RetryConfiguration retry,
|
||||
final ScheduledExecutorService retryExecutor,
|
||||
final DnsNameResolver dnsNameResolver) {
|
||||
|
||||
this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder()
|
||||
|
@ -75,17 +85,24 @@ public class CloudflareTurnCredentialsManager {
|
|||
this.cloudflareTurnHostname = cloudflareTurnHostname;
|
||||
this.dnsNameResolver = dnsNameResolver;
|
||||
|
||||
final String credentialsRequestBody;
|
||||
|
||||
try {
|
||||
final String body = SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(cloudflareTurnTtl));
|
||||
this.request = HttpRequest.newBuilder()
|
||||
credentialsRequestBody =
|
||||
SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(requestedCredentialTtl.toSeconds()));
|
||||
} catch (final JsonProcessingException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
|
||||
// We repeat the same request to Cloudflare every time, so we can construct it once and re-use it
|
||||
this.getCredentialsRequest = HttpRequest.newBuilder()
|
||||
.uri(URI.create(cloudflareTurnEndpoint))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
|
||||
.POST(HttpRequest.BodyPublishers.ofString(body))
|
||||
.POST(HttpRequest.BodyPublishers.ofString(credentialsRequestBody))
|
||||
.build();
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
|
||||
this.clientCredentialTtl = clientCredentialTtl;
|
||||
}
|
||||
|
||||
public TurnToken retrieveFromCloudflare() throws IOException {
|
||||
|
@ -105,7 +122,7 @@ public class CloudflareTurnCredentialsManager {
|
|||
final Timer.Sample sample = Timer.start();
|
||||
final HttpResponse<String> response;
|
||||
try {
|
||||
response = cloudflareTurnClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join();
|
||||
response = cloudflareTurnClient.sendAsync(getCredentialsRequest, HttpResponse.BodyHandlers.ofString()).join();
|
||||
sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME)
|
||||
.publishPercentileHistogram(true)
|
||||
.tags("outcome", "success")
|
||||
|
@ -130,6 +147,7 @@ public class CloudflareTurnCredentialsManager {
|
|||
return new TurnToken(
|
||||
cloudflareTurnResponse.iceServers().username(),
|
||||
cloudflareTurnResponse.iceServers().credential(),
|
||||
clientCredentialTtl.toSeconds(),
|
||||
cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls,
|
||||
cloudflareTurnComposedUrls,
|
||||
cloudflareTurnHostname
|
||||
|
|
|
@ -5,13 +5,15 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.auth;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import java.util.List;
|
||||
import javax.annotation.Nonnull;
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.List;
|
||||
|
||||
public record TurnToken(
|
||||
String username,
|
||||
String password,
|
||||
@JsonProperty("ttl") long ttlSeconds,
|
||||
@Nonnull List<String> urls,
|
||||
@Nonnull List<String> urlsWithIps,
|
||||
@Nullable String hostname) {
|
||||
|
|
|
@ -1,34 +1,22 @@
|
|||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
private static final Metadata EMPTY_TRAILERS = new Metadata();
|
||||
|
||||
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
}
|
||||
|
||||
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) {
|
||||
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
|
||||
return grpcClientConnectionManager.getAuthenticatedDevice(localAddress);
|
||||
} else {
|
||||
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
|
||||
}
|
||||
}
|
||||
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call)
|
||||
throws ChannelNotFoundException {
|
||||
|
||||
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) {
|
||||
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
|
||||
return new ServerCall.Listener<>() {};
|
||||
return grpcClientConnectionManager.getAuthenticatedDevice(call);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,12 +3,17 @@ package org.whispersystems.textsecuregcm.auth.grpc;
|
|||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.Status;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
/**
|
||||
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
|
||||
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
|
||||
* device are closed with an {@code UNAUTHENTICATED} status.
|
||||
* device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined
|
||||
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
|
||||
* reject the call with a status of {@code UNAVAILABLE}.
|
||||
*/
|
||||
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
|
||||
|
||||
|
@ -21,8 +26,15 @@ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInt
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
try {
|
||||
return getAuthenticatedDevice(call)
|
||||
.map(ignored -> closeAsUnauthenticated(call))
|
||||
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited
|
||||
// service via an authenticated connection, then that's actually a server configuration issue and not a
|
||||
// problem with the client's request.
|
||||
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
|
||||
.orElseGet(() -> next.startCall(call, headers));
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,12 +5,16 @@ import io.grpc.Contexts;
|
|||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.Status;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
/**
|
||||
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
|
||||
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
|
||||
* status.
|
||||
* status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed
|
||||
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
|
||||
*/
|
||||
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
|
||||
|
||||
|
@ -23,10 +27,17 @@ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInte
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
try {
|
||||
return getAuthenticatedDevice(call)
|
||||
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
|
||||
call, headers, next))
|
||||
.orElseGet(() -> closeAsUnauthenticated(call));
|
||||
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
|
||||
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
|
||||
// problem with the client's request.
|
||||
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,16 +6,36 @@
|
|||
package org.whispersystems.textsecuregcm.configuration;
|
||||
|
||||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.AssertTrue;
|
||||
import jakarta.validation.constraints.NotBlank;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import jakarta.validation.constraints.Positive;
|
||||
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
|
||||
|
||||
/**
|
||||
* Configuration properties for Cloudflare TURN integration.
|
||||
*
|
||||
* @param apiToken the API token to use when requesting TURN tokens from Cloudflare
|
||||
* @param endpoint the URI of the Cloudflare API endpoint that vends TURN tokens
|
||||
* @param requestedCredentialTtl the lifetime of TURN tokens to request from Cloudflare
|
||||
* @param clientCredentialTtl the time clients may cache a TURN token; must be less than or equal to {@link #requestedCredentialTtl}
|
||||
* @param urls a collection of TURN URLs to include verbatim in responses to clients
|
||||
* @param urlsWithIps a collection of {@link String#format(String, Object...)} patterns to be populated with resolved IP
|
||||
* addresses for {@link #hostname} in responses to clients; each pattern must include a single
|
||||
* {@code %s} placeholder for the IP address
|
||||
* @param circuitBreaker a circuit breaker for requests to Cloudflare
|
||||
* @param retry a retry policy for requests to Cloudflare
|
||||
* @param hostname the hostname to resolve to IP addresses for use with {@link #urlsWithIps}; also transmitted to
|
||||
* clients for use as an SNI when connecting to pre-resolved hosts
|
||||
* @param numHttpClients the number of parallel HTTP clients to use to communicate with Cloudflare
|
||||
*/
|
||||
public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
|
||||
@NotBlank String endpoint,
|
||||
@NotBlank long ttl,
|
||||
@NotNull Duration requestedCredentialTtl,
|
||||
@NotNull Duration clientCredentialTtl,
|
||||
@NotNull @NotEmpty @Valid List<@NotBlank String> urls,
|
||||
@NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps,
|
||||
@NotNull @Valid CircuitBreakerConfiguration circuitBreaker,
|
||||
|
@ -35,4 +55,9 @@ public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
|
|||
retry = new RetryConfiguration();
|
||||
}
|
||||
}
|
||||
|
||||
@AssertTrue
|
||||
public boolean isClientTtlShorterThanRequestedTtl() {
|
||||
return clientCredentialTtl.compareTo(requestedCredentialTtl) <= 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,10 +46,6 @@ public class DynamicConfiguration {
|
|||
@Valid
|
||||
DynamicMessagePersisterConfiguration messagePersister = new DynamicMessagePersisterConfiguration();
|
||||
|
||||
@JsonProperty
|
||||
@Valid
|
||||
DynamicRateLimitPolicy rateLimitPolicy = new DynamicRateLimitPolicy(false);
|
||||
|
||||
@JsonProperty
|
||||
@Valid
|
||||
DynamicRegistrationConfiguration registrationConfiguration = new DynamicRegistrationConfiguration(false);
|
||||
|
@ -100,10 +96,6 @@ public class DynamicConfiguration {
|
|||
return messagePersister;
|
||||
}
|
||||
|
||||
public DynamicRateLimitPolicy getRateLimitPolicy() {
|
||||
return rateLimitPolicy;
|
||||
}
|
||||
|
||||
public DynamicRegistrationConfiguration getRegistrationConfiguration() {
|
||||
return registrationConfiguration;
|
||||
}
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.configuration.dynamic;
|
||||
|
||||
public record DynamicRateLimitPolicy(boolean failOpen) {}
|
|
@ -102,9 +102,9 @@ public class AccountControllerV2 {
|
|||
name = "Retry-After",
|
||||
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
|
||||
public AccountIdentityResponse changeNumber(@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
|
||||
@NotNull @Valid final ChangeNumberRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgentString,
|
||||
@Context final ContainerRequestContext requestContext)
|
||||
throws RateLimitExceededException, InterruptedException {
|
||||
@NotNull @Valid final ChangeNumberRequest request,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgentString,
|
||||
@Context final ContainerRequestContext requestContext) throws RateLimitExceededException, InterruptedException {
|
||||
|
||||
if (!authenticatedDevice.getAuthenticatedDevice().isPrimary()) {
|
||||
throw new ForbiddenException();
|
||||
|
|
|
@ -15,16 +15,12 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
|||
import jakarta.ws.rs.GET;
|
||||
import jakarta.ws.rs.Path;
|
||||
import jakarta.ws.rs.Produces;
|
||||
import jakarta.ws.rs.container.ContainerRequestContext;
|
||||
import jakarta.ws.rs.core.Context;
|
||||
import jakarta.ws.rs.core.MediaType;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager;
|
||||
import org.whispersystems.textsecuregcm.auth.TurnToken;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.websocket.auth.ReadOnly;
|
||||
|
||||
|
@ -32,14 +28,16 @@ import org.whispersystems.websocket.auth.ReadOnly;
|
|||
@Path("/v2/calling")
|
||||
public class CallRoutingControllerV2 {
|
||||
|
||||
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
|
||||
private final RateLimiters rateLimiters;
|
||||
private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
|
||||
|
||||
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER =
|
||||
Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
|
||||
|
||||
public CallRoutingControllerV2(
|
||||
final RateLimiters rateLimiters,
|
||||
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager
|
||||
) {
|
||||
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager) {
|
||||
|
||||
this.rateLimiters = rateLimiters;
|
||||
this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager;
|
||||
}
|
||||
|
@ -58,25 +56,17 @@ public class CallRoutingControllerV2 {
|
|||
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
|
||||
@ApiResponse(responseCode = "422", description = "Invalid request format.")
|
||||
@ApiResponse(responseCode = "429", description = "Rate limited.")
|
||||
public GetCallingRelaysResponse getCallingRelays(
|
||||
final @ReadOnly @Auth AuthenticatedDevice auth
|
||||
) throws RateLimitExceededException, IOException {
|
||||
UUID aci = auth.getAccount().getUuid();
|
||||
public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth)
|
||||
throws RateLimitExceededException, IOException {
|
||||
|
||||
final UUID aci = auth.getAccount().getUuid();
|
||||
rateLimiters.getCallEndpointLimiter().validate(aci);
|
||||
|
||||
List<TurnToken> tokens = new ArrayList<>();
|
||||
try {
|
||||
tokens.add(cloudflareTurnCredentialsManager.retrieveFromCloudflare());
|
||||
} catch (Exception e) {
|
||||
CallRoutingControllerV2.CLOUDFLARE_TURN_ERROR_COUNTER.increment();
|
||||
return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare()));
|
||||
} catch (final Exception e) {
|
||||
CLOUDFLARE_TURN_ERROR_COUNTER.increment();
|
||||
throw e;
|
||||
}
|
||||
|
||||
return new GetCallingRelaysResponse(tokens);
|
||||
}
|
||||
|
||||
public record GetCallingRelaysResponse(
|
||||
List<TurnToken> relays
|
||||
) {
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,6 @@ import java.util.EnumMap;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionException;
|
||||
|
@ -52,7 +51,6 @@ import java.util.concurrent.CompletionStage;
|
|||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.glassfish.jersey.server.ContainerRequest;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
|
||||
|
@ -74,6 +72,7 @@ import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
|
|||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
|
@ -402,7 +401,7 @@ public class DeviceController {
|
|||
|
||||
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
|
||||
try {
|
||||
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform());
|
||||
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).platform());
|
||||
} catch (final UnrecognizedUserAgentException ignored) {
|
||||
return linkedDeviceListenersForUnrecognizedPlatforms;
|
||||
}
|
||||
|
@ -600,25 +599,9 @@ public class DeviceController {
|
|||
}
|
||||
|
||||
private static io.micrometer.core.instrument.Tag primaryPlatformTag(final Account account) {
|
||||
final Device primaryDevice = account.getPrimaryDevice();
|
||||
|
||||
Optional<ClientPlatform> clientPlatform = Optional.empty();
|
||||
if (StringUtils.isNotBlank(primaryDevice.getGcmId())) {
|
||||
clientPlatform = Optional.of(ClientPlatform.ANDROID);
|
||||
} else if (StringUtils.isNotBlank(primaryDevice.getApnId())) {
|
||||
clientPlatform = Optional.of(ClientPlatform.IOS);
|
||||
}
|
||||
clientPlatform = clientPlatform.or(() -> Optional.ofNullable(
|
||||
switch (primaryDevice.getUserAgent()) {
|
||||
case "OWA" -> ClientPlatform.ANDROID;
|
||||
case "OWI", "OWP" -> ClientPlatform.IOS;
|
||||
case "OWD" -> ClientPlatform.DESKTOP;
|
||||
case null, default -> null;
|
||||
}));
|
||||
|
||||
return io.micrometer.core.instrument.Tag.of(
|
||||
"primaryPlatform",
|
||||
clientPlatform
|
||||
DevicePlatformUtil.getDevicePlatform(account.getPrimaryDevice())
|
||||
.map(p -> p.name().toLowerCase(Locale.ROOT))
|
||||
.orElse("unknown"));
|
||||
}
|
||||
|
|
|
@ -98,19 +98,18 @@ public class DonationController {
|
|||
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid())
|
||||
.thenCompose(receiptMatched -> {
|
||||
if (!receiptMatched) {
|
||||
|
||||
return CompletableFuture.completedFuture(
|
||||
Response.status(Status.BAD_REQUEST).entity("receipt serial is already redeemed")
|
||||
.type(MediaType.TEXT_PLAIN_TYPE).build());
|
||||
}
|
||||
|
||||
return accountsManager.getByAccountIdentifierAsync(auth.getAccount().getUuid())
|
||||
.thenCompose(optionalAccount ->
|
||||
optionalAccount.map(account -> accountsManager.updateAsync(account, a -> {
|
||||
return accountsManager.updateAsync(auth.getAccount(), a -> {
|
||||
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
|
||||
if (request.isPrimary()) {
|
||||
a.makeBadgePrimaryIfExists(clock, badgeId);
|
||||
}
|
||||
})).orElse(CompletableFuture.completedFuture(null)))
|
||||
})
|
||||
.thenApply(ignored -> Response.ok().build());
|
||||
});
|
||||
}).thenCompose(Function.identity());
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import org.whispersystems.textsecuregcm.auth.TurnToken;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public record GetCallingRelaysResponse(List<TurnToken> relays) {
|
||||
}
|
|
@ -152,7 +152,7 @@ public class KeysController {
|
|||
|
||||
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
|
||||
|
||||
if (setKeysRequest.preKeys() != null && !setKeysRequest.preKeys().isEmpty()) {
|
||||
if (!setKeysRequest.preKeys().isEmpty()) {
|
||||
Metrics.counter(STORE_KEYS_COUNTER_NAME,
|
||||
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec")))
|
||||
.increment();
|
||||
|
@ -168,7 +168,7 @@ public class KeysController {
|
|||
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
|
||||
}
|
||||
|
||||
if (setKeysRequest.pqPreKeys() != null && !setKeysRequest.pqPreKeys().isEmpty()) {
|
||||
if (!setKeysRequest.pqPreKeys().isEmpty()) {
|
||||
Metrics.counter(STORE_KEYS_COUNTER_NAME,
|
||||
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber")))
|
||||
.increment();
|
||||
|
@ -192,11 +192,7 @@ public class KeysController {
|
|||
final IdentityKey identityKey,
|
||||
@Nullable final String userAgent) {
|
||||
|
||||
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>();
|
||||
|
||||
if (setKeysRequest.pqPreKeys() != null) {
|
||||
signedPreKeys.addAll(setKeysRequest.pqPreKeys());
|
||||
}
|
||||
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(setKeysRequest.pqPreKeys());
|
||||
|
||||
if (setKeysRequest.pqLastResortPreKey() != null) {
|
||||
signedPreKeys.add(setKeysRequest.pqLastResortPreKey());
|
||||
|
@ -244,8 +240,7 @@ public class KeysController {
|
|||
@ApiResponse(responseCode = "422", description = "Invalid request format")
|
||||
public CompletableFuture<Response> checkKeys(
|
||||
@ReadOnly @Auth final AuthenticatedDevice auth,
|
||||
@RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
|
||||
@RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest) {
|
||||
|
||||
final UUID identifier = auth.getAccount().getIdentifier(checkKeysRequest.identityType());
|
||||
final byte deviceId = auth.getAuthenticatedDevice().getId();
|
||||
|
@ -386,10 +381,7 @@ public class KeysController {
|
|||
.increment();
|
||||
|
||||
if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
|
||||
final int registrationId = switch (targetIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
|
||||
};
|
||||
final int registrationId = device.getRegistrationId(targetIdentifier.identityType());
|
||||
|
||||
responseItems.add(
|
||||
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey,
|
||||
|
|
|
@ -43,7 +43,6 @@ import jakarta.ws.rs.core.Response;
|
|||
import jakarta.ws.rs.core.Response.Status;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
@ -96,6 +95,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
|||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
|
||||
import org.whispersystems.textsecuregcm.push.MessageUtil;
|
||||
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
|
||||
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
|
||||
import org.whispersystems.textsecuregcm.push.ReceiptSender;
|
||||
|
@ -105,7 +105,6 @@ import org.whispersystems.textsecuregcm.spam.SpamChecker;
|
|||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
|
||||
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
|
||||
|
@ -114,11 +113,7 @@ import org.whispersystems.textsecuregcm.util.Util;
|
|||
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
|
||||
import org.whispersystems.websocket.WebsocketHeaders;
|
||||
import org.whispersystems.websocket.auth.ReadOnly;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
import reactor.util.function.Tuple2;
|
||||
import reactor.util.function.Tuples;
|
||||
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
||||
@Path("/v1/messages")
|
||||
|
@ -145,8 +140,6 @@ public class MessageController {
|
|||
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
|
||||
private final Clock clock;
|
||||
|
||||
private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
|
||||
|
||||
private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture<?>[0];
|
||||
|
||||
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
|
||||
|
@ -443,11 +436,16 @@ public class MessageController {
|
|||
final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
|
||||
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
|
||||
|
||||
final Optional<Byte> syncMessageSenderDeviceId = messageType == MessageType.SYNC
|
||||
? Optional.ofNullable(sender).map(authenticatedDevice -> authenticatedDevice.getAuthenticatedDevice().getId())
|
||||
: Optional.empty();
|
||||
|
||||
try {
|
||||
messageSender.sendMessages(destination,
|
||||
destinationIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId,
|
||||
syncMessageSenderDeviceId,
|
||||
userAgent);
|
||||
} catch (final MismatchedDevicesException e) {
|
||||
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
|
||||
|
@ -563,7 +561,9 @@ public class MessageController {
|
|||
final ContainerRequestContext context) {
|
||||
|
||||
// Perform fast, inexpensive checks before attempting to resolve recipients
|
||||
validateNoDuplicateDevices(multiRecipientMessage);
|
||||
if (MessageUtil.hasDuplicateDevices(multiRecipientMessage)) {
|
||||
throw new BadRequestException("Multi-recipient message contains duplicate recipient");
|
||||
}
|
||||
|
||||
if (groupSendTokenHeader == null && combinedUnidentifiedSenderAccessKeys == null) {
|
||||
throw new NotAuthorizedException("A group send endorsement token or unidentified access key is required for non-story messages");
|
||||
|
@ -582,7 +582,14 @@ public class MessageController {
|
|||
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient
|
||||
// message. Attempt to resolve the destination service identifiers to Signal accounts.
|
||||
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
|
||||
resolveRecipients(multiRecipientMessage, groupSendTokenHeader == null);
|
||||
MessageUtil.resolveRecipients(accountsManager, multiRecipientMessage);
|
||||
|
||||
final List<ServiceIdentifier> unresolvedRecipientServiceIdentifiers =
|
||||
MessageUtil.getUnresolvedRecipients(multiRecipientMessage, resolvedRecipients);
|
||||
|
||||
if (groupSendTokenHeader == null && !unresolvedRecipientServiceIdentifiers.isEmpty()) {
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
// Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above.
|
||||
// Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them
|
||||
|
@ -598,17 +605,6 @@ public class MessageController {
|
|||
urgent,
|
||||
context);
|
||||
|
||||
final List<ServiceIdentifier> unresolvedRecipientServiceIdentifiers;
|
||||
|
||||
if (groupSendTokenHeader != null) {
|
||||
unresolvedRecipientServiceIdentifiers = multiRecipientMessage.getRecipients().entrySet().stream()
|
||||
.filter(entry -> !resolvedRecipients.containsKey(entry.getValue()))
|
||||
.map(entry -> ServiceIdentifier.fromLibsignal(entry.getKey()))
|
||||
.toList();
|
||||
} else {
|
||||
unresolvedRecipientServiceIdentifiers = List.of();
|
||||
}
|
||||
|
||||
return new SendMultiRecipientMessageResponse(unresolvedRecipientServiceIdentifiers);
|
||||
}
|
||||
|
||||
|
@ -620,12 +616,14 @@ public class MessageController {
|
|||
final ContainerRequestContext context) {
|
||||
|
||||
// Perform fast, inexpensive checks before attempting to resolve recipients
|
||||
validateNoDuplicateDevices(multiRecipientMessage);
|
||||
if (MessageUtil.hasDuplicateDevices(multiRecipientMessage)) {
|
||||
throw new BadRequestException("Multi-recipient message contains duplicate recipient");
|
||||
}
|
||||
|
||||
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient
|
||||
// message. Attempt to resolve the destination service identifiers to Signal accounts.
|
||||
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
|
||||
resolveRecipients(multiRecipientMessage, false);
|
||||
MessageUtil.resolveRecipients(accountsManager, multiRecipientMessage);
|
||||
|
||||
// We might filter out all the recipients of a story (if none exist).
|
||||
// In this case there is no error so we should just return 200 now.
|
||||
|
@ -909,43 +907,4 @@ public class MessageController {
|
|||
return Response.status(Status.ACCEPTED)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static void validateNoDuplicateDevices(final SealedSenderMultiRecipientMessage multiRecipientMessage) {
|
||||
final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID + 1];
|
||||
|
||||
for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) {
|
||||
if (recipient.getDevices().length == 1) {
|
||||
// A recipient can't have repeated devices if they only have one device
|
||||
continue;
|
||||
}
|
||||
|
||||
Arrays.fill(usedDeviceIds, false);
|
||||
|
||||
for (final byte deviceId : recipient.getDevices()) {
|
||||
if (usedDeviceIds[deviceId]) {
|
||||
throw new BadRequestException();
|
||||
}
|
||||
|
||||
usedDeviceIds[deviceId] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolveRecipients(final SealedSenderMultiRecipientMessage multiRecipientMessage,
|
||||
final boolean throwOnNotFound) {
|
||||
|
||||
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
|
||||
.flatMap(serviceIdAndRecipient -> {
|
||||
final ServiceIdentifier serviceIdentifier =
|
||||
ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
|
||||
|
||||
return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
|
||||
.flatMap(Mono::justOrEmpty)
|
||||
.switchIfEmpty(throwOnNotFound ? Mono.error(NotFoundException::new) : Mono.empty())
|
||||
.map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account));
|
||||
}, MAX_FETCH_ACCOUNT_CONCURRENCY)
|
||||
.collectMap(Tuple2::getT1, Tuple2::getT2)
|
||||
.blockOptional()
|
||||
.orElse(Collections.emptyMap());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -428,7 +428,7 @@ public class OneTimeDonationController {
|
|||
@Nullable
|
||||
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
|
||||
try {
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString).platform();
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ import java.util.concurrent.CompletableFuture;
|
|||
import java.util.concurrent.Executor;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
import org.glassfish.jersey.server.ManagedAsync;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.signal.libsignal.protocol.ServiceId;
|
||||
|
@ -123,6 +124,7 @@ public class ProfileController {
|
|||
private static final String EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE = "expiringProfileKey";
|
||||
|
||||
private static final String VERSION_NOT_FOUND_COUNTER_NAME = name(ProfileController.class, "versionNotFound");
|
||||
private static final String DUPLICATE_AUTHENTICATION_COUNTER_NAME = name(ProfileController.class, "duplicateAuthentication");
|
||||
|
||||
public ProfileController(
|
||||
Clock clock,
|
||||
|
@ -204,11 +206,12 @@ public class ProfileController {
|
|||
.build()));
|
||||
}
|
||||
|
||||
final List<AccountBadge> updatedBadges = request.badges()
|
||||
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, auth.getAccount().getBadges()))
|
||||
.orElseGet(() -> auth.getAccount().getBadges());
|
||||
|
||||
accountsManager.update(auth.getAccount(), a -> {
|
||||
|
||||
final List<AccountBadge> updatedBadges = request.badges()
|
||||
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
|
||||
.orElseGet(a::getBadges);
|
||||
|
||||
a.setBadges(clock, updatedBadges);
|
||||
a.setCurrentProfileVersion(request.version());
|
||||
});
|
||||
|
@ -229,11 +232,12 @@ public class ProfileController {
|
|||
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
|
||||
@Context ContainerRequestContext containerRequestContext,
|
||||
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
|
||||
@PathParam("version") String version)
|
||||
@PathParam("version") String version,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
|
||||
throws RateLimitExceededException {
|
||||
|
||||
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
|
||||
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier);
|
||||
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "getVersionedProfile", userAgent);
|
||||
|
||||
return buildVersionedProfileResponse(targetAccount,
|
||||
version,
|
||||
|
@ -252,7 +256,8 @@ public class ProfileController {
|
|||
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
|
||||
@PathParam("version") String version,
|
||||
@PathParam("credentialRequest") String credentialRequest,
|
||||
@QueryParam("credentialType") String credentialType)
|
||||
@QueryParam("credentialType") String credentialType,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
|
||||
throws RateLimitExceededException {
|
||||
|
||||
if (!EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE.equals(credentialType)) {
|
||||
|
@ -260,7 +265,7 @@ public class ProfileController {
|
|||
}
|
||||
|
||||
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
|
||||
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier);
|
||||
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "credentialRequest", userAgent);
|
||||
final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false);
|
||||
|
||||
return buildExpiringProfileKeyCredentialProfileResponse(targetAccount,
|
||||
|
@ -282,8 +287,7 @@ public class ProfileController {
|
|||
@HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken,
|
||||
@Context ContainerRequestContext containerRequestContext,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
|
||||
@PathParam("identifier") ServiceIdentifier identifier,
|
||||
@QueryParam("ca") boolean useCaCertificate)
|
||||
@PathParam("identifier") ServiceIdentifier identifier)
|
||||
throws RateLimitExceededException {
|
||||
|
||||
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
|
||||
|
@ -302,7 +306,7 @@ public class ProfileController {
|
|||
}
|
||||
} else {
|
||||
targetAccount = verifyPermissionToReceiveProfile(
|
||||
maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier);
|
||||
maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier, "getUnversionedProfile", userAgent);
|
||||
}
|
||||
return switch (identifier.identityType()) {
|
||||
case ACI -> buildBaseProfileResponseForAccountIdentity(targetAccount,
|
||||
|
@ -385,7 +389,7 @@ public class ProfileController {
|
|||
profileKeyCredentialResponse = ProfileHelper.getExpiringProfileKeyCredential(HexFormat.of().parseHex(encodedCredentialRequest),
|
||||
profile, new ServiceId.Aci(account.getUuid()), zkProfileOperations);
|
||||
} catch (VerificationFailedException | InvalidInputException e) {
|
||||
throw new BadRequestException(Response.status(Response.Status.BAD_REQUEST).build(), e);
|
||||
throw new BadRequestException(e);
|
||||
}
|
||||
return profileKeyCredentialResponse;
|
||||
})
|
||||
|
@ -473,7 +477,15 @@ public class ProfileController {
|
|||
*/
|
||||
private Account verifyPermissionToReceiveProfile(final Optional<Account> maybeRequester,
|
||||
final Optional<Anonymous> maybeAccessKey,
|
||||
final ServiceIdentifier accountIdentifier) throws RateLimitExceededException {
|
||||
final ServiceIdentifier accountIdentifier,
|
||||
final String endpoint,
|
||||
@Nullable final String userAgent) throws RateLimitExceededException {
|
||||
|
||||
if (maybeRequester.isPresent() && maybeAccessKey.isPresent()) {
|
||||
Metrics.counter(DUPLICATE_AUTHENTICATION_COUNTER_NAME,
|
||||
Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), io.micrometer.core.instrument.Tag.of("endpoint", endpoint)))
|
||||
.increment();
|
||||
}
|
||||
|
||||
if (maybeRequester.isPresent()) {
|
||||
rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid());
|
||||
|
|
|
@ -755,7 +755,7 @@ public class SubscriptionController {
|
|||
@Nullable
|
||||
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
|
||||
try {
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString).platform();
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -7,30 +7,30 @@ package org.whispersystems.textsecuregcm.entities;
|
|||
|
||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import java.util.UUID;
|
||||
import javax.annotation.Nullable;
|
||||
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
|
||||
|
||||
public record AccountIdentityResponse(
|
||||
@Schema(description="the account identifier for this account")
|
||||
@Schema(description = "the account identifier for this account")
|
||||
UUID uuid,
|
||||
|
||||
@Schema(description="the phone number associated with this account")
|
||||
@Schema(description = "the phone number associated with this account")
|
||||
String number,
|
||||
|
||||
@Schema(description="the account identifier for this account's phone-number identity")
|
||||
@Schema(description = "the account identifier for this account's phone-number identity")
|
||||
UUID pni,
|
||||
|
||||
@Schema(description="a hash of this account's username, if set")
|
||||
@Schema(description = "a hash of this account's username, if set")
|
||||
@JsonSerialize(using = ByteArrayBase64UrlAdapter.Serializing.class)
|
||||
@JsonDeserialize(using = ByteArrayBase64UrlAdapter.Deserializing.class)
|
||||
@Nullable byte[] usernameHash,
|
||||
|
||||
@Schema(description="this account's username link handle, if set")
|
||||
@Schema(description = "this account's username link handle, if set")
|
||||
@Nullable UUID usernameLinkHandle,
|
||||
|
||||
@Schema(description="whether any of this account's devices support storage")
|
||||
@Schema(description = "whether any of this account's devices support storage")
|
||||
boolean storageCapable,
|
||||
|
||||
@Schema(description = "entitlements for this account and their current expirations")
|
||||
|
|
|
@ -17,6 +17,7 @@ import javax.annotation.Nullable;
|
|||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.AssertTrue;
|
||||
import jakarta.validation.constraints.NotBlank;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
|
||||
|
@ -51,34 +52,27 @@ public record ChangeNumberRequest(
|
|||
arraySchema=@Schema(description="""
|
||||
A list of synchronization messages to send to companion devices to supply the private keysManager
|
||||
associated with the new identity key and their new prekeys.
|
||||
Exactly one message must be supplied for each enabled device other than the sending (primary) device."""))
|
||||
Exactly one message must be supplied for each device other than the sending (primary) device."""))
|
||||
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
|
||||
|
||||
@Schema(description="""
|
||||
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
|
||||
A new signed elliptic-curve prekey for each device on the account, including this one.
|
||||
Each must be accompanied by a valid signature from the new identity key in this request.""")
|
||||
@NotNull @Valid Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
|
||||
@NotNull @NotEmpty @Valid Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
|
||||
|
||||
@Schema(description="""
|
||||
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
|
||||
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
|
||||
If present, must contain one prekey per enabled device including this one.
|
||||
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
|
||||
A new signed post-quantum last-resort prekey for each device on the account, including this one.
|
||||
Each must be accompanied by a valid signature from the new identity key in this request.""")
|
||||
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
|
||||
@NotNull @NotEmpty @Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
|
||||
|
||||
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
|
||||
@NotNull Map<Byte, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
|
||||
@Schema(description="the new phone-number-identity registration ID for each device on the account, including this one")
|
||||
@NotNull @NotEmpty Map<Byte, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
|
||||
|
||||
public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) {
|
||||
List<SignedPreKey<?>> spks = new ArrayList<>();
|
||||
if (devicePniSignedPrekeys != null) {
|
||||
spks.addAll(devicePniSignedPrekeys.values());
|
||||
}
|
||||
if (devicePniPqLastResortPrekeys != null) {
|
||||
final List<SignedPreKey<?>> spks = new ArrayList<>(devicePniSignedPrekeys.values());
|
||||
spks.addAll(devicePniPqLastResortPrekeys.values());
|
||||
}
|
||||
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number");
|
||||
|
||||
return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number");
|
||||
}
|
||||
|
||||
@AssertTrue
|
||||
|
|
|
@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
|||
import io.swagger.v3.oas.annotations.media.ArraySchema;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -29,36 +30,36 @@ public record PhoneNumberIdentityKeyDistributionRequest(
|
|||
arraySchema=@Schema(description="""
|
||||
A list of synchronization messages to send to companion devices to supply the private keys
|
||||
associated with the new identity key and their new prekeys.
|
||||
Exactly one message must be supplied for each enabled device other than the sending (primary) device.
|
||||
Exactly one message must be supplied for each device other than the sending (primary) device.
|
||||
"""))
|
||||
List<@NotNull @Valid IncomingMessage> deviceMessages,
|
||||
|
||||
@NotNull
|
||||
@NotEmpty
|
||||
@Valid
|
||||
@Schema(description="""
|
||||
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
|
||||
A new signed elliptic-curve prekey for each device on the account, including this one.
|
||||
Each must be accompanied by a valid signature from the new identity key in this request.""")
|
||||
Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
|
||||
|
||||
@NotNull
|
||||
@NotEmpty
|
||||
@Valid
|
||||
@Schema(description="""
|
||||
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
|
||||
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
|
||||
If present, must contain one prekey per enabled device including this one.
|
||||
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
|
||||
A new signed post-quantum last-resort prekey for each device on the account, including this one.
|
||||
Each must be accompanied by a valid signature from the new identity key in this request.""")
|
||||
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
|
||||
Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
|
||||
|
||||
@NotNull
|
||||
@NotEmpty
|
||||
@Valid
|
||||
@Schema(description="The new registration ID to use for the phone-number identity of each device, including this one.")
|
||||
Map<Byte, Integer> pniRegistrationIds) {
|
||||
|
||||
public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) {
|
||||
List<SignedPreKey<?>> spks = new ArrayList<>(devicePniSignedPrekeys.values());
|
||||
if (devicePniPqLastResortPrekeys != null) {
|
||||
spks.addAll(devicePniPqLastResortPrekeys.values());
|
||||
}
|
||||
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "distribute-pni-keys");
|
||||
}
|
||||
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(devicePniSignedPrekeys.values());
|
||||
signedPreKeys.addAll(devicePniPqLastResortPrekeys.values());
|
||||
|
||||
return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, signedPreKeys, userAgent, "distribute-pni-keys");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,11 +5,15 @@
|
|||
package org.whispersystems.textsecuregcm.entities;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import jakarta.validation.Valid;
|
||||
import java.util.List;
|
||||
|
||||
public record SetKeysRequest(
|
||||
@NotNull
|
||||
@Valid
|
||||
@Size(max=100)
|
||||
@Schema(description = """
|
||||
A list of unsigned elliptic-curve prekeys to use for this device. If present and not empty, replaces all stored
|
||||
unsigned EC prekeys for the device; if absent or empty, any stored unsigned EC prekeys for the device are not
|
||||
|
@ -25,7 +29,9 @@ public record SetKeysRequest(
|
|||
""")
|
||||
ECSignedPreKey signedPreKey,
|
||||
|
||||
@NotNull
|
||||
@Valid
|
||||
@Size(max=100)
|
||||
@Schema(description = """
|
||||
A list of signed post-quantum one-time prekeys to use for this device. Each key must have a valid signature from
|
||||
the identity key in this request. If present and not empty, replaces all stored unsigned PQ prekeys for the
|
||||
|
@ -40,4 +46,16 @@ public record SetKeysRequest(
|
|||
deleted. If present, must have a valid signature from the identity key in this request.
|
||||
""")
|
||||
KEMSignedPreKey pqLastResortPreKey) {
|
||||
public SetKeysRequest {
|
||||
// It’s a little counter-intuitive, but this compact constructor allows a default value
|
||||
// to be used when one isn’t specified, allowing the field to still be
|
||||
// validated as @NotNull
|
||||
if (preKeys == null) {
|
||||
preKeys = List.of();
|
||||
}
|
||||
|
||||
if (pqPreKeys == null) {
|
||||
pqPreKeys = List.of();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,7 +81,16 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) {
|
||||
@Nullable final UserAgent userAgent = RequestAttributesUtil.getUserAgent()
|
||||
.map(userAgentString -> {
|
||||
try {
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString);
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
return null;
|
||||
}
|
||||
}).orElse(null);
|
||||
|
||||
if (shouldBlock(userAgent)) {
|
||||
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
|
||||
return new ServerCall.Listener<>() {};
|
||||
} else {
|
||||
|
@ -108,28 +117,28 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
|
|||
return true;
|
||||
}
|
||||
|
||||
if (blockedVersionsByPlatform.containsKey(userAgent.getPlatform())) {
|
||||
if (blockedVersionsByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) {
|
||||
if (blockedVersionsByPlatform.containsKey(userAgent.platform())) {
|
||||
if (blockedVersionsByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
|
||||
recordDeprecation(userAgent, BLOCKED_CLIENT_REASON);
|
||||
shouldBlock = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (minimumVersionsByPlatform.containsKey(userAgent.getPlatform())) {
|
||||
if (userAgent.getVersion().isLowerThan(minimumVersionsByPlatform.get(userAgent.getPlatform()))) {
|
||||
if (minimumVersionsByPlatform.containsKey(userAgent.platform())) {
|
||||
if (userAgent.version().isLowerThan(minimumVersionsByPlatform.get(userAgent.platform()))) {
|
||||
recordDeprecation(userAgent, EXPIRED_CLIENT_REASON);
|
||||
shouldBlock = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (versionsPendingBlockByPlatform.containsKey(userAgent.getPlatform())) {
|
||||
if (versionsPendingBlockByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) {
|
||||
if (versionsPendingBlockByPlatform.containsKey(userAgent.platform())) {
|
||||
if (versionsPendingBlockByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
|
||||
recordPendingDeprecation(userAgent, BLOCKED_CLIENT_REASON);
|
||||
}
|
||||
}
|
||||
|
||||
if (versionsPendingDeprecationByPlatform.containsKey(userAgent.getPlatform())) {
|
||||
if (userAgent.getVersion().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.getPlatform()))) {
|
||||
if (versionsPendingDeprecationByPlatform.containsKey(userAgent.platform())) {
|
||||
if (userAgent.version().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.platform()))) {
|
||||
recordPendingDeprecation(userAgent, EXPIRED_CLIENT_REASON);
|
||||
}
|
||||
}
|
||||
|
@ -139,13 +148,13 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
|
|||
|
||||
private void recordDeprecation(final UserAgent userAgent, final String reason) {
|
||||
Metrics.counter(DEPRECATED_CLIENT_COUNTER_NAME,
|
||||
PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized",
|
||||
PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized",
|
||||
REASON_TAG_NAME, reason).increment();
|
||||
}
|
||||
|
||||
private void recordPendingDeprecation(final UserAgent userAgent, final String reason) {
|
||||
Metrics.counter(PENDING_DEPRECATION_COUNTER_NAME,
|
||||
PLATFORM_TAG, userAgent.getPlatform().name().toLowerCase(),
|
||||
PLATFORM_TAG, userAgent.platform().name().toLowerCase(),
|
||||
REASON_TAG_NAME, reason).increment();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,8 +15,6 @@ import jakarta.ws.rs.container.ContainerRequestFilter;
|
|||
import jakarta.ws.rs.core.SecurityContext;
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||
|
@ -70,8 +68,8 @@ public class RestDeprecationFilter implements ContainerRequestFilter {
|
|||
|
||||
try {
|
||||
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
|
||||
final ClientPlatform platform = userAgent.getPlatform();
|
||||
final Semver version = userAgent.getVersion();
|
||||
final ClientPlatform platform = userAgent.platform();
|
||||
final Semver version = userAgent.version();
|
||||
if (!minimumRestFreeVersion.containsKey(platform)) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
/**
|
||||
* Indicates that a remote channel was not found for a given server call or remote address.
|
||||
*/
|
||||
public class ChannelNotFoundException extends Exception {
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.ForwardingServerCallListener;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
/**
|
||||
* Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with
|
||||
* {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise.
|
||||
*/
|
||||
public class ChannelShutdownInterceptor implements ServerInterceptor {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (!grpcClientConnectionManager.handleServerCallStart(call)) {
|
||||
// Don't allow new calls if the connection is getting ready to close
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
|
||||
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {
|
||||
@Override
|
||||
public void onComplete() {
|
||||
grpcClientConnectionManager.handleServerCallComplete(call);
|
||||
super.onComplete();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancel() {
|
||||
grpcClientConnectionManager.handleServerCallComplete(call);
|
||||
super.onCancel();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -10,8 +10,12 @@ import org.whispersystems.textsecuregcm.storage.Device;
|
|||
|
||||
public class DeviceIdUtil {
|
||||
|
||||
public static boolean isValid(int deviceId) {
|
||||
return deviceId >= Device.PRIMARY_ID && deviceId <= Byte.MAX_VALUE;
|
||||
}
|
||||
|
||||
static byte validate(int deviceId) {
|
||||
if (deviceId < Device.PRIMARY_ID || deviceId > Byte.MAX_VALUE) {
|
||||
if (!isValid(deviceId)) {
|
||||
throw Status.INVALID_ARGUMENT.withDescription("Device ID is out of range").asRuntimeException();
|
||||
}
|
||||
|
||||
|
|
|
@ -7,8 +7,9 @@ package org.whispersystems.textsecuregcm.grpc;
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Status;
|
||||
|
||||
import io.grpc.StatusException;
|
||||
import java.time.Clock;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import org.signal.libsignal.protocol.ServiceId;
|
||||
import org.signal.libsignal.zkgroup.InvalidInputException;
|
||||
|
@ -18,8 +19,6 @@ import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair;
|
|||
import org.signal.libsignal.zkgroup.groupsend.GroupSendFullToken;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
public class GroupSendTokenUtil {
|
||||
|
||||
private final ServerSecretParams serverSecretParams;
|
||||
|
@ -30,16 +29,22 @@ public class GroupSendTokenUtil {
|
|||
this.clock = clock;
|
||||
}
|
||||
|
||||
public Mono<Void> checkGroupSendToken(final ByteString serializedGroupSendToken, List<ServiceIdentifier> serviceIdentifiers) {
|
||||
public void checkGroupSendToken(final ByteString serializedGroupSendToken,
|
||||
final ServiceIdentifier serviceIdentifier) throws StatusException {
|
||||
|
||||
checkGroupSendToken(serializedGroupSendToken, List.of(serviceIdentifier.toLibsignal()));
|
||||
}
|
||||
|
||||
public void checkGroupSendToken(final ByteString serializedGroupSendToken,
|
||||
final Collection<ServiceId> serviceIds) throws StatusException {
|
||||
|
||||
try {
|
||||
final GroupSendFullToken token = new GroupSendFullToken(serializedGroupSendToken.toByteArray());
|
||||
final List<ServiceId> serviceIds = serviceIdentifiers.stream().map(ServiceIdentifier::toLibsignal).toList();
|
||||
token.verify(serviceIds, clock.instant(), GroupSendDerivedKeyPair.forExpiration(token.getExpiration(), serverSecretParams));
|
||||
return Mono.empty();
|
||||
} catch (InvalidInputException e) {
|
||||
return Mono.error(Status.INVALID_ARGUMENT.asException());
|
||||
} catch (final InvalidInputException e) {
|
||||
throw Status.INVALID_ARGUMENT.asException();
|
||||
} catch (VerificationFailedException e) {
|
||||
return Mono.error(Status.UNAUTHENTICATED.asException());
|
||||
throw Status.UNAUTHENTICATED.asException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,12 +7,11 @@ package org.whispersystems.textsecuregcm.grpc;
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusException;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.time.Clock;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.signal.chat.keys.CheckIdentityKeyRequest;
|
||||
import org.signal.chat.keys.CheckIdentityKeyResponse;
|
||||
import org.signal.chat.keys.GetPreKeysAnonymousRequest;
|
||||
|
@ -52,16 +51,24 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony
|
|||
: KeysGrpcHelper.ALL_DEVICES;
|
||||
|
||||
return switch (request.getAuthorizationCase()) {
|
||||
case GROUP_SEND_TOKEN ->
|
||||
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), List.of(serviceIdentifier))
|
||||
.then(lookUpAccount(serviceIdentifier, Status.NOT_FOUND))
|
||||
case GROUP_SEND_TOKEN -> {
|
||||
try {
|
||||
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), serviceIdentifier);
|
||||
|
||||
yield lookUpAccount(serviceIdentifier, Status.NOT_FOUND)
|
||||
.flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager));
|
||||
} catch (final StatusException e) {
|
||||
yield Mono.error(e);
|
||||
}
|
||||
}
|
||||
|
||||
case UNIDENTIFIED_ACCESS_KEY ->
|
||||
lookUpAccount(serviceIdentifier, Status.UNAUTHENTICATED)
|
||||
.flatMap(targetAccount ->
|
||||
UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray())
|
||||
? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager)
|
||||
: Mono.error(Status.UNAUTHENTICATED.asException()));
|
||||
|
||||
default -> Mono.error(Status.INVALID_ARGUMENT.asException());
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,302 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusException;
|
||||
import java.time.Clock;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import org.signal.chat.messages.IndividualRecipientMessageBundle;
|
||||
import org.signal.chat.messages.MultiRecipientMismatchedDevices;
|
||||
import org.signal.chat.messages.SendMessageResponse;
|
||||
import org.signal.chat.messages.SendMultiRecipientMessageRequest;
|
||||
import org.signal.chat.messages.SendMultiRecipientMessageResponse;
|
||||
import org.signal.chat.messages.SendMultiRecipientStoryRequest;
|
||||
import org.signal.chat.messages.SendSealedSenderMessageRequest;
|
||||
import org.signal.chat.messages.SendStoryMessageRequest;
|
||||
import org.signal.chat.messages.SimpleMessagesAnonymousGrpc;
|
||||
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||
import org.signal.libsignal.protocol.InvalidVersionException;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
|
||||
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
|
||||
import org.whispersystems.textsecuregcm.push.MessageUtil;
|
||||
import org.whispersystems.textsecuregcm.spam.GrpcResponse;
|
||||
import org.whispersystems.textsecuregcm.spam.MessageType;
|
||||
import org.whispersystems.textsecuregcm.spam.SpamCheckResult;
|
||||
import org.whispersystems.textsecuregcm.spam.SpamChecker;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
|
||||
public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.MessagesAnonymousImplBase {
|
||||
|
||||
private final AccountsManager accountsManager;
|
||||
private final RateLimiters rateLimiters;
|
||||
private final MessageSender messageSender;
|
||||
private final GroupSendTokenUtil groupSendTokenUtil;
|
||||
private final CardinalityEstimator messageByteLimitEstimator;
|
||||
private final SpamChecker spamChecker;
|
||||
private final Clock clock;
|
||||
|
||||
private static final SendMessageResponse SEND_MESSAGE_SUCCESS_RESPONSE = SendMessageResponse.newBuilder().build();
|
||||
|
||||
public MessagesAnonymousGrpcService(final AccountsManager accountsManager,
|
||||
final RateLimiters rateLimiters,
|
||||
final MessageSender messageSender,
|
||||
final GroupSendTokenUtil groupSendTokenUtil,
|
||||
final CardinalityEstimator messageByteLimitEstimator,
|
||||
final SpamChecker spamChecker,
|
||||
final Clock clock) {
|
||||
|
||||
this.accountsManager = accountsManager;
|
||||
this.rateLimiters = rateLimiters;
|
||||
this.messageSender = messageSender;
|
||||
this.messageByteLimitEstimator = messageByteLimitEstimator;
|
||||
this.spamChecker = spamChecker;
|
||||
this.clock = clock;
|
||||
this.groupSendTokenUtil = groupSendTokenUtil;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SendMessageResponse sendSingleRecipientMessage(final SendSealedSenderMessageRequest request)
|
||||
throws StatusException, RateLimitExceededException {
|
||||
|
||||
final ServiceIdentifier destinationServiceIdentifier =
|
||||
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination());
|
||||
|
||||
final Account destination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier)
|
||||
.orElseThrow(Status.UNAUTHENTICATED::asException);
|
||||
|
||||
switch (request.getAuthorizationCase()) {
|
||||
case UNIDENTIFIED_ACCESS_KEY -> {
|
||||
if (!UnidentifiedAccessUtil.checkUnidentifiedAccess(destination, request.getUnidentifiedAccessKey().toByteArray())) {
|
||||
throw Status.UNAUTHENTICATED.asException();
|
||||
}
|
||||
}
|
||||
case GROUP_SEND_TOKEN ->
|
||||
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), destinationServiceIdentifier);
|
||||
|
||||
case AUTHORIZATION_NOT_SET -> throw Status.UNAUTHENTICATED.asException();
|
||||
}
|
||||
|
||||
return sendIndividualMessage(destination,
|
||||
destinationServiceIdentifier,
|
||||
request.getMessages(),
|
||||
request.getEphemeral(),
|
||||
request.getUrgent(),
|
||||
false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SendMessageResponse sendStory(final SendStoryMessageRequest request)
|
||||
throws StatusException, RateLimitExceededException {
|
||||
|
||||
final ServiceIdentifier destinationServiceIdentifier =
|
||||
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination());
|
||||
|
||||
final Optional<Account> maybeDestination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier);
|
||||
|
||||
if (maybeDestination.isEmpty()) {
|
||||
// Don't reveal to unauthenticated callers whether a destination account actually exists
|
||||
return SEND_MESSAGE_SUCCESS_RESPONSE;
|
||||
}
|
||||
|
||||
final Account destination = maybeDestination.get();
|
||||
|
||||
rateLimiters.getStoriesLimiter().validate(destination.getIdentifier(IdentityType.ACI));
|
||||
|
||||
return sendIndividualMessage(destination,
|
||||
destinationServiceIdentifier,
|
||||
request.getMessages(),
|
||||
false,
|
||||
request.getUrgent(),
|
||||
true);
|
||||
}
|
||||
|
||||
private SendMessageResponse sendIndividualMessage(final Account destination,
|
||||
final ServiceIdentifier destinationServiceIdentifier,
|
||||
final IndividualRecipientMessageBundle messages,
|
||||
final boolean ephemeral,
|
||||
final boolean urgent,
|
||||
final boolean story) throws StatusException, RateLimitExceededException {
|
||||
|
||||
final SpamCheckResult<GrpcResponse<SendMessageResponse>> spamCheckResult =
|
||||
spamChecker.checkForIndividualRecipientSpamGrpc(
|
||||
story ? MessageType.INDIVIDUAL_STORY : MessageType.INDIVIDUAL_SEALED_SENDER,
|
||||
Optional.empty(),
|
||||
Optional.of(destination),
|
||||
destinationServiceIdentifier);
|
||||
|
||||
if (spamCheckResult.response().isPresent()) {
|
||||
return spamCheckResult.response().get().getResponseOrThrowStatus();
|
||||
}
|
||||
|
||||
try {
|
||||
final int totalPayloadLength = messages.getMessagesMap().values().stream()
|
||||
.mapToInt(message -> message.getPayload().size())
|
||||
.sum();
|
||||
|
||||
rateLimiters.getInboundMessageBytes().validate(destinationServiceIdentifier.uuid(), totalPayloadLength);
|
||||
} catch (final RateLimitExceededException e) {
|
||||
messageByteLimitEstimator.add(destinationServiceIdentifier.uuid().toString());
|
||||
throw e;
|
||||
}
|
||||
|
||||
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = messages.getMessagesMap().entrySet()
|
||||
.stream()
|
||||
.collect(Collectors.toMap(
|
||||
entry -> DeviceIdUtil.validate(entry.getKey()),
|
||||
entry -> {
|
||||
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
|
||||
.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER)
|
||||
.setClientTimestamp(messages.getTimestamp())
|
||||
.setServerTimestamp(clock.millis())
|
||||
.setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString())
|
||||
.setEphemeral(ephemeral)
|
||||
.setUrgent(urgent)
|
||||
.setStory(story)
|
||||
.setContent(entry.getValue().getPayload());
|
||||
|
||||
spamCheckResult.token().ifPresent(reportSpamToken ->
|
||||
envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken)));
|
||||
|
||||
return envelopeBuilder.build();
|
||||
}
|
||||
));
|
||||
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = messages.getMessagesMap().entrySet().stream()
|
||||
.collect(Collectors.toMap(
|
||||
entry -> entry.getKey().byteValue(),
|
||||
entry -> entry.getValue().getRegistrationId()));
|
||||
|
||||
return MessagesGrpcHelper.sendMessage(messageSender,
|
||||
destination,
|
||||
destinationServiceIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId,
|
||||
Optional.empty());
|
||||
}
|
||||
|
||||
@Override
|
||||
public SendMultiRecipientMessageResponse sendMultiRecipientMessage(final SendMultiRecipientMessageRequest request)
|
||||
throws StatusException {
|
||||
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
parseAndValidateMultiRecipientMessage(request.getMessage().getPayload().toByteArray());
|
||||
|
||||
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), multiRecipientMessage.getRecipients().keySet());
|
||||
|
||||
return sendMultiRecipientMessage(multiRecipientMessage,
|
||||
request.getMessage().getTimestamp(),
|
||||
request.getEphemeral(),
|
||||
request.getUrgent(),
|
||||
false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SendMultiRecipientMessageResponse sendMultiRecipientStory(final SendMultiRecipientStoryRequest request)
|
||||
throws StatusException {
|
||||
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
parseAndValidateMultiRecipientMessage(request.getMessage().getPayload().toByteArray());
|
||||
|
||||
return sendMultiRecipientMessage(multiRecipientMessage,
|
||||
request.getMessage().getTimestamp(),
|
||||
false,
|
||||
request.getUrgent(),
|
||||
true)
|
||||
.toBuilder()
|
||||
// Don't identify unresolved recipients for stories
|
||||
.clearUnresolvedRecipients()
|
||||
.build();
|
||||
}
|
||||
|
||||
private SendMultiRecipientMessageResponse sendMultiRecipientMessage(
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage,
|
||||
final long timestamp,
|
||||
final boolean ephemeral,
|
||||
final boolean urgent,
|
||||
final boolean story) throws StatusException {
|
||||
|
||||
final SpamCheckResult<GrpcResponse<SendMultiRecipientMessageResponse>> spamCheckResult =
|
||||
spamChecker.checkForMultiRecipientSpamGrpc(story
|
||||
? MessageType.MULTI_RECIPIENT_STORY
|
||||
: MessageType.MULTI_RECIPIENT_SEALED_SENDER);
|
||||
|
||||
if (spamCheckResult.response().isPresent()) {
|
||||
return spamCheckResult.response().get().getResponseOrThrowStatus();
|
||||
}
|
||||
|
||||
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient
|
||||
// message. Attempt to resolve the destination service identifiers to Signal accounts.
|
||||
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
|
||||
MessageUtil.resolveRecipients(accountsManager, multiRecipientMessage);
|
||||
|
||||
try {
|
||||
messageSender.sendMultiRecipientMessage(multiRecipientMessage,
|
||||
resolvedRecipients,
|
||||
timestamp,
|
||||
story,
|
||||
ephemeral,
|
||||
urgent,
|
||||
RequestAttributesUtil.getUserAgent().orElse(null));
|
||||
|
||||
final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder();
|
||||
|
||||
MessageUtil.getUnresolvedRecipients(multiRecipientMessage, resolvedRecipients).stream()
|
||||
.map(ServiceIdentifierUtil::toGrpcServiceIdentifier)
|
||||
.forEach(responseBuilder::addUnresolvedRecipients);
|
||||
|
||||
return responseBuilder.build();
|
||||
} catch (final MessageTooLargeException e) {
|
||||
throw Status.INVALID_ARGUMENT
|
||||
.withDescription("Message for an individual recipient was too large")
|
||||
.withCause(e)
|
||||
.asRuntimeException();
|
||||
} catch (final MultiRecipientMismatchedDevicesException e) {
|
||||
final MultiRecipientMismatchedDevices.Builder mismatchedDevicesBuilder =
|
||||
MultiRecipientMismatchedDevices.newBuilder();
|
||||
|
||||
e.getMismatchedDevicesByServiceIdentifier().forEach((serviceIdentifier, mismatchedDevices) ->
|
||||
mismatchedDevicesBuilder.addMismatchedDevices(MessagesGrpcHelper.buildMismatchedDevices(serviceIdentifier, mismatchedDevices)));
|
||||
|
||||
return SendMultiRecipientMessageResponse.newBuilder()
|
||||
.setMismatchedDevices(mismatchedDevicesBuilder)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
private SealedSenderMultiRecipientMessage parseAndValidateMultiRecipientMessage(
|
||||
final byte[] serializedMultiRecipientMessage) throws StatusException {
|
||||
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage;
|
||||
|
||||
try {
|
||||
multiRecipientMessage = SealedSenderMultiRecipientMessage.parse(serializedMultiRecipientMessage);
|
||||
} catch (final InvalidMessageException | InvalidVersionException e) {
|
||||
throw Status.INVALID_ARGUMENT.withCause(e).asException();
|
||||
}
|
||||
|
||||
// Check that the request is well-formed and doesn't contain repeated entries for the same device for the same
|
||||
// recipient
|
||||
if (MessageUtil.hasDuplicateDevices(multiRecipientMessage)) {
|
||||
throw Status.INVALID_ARGUMENT.withDescription("Multi-recipient message contains duplicate recipient").asException();
|
||||
}
|
||||
|
||||
return multiRecipientMessage;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusException;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import org.signal.chat.messages.MismatchedDevices;
|
||||
import org.signal.chat.messages.SendMessageResponse;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
|
||||
public class MessagesGrpcHelper {
|
||||
|
||||
private static final SendMessageResponse SEND_MESSAGE_SUCCESS_RESPONSE = SendMessageResponse.newBuilder().build();
|
||||
|
||||
/**
|
||||
* Sends a "bundle" of messages to an individual destination account, mapping common exceptions to appropriate gRPC
|
||||
* statuses.
|
||||
*
|
||||
* @param messageSender the {@code MessageSender} instance to use to send the messages
|
||||
* @param destination the destination account for the messages
|
||||
* @param destinationServiceIdentifier the service identifier for the destination account
|
||||
* @param messagesByDeviceId a map of device IDs to message payloads
|
||||
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
|
||||
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
|
||||
* caller's own account), contains the ID of the device that sent the message
|
||||
*
|
||||
* @return a response object to send to callers
|
||||
*
|
||||
* @throws StatusException if the message bundle could not be sent due to an out-of-date device set or an invalid
|
||||
* message payload
|
||||
* @throws RateLimitExceededException if the message bundle could not be sent due to a violated rated limit
|
||||
*/
|
||||
public static SendMessageResponse sendMessage(final MessageSender messageSender,
|
||||
final Account destination,
|
||||
final ServiceIdentifier destinationServiceIdentifier,
|
||||
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId,
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId)
|
||||
throws StatusException, RateLimitExceededException {
|
||||
|
||||
try {
|
||||
messageSender.sendMessages(destination,
|
||||
destinationServiceIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId,
|
||||
syncMessageSenderDeviceId,
|
||||
RequestAttributesUtil.getUserAgent().orElse(null));
|
||||
|
||||
return SEND_MESSAGE_SUCCESS_RESPONSE;
|
||||
} catch (final MismatchedDevicesException e) {
|
||||
return SendMessageResponse.newBuilder()
|
||||
.setMismatchedDevices(buildMismatchedDevices(destinationServiceIdentifier, e.getMismatchedDevices()))
|
||||
.build();
|
||||
} catch (final MessageTooLargeException e) {
|
||||
throw Status.INVALID_ARGUMENT.withDescription("Message too large").withCause(e).asException();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Translates an internal {@link org.whispersystems.textsecuregcm.controllers.MismatchedDevices} entity to a gRPC
|
||||
* {@link MismatchedDevices} entity.
|
||||
*
|
||||
* @param serviceIdentifier the service identifier to which the mismatched device response applies
|
||||
* @param mismatchedDevices the mismatched device entity to translate to gRPC
|
||||
*
|
||||
* @return a gRPC {@code MismatchedDevices} representation of the given mismatched devices
|
||||
*/
|
||||
public static MismatchedDevices buildMismatchedDevices(final ServiceIdentifier serviceIdentifier,
|
||||
final org.whispersystems.textsecuregcm.controllers.MismatchedDevices mismatchedDevices) {
|
||||
|
||||
final MismatchedDevices.Builder mismatchedDevicesBuilder = MismatchedDevices.newBuilder()
|
||||
.setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier));
|
||||
|
||||
mismatchedDevices.missingDeviceIds().forEach(mismatchedDevicesBuilder::addMissingDevices);
|
||||
mismatchedDevices.extraDeviceIds().forEach(mismatchedDevicesBuilder::addExtraDevices);
|
||||
mismatchedDevices.staleDeviceIds().forEach(mismatchedDevicesBuilder::addStaleDevices);
|
||||
|
||||
return mismatchedDevicesBuilder.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,188 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusException;
|
||||
import java.time.Clock;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import org.signal.chat.messages.AuthenticatedSenderMessageType;
|
||||
import org.signal.chat.messages.IndividualRecipientMessageBundle;
|
||||
import org.signal.chat.messages.SendAuthenticatedSenderMessageRequest;
|
||||
import org.signal.chat.messages.SendMessageResponse;
|
||||
import org.signal.chat.messages.SendSyncMessageRequest;
|
||||
import org.signal.chat.messages.SimpleMessagesGrpc;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.spam.GrpcResponse;
|
||||
import org.whispersystems.textsecuregcm.spam.MessageType;
|
||||
import org.whispersystems.textsecuregcm.spam.SpamCheckResult;
|
||||
import org.whispersystems.textsecuregcm.spam.SpamChecker;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
|
||||
public class MessagesGrpcService extends SimpleMessagesGrpc.MessagesImplBase {
|
||||
|
||||
private final AccountsManager accountsManager;
|
||||
private final RateLimiters rateLimiters;
|
||||
private final MessageSender messageSender;
|
||||
private final CardinalityEstimator messageByteLimitEstimator;
|
||||
private final SpamChecker spamChecker;
|
||||
private final Clock clock;
|
||||
|
||||
public MessagesGrpcService(final AccountsManager accountsManager,
|
||||
final RateLimiters rateLimiters,
|
||||
final MessageSender messageSender,
|
||||
final CardinalityEstimator messageByteLimitEstimator,
|
||||
final SpamChecker spamChecker,
|
||||
final Clock clock) {
|
||||
|
||||
this.accountsManager = accountsManager;
|
||||
this.rateLimiters = rateLimiters;
|
||||
this.messageSender = messageSender;
|
||||
this.messageByteLimitEstimator = messageByteLimitEstimator;
|
||||
this.spamChecker = spamChecker;
|
||||
this.clock = clock;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SendMessageResponse sendMessage(final SendAuthenticatedSenderMessageRequest request)
|
||||
throws StatusException, RateLimitExceededException {
|
||||
|
||||
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
|
||||
final AciServiceIdentifier senderServiceIdentifier = new AciServiceIdentifier(authenticatedDevice.accountIdentifier());
|
||||
final Account sender =
|
||||
accountsManager.getByServiceIdentifier(senderServiceIdentifier).orElseThrow(Status.UNAUTHENTICATED::asException);
|
||||
|
||||
final ServiceIdentifier destinationServiceIdentifier =
|
||||
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination());
|
||||
|
||||
if (sender.isIdentifiedBy(destinationServiceIdentifier)) {
|
||||
throw Status.INVALID_ARGUMENT
|
||||
.withDescription("Use `sendSyncMessage` to send messages to own account")
|
||||
.asException();
|
||||
}
|
||||
|
||||
final Account destination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier)
|
||||
.orElseThrow(Status.NOT_FOUND::asException);
|
||||
|
||||
rateLimiters.getMessagesLimiter().validate(authenticatedDevice.accountIdentifier(), destination.getUuid());
|
||||
|
||||
return sendMessage(destination,
|
||||
destinationServiceIdentifier,
|
||||
authenticatedDevice,
|
||||
request.getType(),
|
||||
MessageType.INDIVIDUAL_IDENTIFIED_SENDER,
|
||||
request.getMessages(),
|
||||
request.getEphemeral(),
|
||||
request.getUrgent());
|
||||
}
|
||||
|
||||
@Override
|
||||
public SendMessageResponse sendSyncMessage(final SendSyncMessageRequest request)
|
||||
throws StatusException, RateLimitExceededException {
|
||||
|
||||
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
|
||||
final AciServiceIdentifier senderServiceIdentifier = new AciServiceIdentifier(authenticatedDevice.accountIdentifier());
|
||||
final Account sender =
|
||||
accountsManager.getByServiceIdentifier(senderServiceIdentifier).orElseThrow(Status.UNAUTHENTICATED::asException);
|
||||
|
||||
return sendMessage(sender,
|
||||
senderServiceIdentifier,
|
||||
authenticatedDevice,
|
||||
request.getType(),
|
||||
MessageType.SYNC,
|
||||
request.getMessages(),
|
||||
false,
|
||||
request.getUrgent());
|
||||
}
|
||||
|
||||
private SendMessageResponse sendMessage(final Account destination,
|
||||
final ServiceIdentifier destinationServiceIdentifier,
|
||||
final AuthenticatedDevice sender,
|
||||
final AuthenticatedSenderMessageType envelopeType,
|
||||
final MessageType messageType,
|
||||
final IndividualRecipientMessageBundle messages,
|
||||
final boolean ephemeral,
|
||||
final boolean urgent) throws StatusException, RateLimitExceededException {
|
||||
|
||||
try {
|
||||
final int totalPayloadLength = messages.getMessagesMap().values().stream()
|
||||
.mapToInt(message -> message.getPayload().size())
|
||||
.sum();
|
||||
|
||||
rateLimiters.getInboundMessageBytes().validate(destinationServiceIdentifier.uuid(), totalPayloadLength);
|
||||
} catch (final RateLimitExceededException e) {
|
||||
messageByteLimitEstimator.add(destinationServiceIdentifier.uuid().toString());
|
||||
throw e;
|
||||
}
|
||||
|
||||
final SpamCheckResult<GrpcResponse<SendMessageResponse>> spamCheckResult =
|
||||
spamChecker.checkForIndividualRecipientSpamGrpc(messageType,
|
||||
Optional.of(sender),
|
||||
Optional.of(destination),
|
||||
destinationServiceIdentifier);
|
||||
|
||||
if (spamCheckResult.response().isPresent()) {
|
||||
return spamCheckResult.response().get().getResponseOrThrowStatus();
|
||||
}
|
||||
|
||||
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = messages.getMessagesMap().entrySet()
|
||||
.stream()
|
||||
.collect(Collectors.toMap(
|
||||
entry -> DeviceIdUtil.validate(entry.getKey()),
|
||||
entry -> {
|
||||
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
|
||||
.setType(getEnvelopeType(envelopeType))
|
||||
.setClientTimestamp(messages.getTimestamp())
|
||||
.setServerTimestamp(clock.millis())
|
||||
.setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString())
|
||||
.setSourceServiceId(new AciServiceIdentifier(sender.accountIdentifier()).toServiceIdentifierString())
|
||||
.setSourceDevice(sender.deviceId())
|
||||
.setEphemeral(ephemeral)
|
||||
.setUrgent(urgent)
|
||||
.setContent(entry.getValue().getPayload());
|
||||
|
||||
spamCheckResult.token().ifPresent(reportSpamToken ->
|
||||
envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken)));
|
||||
|
||||
return envelopeBuilder.build();
|
||||
}
|
||||
));
|
||||
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = messages.getMessagesMap().entrySet().stream()
|
||||
.collect(Collectors.toMap(
|
||||
entry -> entry.getKey().byteValue(),
|
||||
entry -> entry.getValue().getRegistrationId()));
|
||||
|
||||
return MessagesGrpcHelper.sendMessage(messageSender,
|
||||
destination,
|
||||
destinationServiceIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId,
|
||||
messageType == MessageType.SYNC ? Optional.of(sender.deviceId()) : Optional.empty());
|
||||
}
|
||||
|
||||
private static MessageProtos.Envelope.Type getEnvelopeType(final AuthenticatedSenderMessageType type) {
|
||||
return switch (type) {
|
||||
case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT;
|
||||
case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE;
|
||||
case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT;
|
||||
case UNSPECIFIED, UNRECOGNIZED ->
|
||||
throw Status.INVALID_ARGUMENT.withDescription("Unrecognized envelope type").asRuntimeException();
|
||||
};
|
||||
}
|
||||
}
|
|
@ -6,10 +6,8 @@
|
|||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Status;
|
||||
|
||||
import io.grpc.StatusException;
|
||||
import java.time.Clock;
|
||||
import java.util.List;
|
||||
|
||||
import org.signal.chat.profile.CredentialType;
|
||||
import org.signal.chat.profile.GetExpiringProfileKeyCredentialAnonymousRequest;
|
||||
import org.signal.chat.profile.GetExpiringProfileKeyCredentialResponse;
|
||||
|
@ -59,11 +57,17 @@ public class ProfileAnonymousGrpcService extends ReactorProfileAnonymousGrpc.Pro
|
|||
}
|
||||
|
||||
final Mono<Account> account = switch (request.getAuthenticationCase()) {
|
||||
case GROUP_SEND_TOKEN ->
|
||||
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), List.of(targetIdentifier))
|
||||
.then(Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier)))
|
||||
case GROUP_SEND_TOKEN -> {
|
||||
try {
|
||||
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), targetIdentifier);
|
||||
|
||||
yield Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier))
|
||||
.flatMap(Mono::justOrEmpty)
|
||||
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()));
|
||||
} catch (final StatusException e) {
|
||||
yield Mono.error(e);
|
||||
}
|
||||
}
|
||||
case UNIDENTIFIED_ACCESS_KEY ->
|
||||
getTargetAccountAndValidateUnidentifiedAccess(targetIdentifier, request.getUnidentifiedAccessKey().toByteArray());
|
||||
default -> Mono.error(Status.INVALID_ARGUMENT.asException());
|
||||
|
|
|
@ -145,11 +145,13 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase {
|
|||
request.getCommitment().toByteArray())));
|
||||
|
||||
final List<Mono<?>> updates = new ArrayList<>(2);
|
||||
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
|
||||
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, account.getBadges()))
|
||||
.orElseGet(account::getBadges);
|
||||
|
||||
updates.add(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> {
|
||||
|
||||
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
|
||||
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
|
||||
.orElseGet(a::getBadges);
|
||||
|
||||
a.setBadges(clock, updatedBadges);
|
||||
a.setCurrentProfileVersion(request.getVersion());
|
||||
})));
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import java.net.InetAddress;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
public record RequestAttributes(InetAddress remoteAddress,
|
||||
@Nullable String userAgent,
|
||||
List<Locale.LanguageRange> acceptLanguage) {
|
||||
}
|
|
@ -2,28 +2,25 @@ package org.whispersystems.textsecuregcm.grpc;
|
|||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import java.net.InetAddress;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
|
||||
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
|
||||
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
|
||||
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
|
||||
*
|
||||
* @see RequestAttributesUtil
|
||||
*/
|
||||
public class RequestAttributesInterceptor implements ServerInterceptor {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
|
||||
|
||||
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
}
|
||||
|
@ -33,52 +30,12 @@ public class RequestAttributesInterceptor implements ServerInterceptor {
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
|
||||
Context context = Context.current();
|
||||
|
||||
{
|
||||
final Optional<InetAddress> maybeRemoteAddress = grpcClientConnectionManager.getRemoteAddress(localAddress);
|
||||
|
||||
if (maybeRemoteAddress.isEmpty()) {
|
||||
// We should never have a call from a party whose remote address we can't identify
|
||||
log.warn("No remote address available");
|
||||
|
||||
call.close(Status.INTERNAL, new Metadata());
|
||||
return new ServerCall.Listener<>() {};
|
||||
}
|
||||
|
||||
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
|
||||
}
|
||||
|
||||
{
|
||||
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
|
||||
grpcClientConnectionManager.getAcceptableLanguages(localAddress);
|
||||
|
||||
if (maybeAcceptLanguage.isPresent()) {
|
||||
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
final Optional<String> maybeRawUserAgent =
|
||||
grpcClientConnectionManager.getRawUserAgent(localAddress);
|
||||
|
||||
if (maybeRawUserAgent.isPresent()) {
|
||||
context = context.withValue(RequestAttributesUtil.RAW_USER_AGENT_CONTEXT_KEY, maybeRawUserAgent.get());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
final Optional<UserAgent> maybeUserAgent = grpcClientConnectionManager.getUserAgent(localAddress);
|
||||
|
||||
if (maybeUserAgent.isPresent()) {
|
||||
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
|
||||
}
|
||||
}
|
||||
|
||||
return Contexts.interceptCall(context, call, headers, next);
|
||||
} else {
|
||||
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
|
||||
try {
|
||||
return Contexts.interceptCall(Context.current()
|
||||
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
|
||||
grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next);
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,18 +3,13 @@ package org.whispersystems.textsecuregcm.grpc;
|
|||
import io.grpc.Context;
|
||||
import java.net.InetAddress;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
|
||||
public class RequestAttributesUtil {
|
||||
|
||||
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language");
|
||||
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
|
||||
static final Context.Key<String> RAW_USER_AGENT_CONTEXT_KEY = Context.key("unparsed-user-agent");
|
||||
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("parsed-user-agent");
|
||||
static final Context.Key<RequestAttributes> REQUEST_ATTRIBUTES_CONTEXT_KEY = Context.key("request-attributes");
|
||||
|
||||
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
|
||||
|
||||
|
@ -23,8 +18,8 @@ public class RequestAttributesUtil {
|
|||
*
|
||||
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
|
||||
*/
|
||||
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() {
|
||||
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get());
|
||||
public static List<Locale.LanguageRange> getAcceptableLanguages() {
|
||||
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().acceptLanguage();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -35,9 +30,7 @@ public class RequestAttributesUtil {
|
|||
* @return a list of distinct locales acceptable to the remote client and available in this JVM
|
||||
*/
|
||||
public static List<Locale> getAvailableAcceptedLocales() {
|
||||
return getAcceptableLanguages()
|
||||
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
|
||||
.orElseGet(Collections::emptyList);
|
||||
return Locale.filter(getAcceptableLanguages(), AVAILABLE_LOCALES);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -46,16 +39,7 @@ public class RequestAttributesUtil {
|
|||
* @return the remote address of the remote client
|
||||
*/
|
||||
public static InetAddress getRemoteAddress() {
|
||||
return REMOTE_ADDRESS_CONTEXT_KEY.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the parsed user-agent of the remote client in the current gRPC request context.
|
||||
*
|
||||
* @return the parsed user-agent of the remote client; may be empty if unparseable or not specified
|
||||
*/
|
||||
public static Optional<UserAgent> getUserAgent() {
|
||||
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
|
||||
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().remoteAddress();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -63,7 +47,7 @@ public class RequestAttributesUtil {
|
|||
*
|
||||
* @return the unparsed user-agent of the remote client; may be empty if not specified
|
||||
*/
|
||||
public static Optional<String> getRawUserAgent() {
|
||||
return Optional.ofNullable(RAW_USER_AGENT_CONTEXT_KEY.get());
|
||||
public static Optional<String> getUserAgent() {
|
||||
return Optional.ofNullable(REQUEST_ATTRIBUTES_CONTEXT_KEY.get().userAgent());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.Status;
|
||||
|
||||
public class ServerInterceptorUtil {
|
||||
|
||||
@SuppressWarnings("rawtypes")
|
||||
private static final ServerCall.Listener NO_OP_LISTENER = new ServerCall.Listener<>() {};
|
||||
|
||||
private static final Metadata EMPTY_TRAILERS = new Metadata();
|
||||
|
||||
private ServerInterceptorUtil() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes the given server call with the given status, returning a no-op listener.
|
||||
*
|
||||
* @param call the server call to close
|
||||
* @param status the status with which to close the call
|
||||
*
|
||||
* @return a no-op server call listener
|
||||
*
|
||||
* @param <ReqT> the type of request object handled by the server call
|
||||
* @param <RespT> the type of response object returned by the server call
|
||||
*/
|
||||
public static <ReqT, RespT> ServerCall.Listener<ReqT> closeWithStatus(final ServerCall<ReqT, RespT> call, final Status status) {
|
||||
call.close(status, EMPTY_TRAILERS);
|
||||
|
||||
//noinspection unchecked
|
||||
return NO_OP_LISTENER;
|
||||
}
|
||||
}
|
|
@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
|
|||
import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.internalError;
|
||||
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.ForwardingServerCallListener;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
|
@ -75,7 +75,7 @@ public class ValidatingInterceptor implements ServerInterceptor {
|
|||
}
|
||||
|
||||
private void validateMessage(final Object message) throws StatusException {
|
||||
if (message instanceof GeneratedMessageV3 msg) {
|
||||
if (message instanceof Message msg) {
|
||||
try {
|
||||
for (final Descriptors.FieldDescriptor fd: msg.getDescriptorForType().getFields()) {
|
||||
for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry: fd.getOptions().getAllFields().entrySet()) {
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
/**
|
||||
* Indicates that an attempt to authenticate a remote client failed for some reason.
|
||||
*/
|
||||
class ClientAuthenticationException extends Exception {
|
||||
}
|
|
@ -3,60 +3,39 @@ package org.whispersystems.textsecuregcm.grpc.net;
|
|||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a
|
||||
* WebSocket handshake, the error handler will send appropriate WebSocket closure codes to the client in an attempt to
|
||||
* identify the problem. If the client has not completed a WebSocket handshake, the handler simply closes the
|
||||
* connection.
|
||||
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. It translates exceptions
|
||||
* thrown in inbound handlers into {@link OutboundCloseErrorMessage}s.
|
||||
*/
|
||||
class ErrorHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private boolean websocketHandshakeComplete = false;
|
||||
|
||||
public class ErrorHandler extends ChannelInboundHandlerAdapter {
|
||||
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
||||
setWebsocketHandshakeComplete();
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
protected void setWebsocketHandshakeComplete() {
|
||||
this.websocketHandshakeComplete = true;
|
||||
}
|
||||
private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage(
|
||||
OutboundCloseErrorMessage.Code.NOISE_ERROR,
|
||||
"Noise encryption error");
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
|
||||
if (websocketHandshakeComplete) {
|
||||
final WebSocketCloseStatus webSocketCloseStatus = switch (ExceptionUtils.unwrap(cause)) {
|
||||
case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage());
|
||||
case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated");
|
||||
case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
|
||||
case NoiseException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
|
||||
final OutboundCloseErrorMessage closeMessage = switch (ExceptionUtils.unwrap(cause)) {
|
||||
case NoiseHandshakeException e -> new OutboundCloseErrorMessage(
|
||||
OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR,
|
||||
e.getMessage());
|
||||
case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||
case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||
default -> {
|
||||
log.warn("An unexpected exception reached the end of the pipeline", cause);
|
||||
yield WebSocketCloseStatus.INTERNAL_SERVER_ERROR;
|
||||
yield new OutboundCloseErrorMessage(
|
||||
OutboundCloseErrorMessage.Code.INTERNAL_SERVER_ERROR,
|
||||
cause.getMessage());
|
||||
}
|
||||
};
|
||||
|
||||
context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus))
|
||||
context.writeAndFlush(closeMessage)
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
} else {
|
||||
log.debug("Error occurred before websocket handshake complete", cause);
|
||||
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
|
||||
// way; just close the connection instead.
|
||||
context.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,20 +7,21 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
|
|||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.InetAddress;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
|
||||
/**
|
||||
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
|
||||
* any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
|
||||
* server.
|
||||
*/
|
||||
class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
|
@ -48,15 +49,20 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
|
||||
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
|
||||
if (event instanceof NoiseIdentityDeterminedEvent(
|
||||
final Optional<AuthenticatedDevice> authenticatedDevice,
|
||||
InetAddress remoteAddress, String userAgent, String acceptLanguage)) {
|
||||
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
|
||||
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
|
||||
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
|
||||
// authenticated service.
|
||||
final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent()
|
||||
final LocalAddress grpcServerAddress = authenticatedDevice.isPresent()
|
||||
? authenticatedGrpcServerAddress
|
||||
: anonymousGrpcServerAddress;
|
||||
|
||||
GrpcClientConnectionManager.handleHandshakeInitiated(
|
||||
remoteChannelContext.channel(), remoteAddress, userAgent, acceptLanguage);
|
||||
|
||||
new Bootstrap()
|
||||
.remoteAddress(grpcServerAddress)
|
||||
.channel(LocalChannel.class)
|
||||
|
@ -72,12 +78,14 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
if (localChannelFuture.isSuccess()) {
|
||||
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
|
||||
remoteChannelContext.channel(),
|
||||
noiseIdentityDeterminedEvent.authenticatedDevice());
|
||||
authenticatedDevice);
|
||||
|
||||
// Close the local connection if the remote channel closes and vice versa
|
||||
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
|
||||
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
|
||||
remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
|
||||
remoteChannelContext.channel()
|
||||
.write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
|
||||
|
||||
remoteChannelContext.pipeline()
|
||||
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.ServerCall;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.util.AttributeKey;
|
||||
import java.net.InetAddress;
|
||||
import java.util.ArrayList;
|
||||
|
@ -23,15 +24,26 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
||||
import org.whispersystems.textsecuregcm.util.ClosableEpoch;
|
||||
|
||||
/**
|
||||
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
|
||||
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the
|
||||
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close
|
||||
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate.
|
||||
* Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
|
||||
* identity of the device that opened the connection (for non-anonymous connections). It can also close connections
|
||||
* associated with a given device if that device's credentials have changed and clients must reauthenticate.
|
||||
* <p>
|
||||
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
|
||||
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
|
||||
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
|
||||
* narrow window between a server call being created and the start of call execution, in which case accessor methods
|
||||
* in this class will throw a {@link ChannelNotFoundException}.
|
||||
* <p>
|
||||
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
|
||||
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
|
||||
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
|
||||
* be called from any application code.
|
||||
*/
|
||||
public class GrpcClientConnectionManager implements DisconnectionRequestListener {
|
||||
|
||||
|
@ -43,94 +55,96 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
|||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress");
|
||||
public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
|
||||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent");
|
||||
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
|
||||
private static OutboundCloseErrorMessage SERVER_CLOSED =
|
||||
new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed");
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
|
||||
|
||||
/**
|
||||
* Returns the authenticated device associated with the given local address, if any. An authenticated device is
|
||||
* available if and only if the given local address maps to an active local connection and that connection is
|
||||
* authenticated (i.e. not anonymous).
|
||||
* Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
|
||||
* (i.e. unauthenticated), the returned value will be empty.
|
||||
*
|
||||
* @param localAddress the local address for which to find an authenticated device
|
||||
* @param serverCall the gRPC server call for which to find an authenticated device
|
||||
*
|
||||
* @return the authenticated device associated with the given local address, if any
|
||||
*
|
||||
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
|
||||
* generally indicates that the channel has closed while request processing is still in progress
|
||||
*/
|
||||
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) {
|
||||
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress));
|
||||
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall)
|
||||
throws ChannelNotFoundException {
|
||||
|
||||
return getAuthenticatedDevice(getRemoteChannel(serverCall));
|
||||
}
|
||||
|
||||
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) {
|
||||
return Optional.ofNullable(remoteChannel)
|
||||
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
|
||||
@VisibleForTesting
|
||||
Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) {
|
||||
return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may
|
||||
* be unavailable if the local connection associated with the given local address has already closed, if the client
|
||||
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
|
||||
* Returns the request attributes associated with the given server call.
|
||||
*
|
||||
* @param localAddress the local address for which to find acceptable languages
|
||||
* @param serverCall the gRPC server call for which to retrieve request attributes
|
||||
*
|
||||
* @return the acceptable languages associated with the given local address, if any
|
||||
* @return the request attributes associated with the given server call
|
||||
*
|
||||
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
|
||||
* generally indicates that the channel has closed while request processing is still in progress
|
||||
*/
|
||||
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
|
||||
public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
|
||||
return getRequestAttributes(getRemoteChannel(serverCall));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
RequestAttributes getRequestAttributes(final Channel remoteChannel) {
|
||||
final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
|
||||
|
||||
if (requestAttributes == null) {
|
||||
throw new IllegalStateException("Channel does not have request attributes");
|
||||
}
|
||||
|
||||
return requestAttributes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if
|
||||
* the local connection associated with the given local address has already closed.
|
||||
* Handles the start of a server call, incrementing the active call count for the remote channel associated with the
|
||||
* given server call.
|
||||
*
|
||||
* @param localAddress the local address for which to find a remote address
|
||||
* @param serverCall the server call to start
|
||||
*
|
||||
* @return the remote address associated with the given local address, if any
|
||||
* @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the
|
||||
* underlying channel is closing
|
||||
*/
|
||||
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
|
||||
public boolean handleServerCallStart(final ServerCall<?, ?> serverCall) {
|
||||
try {
|
||||
return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive();
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
// This would only happen if the channel had already closed, which is certainly possible. In this case, the call
|
||||
// should certainly not proceed.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the unparsed user agent provided by the client that opened the connection associated with the given local
|
||||
* address. This method may return an empty value if no active local connection is associated with the given local
|
||||
* address.
|
||||
* Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel
|
||||
* associated with the given server call.
|
||||
*
|
||||
* @param localAddress the local address for which to find a User-Agent string
|
||||
*
|
||||
* @return the user agent string associated with the given local address
|
||||
* @param serverCall the server call to complete
|
||||
*/
|
||||
public Optional<String> getRawUserAgent(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(RAW_USER_AGENT_ATTRIBUTE_KEY).get());
|
||||
public void handleServerCallComplete(final ServerCall<?, ?> serverCall) {
|
||||
try {
|
||||
getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart();
|
||||
} catch (final ChannelNotFoundException ignored) {
|
||||
// In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the parsed user agent provided by the client that opened the connection associated with the given local
|
||||
* address. This method may return an empty value if no active local connection is associated with the given local
|
||||
* address or if the client's user-agent string was not recognized.
|
||||
*
|
||||
* @param localAddress the local address for which to find a User-Agent string
|
||||
*
|
||||
* @return the user agent associated with the given local address
|
||||
*/
|
||||
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -139,11 +153,17 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
|||
* @param authenticatedDevice the authenticated device for which to close connections
|
||||
*/
|
||||
public void closeConnection(final AuthenticatedDevice authenticatedDevice) {
|
||||
// Channels will actually get removed from the list/map by their closeFuture listeners
|
||||
remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()).forEach(channel ->
|
||||
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
|
||||
.toWebSocketCloseStatus("Reauthentication required")))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
|
||||
// Channels will actually get removed from the list/map by their closeFuture listeners. We copy the list to avoid
|
||||
// concurrent modification; it's possible (though practically unlikely) that a channel can close and remove itself
|
||||
// from the list while we're still iterating, resulting in a `ConcurrentModificationException`.
|
||||
final List<Channel> channelsToClose =
|
||||
new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
|
||||
|
||||
channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
|
||||
}
|
||||
|
||||
private static void closeRemoteChannel(final Channel channel) {
|
||||
channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
|
@ -151,53 +171,66 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
|||
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
|
||||
}
|
||||
|
||||
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
|
||||
return getRemoteChannel(getLocalAddress(serverCall));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) {
|
||||
Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
|
||||
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
|
||||
|
||||
if (remoteChannel == null) {
|
||||
throw new ChannelNotFoundException();
|
||||
}
|
||||
|
||||
return remoteChannelsByLocalAddress.get(localAddress);
|
||||
}
|
||||
|
||||
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
|
||||
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
|
||||
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
|
||||
// a proxied Noise channel.
|
||||
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
|
||||
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
|
||||
}
|
||||
|
||||
return localAddress;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
|
||||
* Handles receipt of a handshake message and associates attributes and headers from the handshake
|
||||
* request with the channel via which the handshake took place.
|
||||
*
|
||||
* @param channel the channel that completed a WebSocket handshake
|
||||
* @param channel the channel where the handshake was initiated
|
||||
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
|
||||
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
|
||||
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
|
||||
* {@code null}
|
||||
*/
|
||||
static void handleWebSocketHandshakeComplete(final Channel channel,
|
||||
public static void handleHandshakeInitiated(final Channel channel,
|
||||
final InetAddress preferredRemoteAddress,
|
||||
@Nullable final String userAgentHeader,
|
||||
@Nullable final String acceptLanguageHeader) {
|
||||
|
||||
channel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress);
|
||||
|
||||
if (StringUtils.isNotBlank(userAgentHeader)) {
|
||||
channel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
|
||||
|
||||
try {
|
||||
channel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
|
||||
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
|
||||
} catch (final UnrecognizedUserAgentException ignored) {
|
||||
}
|
||||
}
|
||||
@Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
|
||||
|
||||
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
|
||||
try {
|
||||
channel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader));
|
||||
acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
|
||||
} catch (final IllegalArgumentException e) {
|
||||
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
|
||||
}
|
||||
}
|
||||
|
||||
channel.attr(REQUEST_ATTRIBUTES_KEY)
|
||||
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server.
|
||||
* Handles successful establishment of a Noise connection from a remote client to a local gRPC server.
|
||||
*
|
||||
* @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC
|
||||
* server
|
||||
* @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel
|
||||
* @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server
|
||||
* @param remoteChannel the channel from the remote client to the Noise tunnel
|
||||
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
|
||||
*/
|
||||
void handleConnectionEstablished(final LocalChannel localChannel,
|
||||
|
@ -207,6 +240,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
|||
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
|
||||
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
|
||||
|
||||
remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
|
||||
.set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
|
||||
|
||||
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
|
||||
|
||||
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
enum HandshakePattern {
|
||||
public enum HandshakePattern {
|
||||
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
|
||||
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");
|
||||
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
/**
|
||||
* A NoiseAnonymousHandler is a netty pipeline element that handles the responder side of an unauthenticated handshake
|
||||
* and noise encryption/decryption.
|
||||
* <p>
|
||||
* A noise NK handshake must be used for unauthenticated connections. Optionally, the initiator can also include an
|
||||
* initial request in their payload. If provided, this allows the server to begin processing the request without an
|
||||
* initial message delay (fast open).
|
||||
* <p>
|
||||
* Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent}
|
||||
* indicating that initiator connected anonymously.
|
||||
*/
|
||||
class NoiseAnonymousHandler extends NoiseHandler {
|
||||
|
||||
public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) {
|
||||
super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair));
|
||||
}
|
||||
|
||||
@Override
|
||||
CompletableFuture<HandshakeResult> handleHandshakePayload(final ChannelHandlerContext context,
|
||||
final Optional<byte[]> initiatorPublicKey, final ByteBuf handshakePayload) {
|
||||
return CompletableFuture.completedFuture(new HandshakeResult(
|
||||
handshakePayload,
|
||||
Optional.empty()
|
||||
));
|
||||
}
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* A NoiseAuthenticatedHandler is a netty pipeline element that handles the responder side of an authenticated handshake
|
||||
* and noise encryption/decryption. Authenticated handshakes are noise IK handshakes where the initiator's static public
|
||||
* key is authenticated by the responder.
|
||||
* <p>
|
||||
* The authenticated handshake requires the initiator to provide a payload with their first handshake message that
|
||||
* includes their account identifier and device id in network byte-order. Optionally, the initiator can also include an
|
||||
* initial request in their payload. If provided, this allows the server to begin processing the request without an
|
||||
* initial message delay (fast open).
|
||||
* <pre>
|
||||
* +-----------------+----------------+------------------------+
|
||||
* | UUID (16) | deviceId (1) | request bytes (N) |
|
||||
* +-----------------+----------------+------------------------+
|
||||
* </pre>
|
||||
* <p>
|
||||
* For a successful handshake, the static key provided in the handshake message must match the server's stored public
|
||||
* key for the device identified by the provided ACI and deviceId.
|
||||
* <p>
|
||||
* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}.
|
||||
*/
|
||||
class NoiseAuthenticatedHandler extends NoiseHandler {
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair) {
|
||||
super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair));
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
CompletableFuture<HandshakeResult> handleHandshakePayload(
|
||||
final ChannelHandlerContext context,
|
||||
final Optional<byte[]> initiatorPublicKey,
|
||||
final ByteBuf handshakePayload) throws NoiseHandshakeException {
|
||||
if (handshakePayload.readableBytes() < 17) {
|
||||
throw new NoiseHandshakeException("Invalid handshake payload");
|
||||
}
|
||||
|
||||
final byte[] publicKeyFromClient = initiatorPublicKey
|
||||
.orElseThrow(() -> new IllegalStateException("No remote public key"));
|
||||
|
||||
// Advances the read index by 16 bytes
|
||||
final UUID accountIdentifier = parseUUID(handshakePayload);
|
||||
|
||||
// Advances the read index by 1 byte
|
||||
final byte deviceId = handshakePayload.readByte();
|
||||
|
||||
final ByteBuf fastOpenRequest = handshakePayload.slice();
|
||||
return clientPublicKeysManager
|
||||
.findPublicKey(accountIdentifier, deviceId)
|
||||
.handleAsync((storedPublicKey, throwable) -> {
|
||||
if (throwable != null) {
|
||||
ReferenceCountUtil.release(fastOpenRequest);
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
}
|
||||
final boolean valid = storedPublicKey
|
||||
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
|
||||
.orElse(false);
|
||||
if (!valid) {
|
||||
throw ExceptionUtils.wrap(new ClientAuthenticationException());
|
||||
}
|
||||
return new HandshakeResult(
|
||||
fastOpenRequest,
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
|
||||
}, context.executor());
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a {@link UUID} out of bytes, advancing the readerIdx by 16
|
||||
*
|
||||
* @param bytes The {@link ByteBuf} to read from
|
||||
* @return The parsed UUID
|
||||
* @throws NoiseHandshakeException If a UUID could not be parsed from bytes
|
||||
*/
|
||||
private UUID parseUUID(final ByteBuf bytes) throws NoiseHandshakeException {
|
||||
if (bytes.readableBytes() < 16) {
|
||||
throw new NoiseHandshakeException("Could not parse account identifier");
|
||||
}
|
||||
return new UUID(bytes.readLong(), bytes.readLong());
|
||||
}
|
||||
}
|
|
@ -1,10 +1,12 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
||||
|
||||
/**
|
||||
* Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/
|
||||
* format or a general encryption error).
|
||||
*/
|
||||
class NoiseException extends Exception {
|
||||
public class NoiseException extends NoStackTraceException {
|
||||
public NoiseException(final String message) {
|
||||
super(message);
|
||||
}
|
||||
|
|
|
@ -11,135 +11,50 @@ import io.netty.buffer.ByteBuf;
|
|||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.concurrent.PromiseCombiner;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts
|
||||
* inbound messages, and encrypts outbound messages
|
||||
* A bidirectional {@link io.netty.channel.ChannelHandler} that decrypts inbound messages, and encrypts outbound
|
||||
* messages
|
||||
*/
|
||||
abstract class NoiseHandler extends ChannelDuplexHandler {
|
||||
public class NoiseHandler extends ChannelDuplexHandler {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
|
||||
private final CipherStatePair cipherStatePair;
|
||||
|
||||
private enum State {
|
||||
// Waiting for handshake to complete
|
||||
HANDSHAKE,
|
||||
// Can freely exchange encrypted noise messages on an established session
|
||||
TRANSPORT,
|
||||
// Finished with error
|
||||
ERROR
|
||||
NoiseHandler(CipherStatePair cipherStatePair) {
|
||||
this.cipherStatePair = cipherStatePair;
|
||||
}
|
||||
|
||||
private final NoiseHandshakeHelper handshakeHelper;
|
||||
|
||||
private State state = State.HANDSHAKE;
|
||||
private CipherStatePair cipherStatePair;
|
||||
|
||||
NoiseHandler(NoiseHandshakeHelper handshakeHelper) {
|
||||
this.handshakeHelper = handshakeHelper;
|
||||
}
|
||||
|
||||
/**
|
||||
* The result of processing an initiator handshake payload
|
||||
*
|
||||
* @param fastOpenRequest A fast-open request included in the handshake. If none was present, this should be an
|
||||
* empty ByteBuf
|
||||
* @param authenticatedDevice If present, the successfully authenticated initiator identity
|
||||
*/
|
||||
record HandshakeResult(ByteBuf fastOpenRequest, Optional<AuthenticatedDevice> authenticatedDevice) {}
|
||||
|
||||
/**
|
||||
* Parse and potentially authenticate the initiator handshake message
|
||||
*
|
||||
* @param context A {@link ChannelHandlerContext}
|
||||
* @param initiatorPublicKey The initiator's static public key, if a handshake pattern that includes it was used
|
||||
* @param handshakePayload The handshake payload provided in the initiator message
|
||||
* @return A {@link HandshakeResult} that includes an authenticated device and a parsed fast-open request if one was
|
||||
* present in the handshake payload.
|
||||
* @throws NoiseHandshakeException If the handshake payload was invalid
|
||||
* @throws ClientAuthenticationException If the initiatorPublicKey could not be authenticated
|
||||
*/
|
||||
abstract CompletableFuture<HandshakeResult> handleHandshakePayload(
|
||||
final ChannelHandlerContext context,
|
||||
final Optional<byte[]> initiatorPublicKey,
|
||||
final ByteBuf handshakePayload) throws NoiseHandshakeException, ClientAuthenticationException;
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
if (frame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
final String error = "Invalid noise message length " + frame.content().readableBytes();
|
||||
throw state == State.HANDSHAKE ? new NoiseHandshakeException(error) : new NoiseException(error);
|
||||
if (message instanceof ByteBuf frame) {
|
||||
if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
throw new NoiseException("Invalid noise message length " + frame.readableBytes());
|
||||
}
|
||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||
// We'll need to copy it to a heap buffer.
|
||||
handleInboundMessage(context, ByteBufUtil.getBytes(frame.content()));
|
||||
handleInboundDataMessage(context, ByteBufUtil.getBytes(frame));
|
||||
} else {
|
||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||
// error
|
||||
// Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
fail(context, e);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleInboundMessage(final ChannelHandlerContext context, final byte[] frameBytes)
|
||||
throws NoiseHandshakeException, ShortBufferException, BadPaddingException, ClientAuthenticationException {
|
||||
switch (state) {
|
||||
|
||||
// Got an initiator handshake message
|
||||
case HANDSHAKE -> {
|
||||
final ByteBuf payload = handshakeHelper.read(frameBytes);
|
||||
handleHandshakePayload(context, handshakeHelper.remotePublicKey(), payload).whenCompleteAsync(
|
||||
(result, throwable) -> {
|
||||
if (state == State.ERROR) {
|
||||
return;
|
||||
}
|
||||
if (throwable != null) {
|
||||
fail(context, ExceptionUtils.unwrap(throwable));
|
||||
return;
|
||||
}
|
||||
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(result.authenticatedDevice()));
|
||||
|
||||
// Now that we've authenticated, write the handshake response
|
||||
byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES);
|
||||
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
|
||||
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
||||
this.state = State.TRANSPORT;
|
||||
this.cipherStatePair = handshakeHelper.getHandshakeState().split();
|
||||
if (result.fastOpenRequest().isReadable()) {
|
||||
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
|
||||
// encrypt the response when the server writes back through us
|
||||
context.fireChannelRead(result.fastOpenRequest());
|
||||
} else {
|
||||
ReferenceCountUtil.release(result.fastOpenRequest());
|
||||
}
|
||||
}, context.executor());
|
||||
}
|
||||
|
||||
// Got a client message that should be decrypted and forwarded
|
||||
case TRANSPORT -> {
|
||||
private void handleInboundDataMessage(final ChannelHandlerContext context, final byte[] frameBytes)
|
||||
throws ShortBufferException, BadPaddingException {
|
||||
final CipherState cipherState = cipherStatePair.getReceiver();
|
||||
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
||||
final int plaintextLength = cipherState.decryptWithAd(null,
|
||||
|
@ -151,21 +66,6 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
|
|||
context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
|
||||
}
|
||||
|
||||
// The session is already in an error state, drop the message
|
||||
case ERROR -> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the state to the error state (so subsequent messages fast-fail) and propagate the failure reason on the
|
||||
* context
|
||||
*/
|
||||
private void fail(final ChannelHandlerContext context, final Throwable cause) {
|
||||
this.state = State.ERROR;
|
||||
context.fireExceptionCaught(cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
||||
throws Exception {
|
||||
|
@ -193,19 +93,27 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
|
|||
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
||||
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
||||
|
||||
pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer))));
|
||||
pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer)));
|
||||
}
|
||||
pc.finish(promise);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(byteBuf);
|
||||
}
|
||||
} else {
|
||||
if (!(message instanceof WebSocketFrame)) {
|
||||
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
|
||||
// get issued in response to exceptions)
|
||||
if (!(message instanceof OutboundCloseErrorMessage)) {
|
||||
// Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames
|
||||
// that get issued in response to exceptions)
|
||||
log.warn("Unexpected object in pipeline: {}", message);
|
||||
}
|
||||
context.write(message, promise);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(ChannelHandlerContext var1) {
|
||||
if (cipherStatePair != null) {
|
||||
cipherStatePair.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
||||
|
||||
/**
|
||||
* Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
|
||||
* a general encryption error).
|
||||
*/
|
||||
class NoiseHandshakeException extends Exception {
|
||||
public class NoiseHandshakeException extends NoStackTraceException {
|
||||
|
||||
public NoiseHandshakeException(final String message) {
|
||||
super(message);
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufInputStream;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.DeviceIdUtil;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
/**
|
||||
* Handles the responder side of a noise handshake and then replaces itself with a {@link NoiseHandler} which will
|
||||
* encrypt/decrypt subsequent data frames
|
||||
* <p>
|
||||
* The handler expects to receive a single inbound message, a {@link NoiseHandshakeInit} that includes the initiator
|
||||
* handshake message, connection metadata, and the type of handshake determined by the framing layer. This handler
|
||||
* currently supports two types of handshakes.
|
||||
* <p>
|
||||
* The first are IK handshakes where the initiator's static public key is authenticated by the responder. The initiator
|
||||
* handshake message must contain the ACI and deviceId of the initiator. To be authenticated, the static key provided in
|
||||
* the handshake message must match the server's stored public key for the device identified by the provided ACI and
|
||||
* deviceId.
|
||||
* <p>
|
||||
* The second are NK handshakes which are anonymous.
|
||||
* <p>
|
||||
* Optionally, the initiator can also include an initial request in their payload. If provided, this allows the server
|
||||
* to begin processing the request without an initial message delay (fast open).
|
||||
* <p>
|
||||
* Once the handshake has been validated, a {@link NoiseIdentityDeterminedEvent} will be fired. For an IK handshake,
|
||||
* this will include the {@link org.whispersystems.textsecuregcm.auth.AuthenticatedDevice} of the initiator. This
|
||||
* handler will then replace itself with a {@link NoiseHandler} with a noise state pair ready to encrypt/decrypt data
|
||||
* frames.
|
||||
*/
|
||||
public class NoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private static final byte[] HANDSHAKE_WRONG_PK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY)
|
||||
.build().toByteArray();
|
||||
private static final byte[] HANDSHAKE_OK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
|
||||
.build().toByteArray();
|
||||
|
||||
// We might get additional messages while we're waiting to process a handshake, so keep track of where we are
|
||||
private boolean receivedHandshakeInit = false;
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
private final ECKeyPair ecKeyPair;
|
||||
|
||||
public NoiseHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (!(message instanceof NoiseHandshakeInit handshakeInit)) {
|
||||
// Anything except HandshakeInit should have been filtered out of the pipeline by now; treat this as an error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
if (receivedHandshakeInit) {
|
||||
throw new NoiseHandshakeException("Should not receive messages until handshake complete");
|
||||
}
|
||||
receivedHandshakeInit = true;
|
||||
|
||||
if (handshakeInit.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
throw new NoiseHandshakeException("Invalid noise message length " + handshakeInit.content().readableBytes());
|
||||
}
|
||||
|
||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||
// We'll need to copy it to a heap buffer
|
||||
handleInboundHandshake(context,
|
||||
handshakeInit.getRemoteAddress(),
|
||||
handshakeInit.getHandshakePattern(),
|
||||
ByteBufUtil.getBytes(handshakeInit.content()));
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleInboundHandshake(
|
||||
final ChannelHandlerContext context,
|
||||
final InetAddress remoteAddress,
|
||||
final HandshakePattern handshakePattern,
|
||||
final byte[] frameBytes) throws NoiseHandshakeException {
|
||||
final NoiseHandshakeHelper handshakeHelper = new NoiseHandshakeHelper(handshakePattern, ecKeyPair);
|
||||
final ByteBuf payload = handshakeHelper.read(frameBytes);
|
||||
|
||||
// Parse the handshake message
|
||||
final NoiseTunnelProtos.HandshakeInit handshakeInit;
|
||||
try {
|
||||
handshakeInit = NoiseTunnelProtos.HandshakeInit.parseFrom(new ByteBufInputStream(payload));
|
||||
} catch (IOException e) {
|
||||
throw new NoiseHandshakeException("Failed to parse handshake message");
|
||||
}
|
||||
|
||||
switch (handshakePattern) {
|
||||
case NK -> {
|
||||
if (handshakeInit.getDeviceId() != 0 || !handshakeInit.getAci().isEmpty()) {
|
||||
throw new NoiseHandshakeException("Anonymous handshake should not include identifiers");
|
||||
}
|
||||
handleAuthenticated(context, handshakeHelper, remoteAddress, handshakeInit, Optional.empty());
|
||||
}
|
||||
case IK -> {
|
||||
final byte[] publicKeyFromClient = handshakeHelper.remotePublicKey()
|
||||
.orElseThrow(() -> new IllegalStateException("No remote public key"));
|
||||
final UUID accountIdentifier = aci(handshakeInit);
|
||||
final byte deviceId = deviceId(handshakeInit);
|
||||
clientPublicKeysManager
|
||||
.findPublicKey(accountIdentifier, deviceId)
|
||||
.whenCompleteAsync((storedPublicKey, throwable) -> {
|
||||
if (throwable != null) {
|
||||
context.fireExceptionCaught(ExceptionUtils.unwrap(throwable));
|
||||
return;
|
||||
}
|
||||
final boolean valid = storedPublicKey
|
||||
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
|
||||
.orElse(false);
|
||||
if (!valid) {
|
||||
// Write a handshake response indicating that the client used the wrong public key
|
||||
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_WRONG_PK);
|
||||
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
|
||||
context.fireExceptionCaught(new NoiseHandshakeException("Bad public key"));
|
||||
return;
|
||||
}
|
||||
handleAuthenticated(context,
|
||||
handshakeHelper, remoteAddress, handshakeInit,
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
|
||||
}, context.executor());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void handleAuthenticated(final ChannelHandlerContext context,
|
||||
final NoiseHandshakeHelper handshakeHelper,
|
||||
final InetAddress remoteAddress,
|
||||
final NoiseTunnelProtos.HandshakeInit handshakeInit,
|
||||
final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
|
||||
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(
|
||||
maybeAuthenticatedDevice,
|
||||
remoteAddress,
|
||||
handshakeInit.getUserAgent(),
|
||||
handshakeInit.getAcceptLanguage()));
|
||||
|
||||
// Now that we've authenticated, write the handshake response
|
||||
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_OK);
|
||||
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
|
||||
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
||||
// Note: It may be tempting to swap the before/remove for a replace, but then when we forward the fast open
|
||||
// request it will go through the NoiseHandler. We want to skip the NoiseHandler because we've already
|
||||
// decrypted the fastOpen request
|
||||
context.pipeline()
|
||||
.addBefore(context.name(), null, new NoiseHandler(handshakeHelper.getHandshakeState().split()));
|
||||
context.pipeline().remove(NoiseHandshakeHandler.class);
|
||||
if (!handshakeInit.getFastOpenRequest().isEmpty()) {
|
||||
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
|
||||
// encrypt the response when the server writes back through us
|
||||
context.fireChannelRead(Unpooled.wrappedBuffer(handshakeInit.getFastOpenRequest().asReadOnlyByteBuffer()));
|
||||
}
|
||||
}
|
||||
|
||||
private static UUID aci(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
|
||||
try {
|
||||
return UUIDUtil.fromByteString(handshakePayload.getAci());
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new NoiseHandshakeException("Could not parse aci");
|
||||
}
|
||||
}
|
||||
|
||||
private static byte deviceId(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
|
||||
if (!DeviceIdUtil.isValid(handshakePayload.getDeviceId())) {
|
||||
throw new NoiseHandshakeException("Invalid deviceId");
|
||||
}
|
||||
return (byte) handshakePayload.getDeviceId();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.DefaultByteBufHolder;
|
||||
import java.net.InetAddress;
|
||||
|
||||
/**
|
||||
* A message that includes the initiator's handshake message, connection metadata, and the handshake type. The metadata
|
||||
* and handshake type are extracted from the framing layer, so this allows receivers to be framing layer agnostic.
|
||||
*/
|
||||
public class NoiseHandshakeInit extends DefaultByteBufHolder {
|
||||
|
||||
private final InetAddress remoteAddress;
|
||||
private final HandshakePattern handshakePattern;
|
||||
|
||||
public NoiseHandshakeInit(
|
||||
final InetAddress remoteAddress,
|
||||
final HandshakePattern handshakePattern,
|
||||
final ByteBuf initiatorHandshakeMessage) {
|
||||
super(initiatorHandshakeMessage);
|
||||
this.remoteAddress = remoteAddress;
|
||||
this.handshakePattern = handshakePattern;
|
||||
}
|
||||
|
||||
public InetAddress getRemoteAddress() {
|
||||
return remoteAddress;
|
||||
}
|
||||
|
||||
public HandshakePattern getHandshakePattern() {
|
||||
return handshakePattern;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import java.net.InetAddress;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
|
||||
|
@ -9,5 +10,12 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
|||
*
|
||||
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
||||
* type that performs authentication
|
||||
* @param remoteAddress the remote address of the connecting client
|
||||
* @param userAgent the client supplied userAgent
|
||||
* @param acceptLanguage the client supplied acceptLanguage
|
||||
*/
|
||||
record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}
|
||||
public record NoiseIdentityDeterminedEvent(
|
||||
Optional<AuthenticatedDevice> authenticatedDevice,
|
||||
InetAddress remoteAddress,
|
||||
String userAgent,
|
||||
String acceptLanguage) {}
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
/**
|
||||
* An error written to the outbound pipeline that indicates the connection should be closed
|
||||
*/
|
||||
public record OutboundCloseErrorMessage(Code code, String message) {
|
||||
public enum Code {
|
||||
|
||||
/**
|
||||
* The server decided to close the connection. This could be because the server is going away, or it could be
|
||||
* because the credentials for the connected client have been updated.
|
||||
*/
|
||||
SERVER_CLOSED,
|
||||
|
||||
/**
|
||||
* There was a noise decryption error after the noise session was established
|
||||
*/
|
||||
NOISE_ERROR,
|
||||
|
||||
/**
|
||||
* There was an error establishing the noise handshake
|
||||
*/
|
||||
NOISE_HANDSHAKE_ERROR,
|
||||
|
||||
INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
|
@ -8,7 +8,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
|
|||
/**
|
||||
* A proxy handler writes all data read from one channel to another peer channel.
|
||||
*/
|
||||
class ProxyHandler extends ChannelInboundHandlerAdapter {
|
||||
public class ProxyHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final Channel peerChannel;
|
||||
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
|
||||
|
||||
/**
|
||||
* In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the
|
||||
* noise packet to the noise layer as a {@link ByteBuf} for decryption.
|
||||
* <p>
|
||||
* In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper
|
||||
* so it can be wire serialized. This handler assumes the first outbound message received will correspond to the
|
||||
* handshake response, and then the subsequent messages are all data frame payloads.
|
||||
*/
|
||||
public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame frame) {
|
||||
if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType());
|
||||
}
|
||||
ctx.fireChannelRead(frame.content());
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||
if (msg instanceof ByteBuf bb) {
|
||||
ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.DefaultByteBufHolder;
|
||||
|
||||
public class NoiseDirectFrame extends DefaultByteBufHolder {
|
||||
|
||||
static final byte VERSION = 0x00;
|
||||
|
||||
private final FrameType frameType;
|
||||
|
||||
public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) {
|
||||
super(data);
|
||||
this.frameType = frameType;
|
||||
}
|
||||
|
||||
public FrameType frameType() {
|
||||
return frameType;
|
||||
}
|
||||
|
||||
public byte versionedFrameTypeByte() {
|
||||
final byte frameBits = frameType().getFrameBits();
|
||||
return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits);
|
||||
}
|
||||
|
||||
|
||||
public enum FrameType {
|
||||
/**
|
||||
* The payload is the initiator message for a Noise NK handshake. If established, the
|
||||
* session will be unauthenticated.
|
||||
*/
|
||||
NK_HANDSHAKE((byte) 1),
|
||||
/**
|
||||
* The payload is the initiator message for a Noise IK handshake. If established, the
|
||||
* session will be authenticated.
|
||||
*/
|
||||
IK_HANDSHAKE((byte) 2),
|
||||
/**
|
||||
* The payload is an encrypted noise packet.
|
||||
*/
|
||||
DATA((byte) 3),
|
||||
/**
|
||||
* A frame sent before the connection is closed. The payload is a protobuf indicating why the connection is being
|
||||
* closed.
|
||||
*/
|
||||
CLOSE((byte) 4);
|
||||
|
||||
private final byte frameType;
|
||||
|
||||
FrameType(byte frameType) {
|
||||
if (frameType != (0x0F & frameType)) {
|
||||
throw new IllegalStateException("Frame type must fit in 4 bits");
|
||||
}
|
||||
this.frameType = frameType;
|
||||
}
|
||||
|
||||
public byte getFrameBits() {
|
||||
return frameType;
|
||||
}
|
||||
|
||||
public boolean isHandshake() {
|
||||
return switch (this) {
|
||||
case IK_HANDSHAKE, NK_HANDSHAKE -> true;
|
||||
case DATA, CLOSE -> false;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
|
||||
/**
|
||||
* Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes
|
||||
* have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder}
|
||||
*/
|
||||
public class NoiseDirectFrameCodec extends ChannelDuplexHandler {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof ByteBuf byteBuf) {
|
||||
try {
|
||||
ctx.fireChannelRead(deserialize(byteBuf));
|
||||
} catch (Exception e) {
|
||||
ReferenceCountUtil.release(byteBuf);
|
||||
throw e;
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||
if (msg instanceof NoiseDirectFrame noiseDirectFrame) {
|
||||
try {
|
||||
// Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the
|
||||
// network, nothing should have to make another copy of this. If later another layer is added, it may be more
|
||||
// efficient to reuse the input buffer (typically not direct) by using a composite byte buffer
|
||||
final ByteBuf serialized = serialize(ctx, noiseDirectFrame);
|
||||
ctx.writeAndFlush(serialized, promise);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(noiseDirectFrame);
|
||||
}
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
|
||||
private ByteBuf serialize(
|
||||
final ChannelHandlerContext ctx,
|
||||
final NoiseDirectFrame noiseDirectFrame) {
|
||||
if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes());
|
||||
}
|
||||
|
||||
// 1 version/frametype byte, 2 length bytes, content
|
||||
final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes());
|
||||
|
||||
byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte());
|
||||
byteBuf.writeShort(noiseDirectFrame.content().readableBytes());
|
||||
byteBuf.writeBytes(noiseDirectFrame.content());
|
||||
return byteBuf;
|
||||
}
|
||||
|
||||
private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception {
|
||||
final byte versionAndFrameByte = byteBuf.readByte();
|
||||
final int version = (versionAndFrameByte & 0xF0) >> 4;
|
||||
if (version != NoiseDirectFrame.VERSION) {
|
||||
throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version);
|
||||
}
|
||||
final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F);
|
||||
final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) {
|
||||
case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
|
||||
case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
|
||||
case 3 -> NoiseDirectFrame.FrameType.DATA;
|
||||
case 4 -> NoiseDirectFrame.FrameType.CLOSE;
|
||||
default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
|
||||
};
|
||||
|
||||
final int length = Short.toUnsignedInt(byteBuf.readShort());
|
||||
if (length != byteBuf.readableBytes()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Payload length did not match remaining buffer, should have been guaranteed by a previous handler");
|
||||
}
|
||||
return new NoiseDirectFrame(frameType, byteBuf.readSlice(length));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||
|
||||
/**
|
||||
* Waits for a Handshake {@link NoiseDirectFrame} and then replaces itself with a {@link NoiseDirectDataFrameCodec} and
|
||||
* forwards the handshake frame along as a {@link NoiseHandshakeInit} message
|
||||
*/
|
||||
public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame frame) {
|
||||
try {
|
||||
if (!(ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress)) {
|
||||
throw new IOException("Could not determine remote address");
|
||||
}
|
||||
// We've received an inbound handshake frame. Pull the framing-protocol specific data the downstream handler
|
||||
// needs into a NoiseHandshakeInit message and forward that along
|
||||
final NoiseHandshakeInit handshakeMessage = new NoiseHandshakeInit(inetSocketAddress.getAddress(),
|
||||
switch (frame.frameType()) {
|
||||
case DATA -> throw new NoiseHandshakeException("First message must have handshake frame type");
|
||||
case CLOSE -> throw new IllegalStateException("Close frames should not reach handshake selector");
|
||||
case IK_HANDSHAKE -> HandshakePattern.IK;
|
||||
case NK_HANDSHAKE -> HandshakePattern.NK;
|
||||
}, frame.content());
|
||||
|
||||
// Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
|
||||
// should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
|
||||
// for network serialization. Note that we need to install the Data frame handler before firing the read,
|
||||
// because we may receive an outbound message from the noiseHandler
|
||||
ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
|
||||
ctx.fireChannelRead(handshakeMessage);
|
||||
} catch (Exception e) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
throw e;
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
|
||||
|
||||
/**
|
||||
* Watches for inbound close frames and closes the connection in response
|
||||
*/
|
||||
public class NoiseDirectInboundCloseHandler extends ChannelInboundHandlerAdapter {
|
||||
private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(ChannelInboundHandlerAdapter.class, "clientClose");
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||
try {
|
||||
final NoiseDirectProtos.CloseReason closeReason = NoiseDirectProtos.CloseReason
|
||||
.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||
|
||||
Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "reason", closeReason.getCode().name()).increment();
|
||||
} finally {
|
||||
ReferenceCountUtil.release(msg);
|
||||
ctx.close();
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufOutputStream;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||
|
||||
/**
|
||||
* Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are
|
||||
* written, the channel is closed
|
||||
*/
|
||||
class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof OutboundCloseErrorMessage err) {
|
||||
final NoiseDirectProtos.CloseReason.Code code = switch (err.code()) {
|
||||
case SERVER_CLOSED -> NoiseDirectProtos.CloseReason.Code.UNAVAILABLE;
|
||||
case NOISE_ERROR -> NoiseDirectProtos.CloseReason.Code.ENCRYPTION_ERROR;
|
||||
case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.CloseReason.Code.HANDSHAKE_ERROR;
|
||||
case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.CloseReason.Code.INTERNAL_ERROR;
|
||||
};
|
||||
final NoiseDirectProtos.CloseReason proto = NoiseDirectProtos.CloseReason.newBuilder()
|
||||
.setCode(code)
|
||||
.setMessage(err.message())
|
||||
.build();
|
||||
final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
|
||||
proto.writeTo(new ByteBufOutputStream(byteBuf));
|
||||
ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.CLOSE, byteBuf))
|
||||
.addListener(ChannelFutureListener.CLOSE);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.dropwizard.lifecycle.Managed;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.channel.socket.ServerSocketChannel;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||
import java.net.InetSocketAddress;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
/**
|
||||
* A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom
|
||||
* binary framing protocol) and passes it through to a local gRPC server.
|
||||
*/
|
||||
public class NoiseDirectTunnelServer implements Managed {
|
||||
|
||||
private final ServerBootstrap bootstrap;
|
||||
private ServerSocketChannel channel;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class);
|
||||
|
||||
public NoiseDirectTunnelServer(final int port,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress) {
|
||||
|
||||
this.bootstrap = new ServerBootstrap()
|
||||
.group(eventLoopGroup)
|
||||
.channel(NioServerSocketChannel.class)
|
||||
.localAddress(port)
|
||||
.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
protected void initChannel(SocketChannel socketChannel) {
|
||||
socketChannel.pipeline()
|
||||
.addLast(new ProxyProtocolDetectionHandler())
|
||||
.addLast(new HAProxyMessageHandler())
|
||||
// frame byte followed by a 2-byte length field
|
||||
.addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
|
||||
// Parses NoiseDirectFrames from wire bytes and vice versa
|
||||
.addLast(new NoiseDirectFrameCodec())
|
||||
// Terminate the connection if the client sends us a close frame
|
||||
.addLast(new NoiseDirectInboundCloseHandler())
|
||||
// Turn generic OutboundCloseErrorMessages into noise direct error frames
|
||||
.addLast(new NoiseDirectOutboundErrorHandler())
|
||||
// Forwards the first payload supplemented with handshake metadata, and then replaces itself with a
|
||||
// NoiseDirectDataFrameCodec to handle subsequent data frames
|
||||
.addLast(new NoiseDirectHandshakeSelector())
|
||||
// Performs the noise handshake and then replace itself with a NoiseHandler
|
||||
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
|
||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||
// once the Noise handshake has completed
|
||||
.addLast(new EstablishLocalGrpcConnectionHandler(
|
||||
grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
||||
.addLast(new ErrorHandler());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public InetSocketAddress getLocalAddress() {
|
||||
return channel.localAddress();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws InterruptedException {
|
||||
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stop() throws InterruptedException {
|
||||
if (channel != null) {
|
||||
channel.close().await();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,12 +1,10 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
|
||||
enum ApplicationWebSocketCloseReason {
|
||||
NOISE_HANDSHAKE_ERROR(4001),
|
||||
CLIENT_AUTHENTICATION_ERROR(4002),
|
||||
NOISE_ENCRYPTION_ERROR(4003),
|
||||
REAUTHENTICATION_REQUIRED(4004);
|
||||
NOISE_ENCRYPTION_ERROR(4002);
|
||||
|
||||
private final int statusCode;
|
||||
|
||||
|
@ -17,8 +15,4 @@ enum ApplicationWebSocketCloseReason {
|
|||
public int getStatusCode() {
|
||||
return statusCode;
|
||||
}
|
||||
|
||||
WebSocketCloseStatus toWebSocketCloseStatus(final String reason) {
|
||||
return new WebSocketCloseStatus(statusCode, reason);
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
|
@ -28,6 +28,7 @@ import javax.net.ssl.SSLException;
|
|||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.*;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
/**
|
||||
|
@ -103,10 +104,16 @@ public class NoiseWebSocketTunnelServer implements Managed {
|
|||
// request and passed it down the pipeline
|
||||
.addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH))
|
||||
.addLast(new WebSocketServerProtocolHandler("/", true))
|
||||
// Turn generic OutboundCloseErrorMessages into websocket close frames
|
||||
.addLast(new WebSocketOutboundErrorHandler())
|
||||
.addLast(new RejectUnsupportedMessagesHandler())
|
||||
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
|
||||
// a WebSocket handshake has been completed
|
||||
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret))
|
||||
.addLast(new WebsocketPayloadCodec())
|
||||
// The WebSocket handshake complete listener will forward the first payload supplemented with
|
||||
// data from the websocket handshake completion event, and then remove itself from the pipeline
|
||||
.addLast(new WebsocketHandshakeCompleteHandler(recognizedProxySecret))
|
||||
// The NoiseHandshakeHandler will perform the noise handshake and then replace itself with a
|
||||
// NoiseHandler
|
||||
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
|
||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||
// once the Noise handshake has completed
|
||||
.addLast(new EstablishLocalGrpcConnectionHandler(grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
|
@ -1,4 +1,4 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
|
@ -1,4 +1,4 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
|
@ -0,0 +1,58 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||
|
||||
/**
|
||||
* Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
|
||||
*/
|
||||
class WebSocketOutboundErrorHandler extends ChannelDuplexHandler {
|
||||
|
||||
private boolean websocketHandshakeComplete = false;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class);
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
||||
setWebsocketHandshakeComplete();
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
protected void setWebsocketHandshakeComplete() {
|
||||
this.websocketHandshakeComplete = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof OutboundCloseErrorMessage err) {
|
||||
if (websocketHandshakeComplete) {
|
||||
final int status = switch (err.code()) {
|
||||
case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
|
||||
case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
|
||||
case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
|
||||
case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
|
||||
};
|
||||
ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
} else {
|
||||
log.debug("Error {} occurred before websocket handshake complete", err);
|
||||
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
|
||||
// way; just close the connection instead.
|
||||
ctx.close();
|
||||
}
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,15 +1,15 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.net.InetAddresses;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
@ -17,10 +17,10 @@ import java.security.MessageDigest;
|
|||
import java.util.Optional;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||
|
||||
/**
|
||||
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
|
||||
|
@ -28,10 +28,6 @@ import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
|||
*/
|
||||
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private final ECKeyPair ecKeyPair;
|
||||
|
||||
private final byte[] recognizedProxySecret;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
|
||||
|
@ -42,12 +38,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
|||
@VisibleForTesting
|
||||
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
|
||||
|
||||
WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final String recognizedProxySecret) {
|
||||
private InetAddress remoteAddress = null;
|
||||
private HandshakePattern handshakePattern = null;
|
||||
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
WebsocketHandshakeCompleteHandler(final String recognizedProxySecret) {
|
||||
|
||||
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
|
||||
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
|
||||
|
@ -58,8 +52,6 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
|||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||
final InetAddress preferredRemoteAddress;
|
||||
{
|
||||
final Optional<InetAddress> maybePreferredRemoteAddress =
|
||||
getPreferredRemoteAddress(context, handshakeCompleteEvent);
|
||||
|
||||
|
@ -71,34 +63,41 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
|||
return;
|
||||
}
|
||||
|
||||
preferredRemoteAddress = maybePreferredRemoteAddress.get();
|
||||
}
|
||||
|
||||
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(),
|
||||
preferredRemoteAddress,
|
||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
|
||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));
|
||||
|
||||
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
|
||||
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH ->
|
||||
new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair);
|
||||
|
||||
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH ->
|
||||
new NoiseAnonymousHandler(ecKeyPair);
|
||||
|
||||
default -> {
|
||||
remoteAddress = maybePreferredRemoteAddress.get();
|
||||
handshakePattern = switch (handshakeCompleteEvent.requestUri()) {
|
||||
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> HandshakePattern.IK;
|
||||
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> HandshakePattern.NK;
|
||||
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
|
||||
// internal error if something slipped through.
|
||||
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
||||
}
|
||||
default -> throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
||||
};
|
||||
|
||||
context.pipeline().replace(WebsocketHandshakeCompleteHandler.this, null, noiseHandshakeHandler);
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object msg) {
|
||||
try {
|
||||
if (!(msg instanceof ByteBuf frame)) {
|
||||
throw new IllegalStateException("Unexpected msg type: " + msg.getClass());
|
||||
}
|
||||
|
||||
if (handshakePattern == null || remoteAddress == null) {
|
||||
throw new IllegalStateException("Received payload before websocket handshake complete");
|
||||
}
|
||||
|
||||
final NoiseHandshakeInit handshakeMessage =
|
||||
new NoiseHandshakeInit(remoteAddress, handshakePattern, frame);
|
||||
|
||||
context.pipeline().remove(WebsocketHandshakeCompleteHandler.class);
|
||||
context.fireChannelRead(handshakeMessage);
|
||||
} catch (Exception e) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
|
||||
/**
|
||||
* Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a
|
||||
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise
|
||||
* packet buffers in BinaryWebsocketFrames for writing through the websocket layer.
|
||||
*/
|
||||
public class WebsocketPayloadCodec extends ChannelDuplexHandler {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
|
||||
if (msg instanceof BinaryWebSocketFrame frame) {
|
||||
ctx.fireChannelRead(frame.content());
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||
if (msg instanceof ByteBuf bb) {
|
||||
ctx.write(new BinaryWebSocketFrame(bb), promise);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,7 +11,6 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusException;
|
||||
|
@ -49,7 +48,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
|
|||
public void validate(
|
||||
final Object extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final GeneratedMessageV3 msg) throws StatusException {
|
||||
final Message msg) throws StatusException {
|
||||
try {
|
||||
final T extensionValueTyped = resolveExtensionValue(extensionValue);
|
||||
|
||||
|
@ -116,7 +115,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
|
|||
protected void validateRepeatedField(
|
||||
final T extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final GeneratedMessageV3 msg) throws StatusException {
|
||||
final Message msg) throws StatusException {
|
||||
throw internalError("`validateRepeatedField` method needs to be implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.StatusException;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
@ -53,7 +53,7 @@ public class ExactlySizeFieldValidator extends BaseFieldValidator<Set<Integer>>
|
|||
protected void validateRepeatedField(
|
||||
final Set<Integer> permittedSizes,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final GeneratedMessageV3 msg) throws StatusException {
|
||||
final Message msg) throws StatusException {
|
||||
final int size = msg.getRepeatedFieldCount(fd);
|
||||
if (permittedSizes.contains(size)) {
|
||||
return;
|
||||
|
|
|
@ -6,11 +6,11 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.StatusException;
|
||||
|
||||
public interface FieldValidator {
|
||||
|
||||
void validate(Object extensionValue, Descriptors.FieldDescriptor fd, GeneratedMessageV3 msg)
|
||||
void validate(Object extensionValue, Descriptors.FieldDescriptor fd, Message msg)
|
||||
throws StatusException;
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.StatusException;
|
||||
import java.util.Set;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
@ -52,7 +52,7 @@ public class NonEmptyFieldValidator extends BaseFieldValidator<Boolean> {
|
|||
protected void validateRepeatedField(
|
||||
final Boolean extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final GeneratedMessageV3 msg) throws StatusException {
|
||||
final Message msg) throws StatusException {
|
||||
if (msg.getRepeatedFieldCount(fd) > 0) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
|
||||
public record Range(int min, int max) {
|
||||
public record Range(long min, long max) {
|
||||
public Range {
|
||||
if (min > max) {
|
||||
throw new IllegalArgumentException("invalid range values: expected min <= max but have [%d, %d],".formatted(min, max));
|
||||
|
|
|
@ -39,8 +39,8 @@ public class RangeFieldValidator extends BaseFieldValidator<Range> {
|
|||
@Override
|
||||
protected Range resolveExtensionValue(final Object extensionValue) throws StatusException {
|
||||
final ValueRangeConstraint rangeConstraint = (ValueRangeConstraint) extensionValue;
|
||||
final int min = rangeConstraint.hasMin() ? rangeConstraint.getMin() : Integer.MIN_VALUE;
|
||||
final int max = rangeConstraint.hasMax() ? rangeConstraint.getMax() : Integer.MAX_VALUE;
|
||||
final long min = rangeConstraint.hasMin() ? rangeConstraint.getMin() : Long.MIN_VALUE;
|
||||
final long max = rangeConstraint.hasMax() ? rangeConstraint.getMax() : Long.MAX_VALUE;
|
||||
return new Range(min, max);
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.StatusException;
|
||||
import java.util.Set;
|
||||
import org.signal.chat.require.SizeConstraint;
|
||||
|
@ -48,7 +48,7 @@ public class SizeFieldValidator extends BaseFieldValidator<Range> {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final GeneratedMessageV3 msg) throws StatusException {
|
||||
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final Message msg) throws StatusException {
|
||||
final int size = msg.getRepeatedFieldCount(fd);
|
||||
if (size < range.min() || size > range.max()) {
|
||||
throw invalidArgument("field value is [%d] but expected to be within the [%d, %d] range".formatted(
|
||||
|
|
|
@ -10,16 +10,12 @@ import static java.util.Objects.requireNonNull;
|
|||
import io.lettuce.core.ScriptOutputType;
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.time.Clock;
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||
|
@ -27,25 +23,18 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
|||
|
||||
public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
|
||||
|
||||
private final Map<T, RateLimiter> rateLimiterByDescriptor;
|
||||
|
||||
private final Map<String, RateLimiterConfig> configs;
|
||||
|
||||
|
||||
protected BaseRateLimiters(
|
||||
final T[] values,
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cacheCluster,
|
||||
final Clock clock) {
|
||||
this.configs = configs;
|
||||
this.rateLimiterByDescriptor = Arrays.stream(values)
|
||||
.map(descriptor -> Pair.of(
|
||||
descriptor,
|
||||
createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
|
||||
createForDescriptor(descriptor, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
|
||||
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
|
||||
}
|
||||
|
||||
|
@ -53,22 +42,6 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
|||
return requireNonNull(rateLimiterByDescriptor.get(handle));
|
||||
}
|
||||
|
||||
public void validateValuesAndConfigs() {
|
||||
final Set<String> ids = rateLimiterByDescriptor.keySet().stream()
|
||||
.map(RateLimiterDescriptor::id)
|
||||
.collect(Collectors.toSet());
|
||||
for (final String key: configs.keySet()) {
|
||||
if (!ids.contains(key)) {
|
||||
final String message = String.format(
|
||||
"Static configuration has an unexpected field '%s' that doesn't match any RateLimiterDescriptor",
|
||||
key
|
||||
);
|
||||
logger.error(message);
|
||||
throw new IllegalArgumentException(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected static ClusterLuaScript defaultScript(final FaultTolerantRedisClusterClient cacheCluster) {
|
||||
try {
|
||||
return ClusterLuaScript.fromResource(
|
||||
|
@ -80,21 +53,12 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
|||
|
||||
private static RateLimiter createForDescriptor(
|
||||
final RateLimiterDescriptor descriptor,
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cacheCluster,
|
||||
final Clock clock) {
|
||||
if (descriptor.isDynamic()) {
|
||||
final Supplier<RateLimiterConfig> configResolver = () -> {
|
||||
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
|
||||
return config != null
|
||||
? config
|
||||
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
||||
};
|
||||
return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock);
|
||||
}
|
||||
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
||||
return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock, dynamicConfigurationManager);
|
||||
final Supplier<RateLimiterConfig> configResolver =
|
||||
() -> dynamicConfigurationManager.getConfiguration().getLimits().getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
||||
return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,87 +7,167 @@ package org.whispersystems.textsecuregcm.limits;
|
|||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.Supplier;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
public class DynamicRateLimiter implements RateLimiter {
|
||||
|
||||
private final String name;
|
||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
|
||||
private final Supplier<RateLimiterConfig> configResolver;
|
||||
|
||||
private final ClusterLuaScript validateScript;
|
||||
|
||||
private final FaultTolerantRedisClusterClient cluster;
|
||||
|
||||
private final Clock clock;
|
||||
private final Counter limitExceededCounter;
|
||||
|
||||
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
|
||||
private final Clock clock;
|
||||
|
||||
|
||||
public DynamicRateLimiter(
|
||||
final String name,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final Supplier<RateLimiterConfig> configResolver,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cluster,
|
||||
final Clock clock) {
|
||||
this.name = requireNonNull(name);
|
||||
this.dynamicConfigurationManager = dynamicConfigurationManager;
|
||||
this.configResolver = requireNonNull(configResolver);
|
||||
this.validateScript = requireNonNull(validateScript);
|
||||
this.cluster = requireNonNull(cluster);
|
||||
this.clock = requireNonNull(clock);
|
||||
this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(final String key, final int amount) throws RateLimitExceededException {
|
||||
current().getRight().validate(key, amount);
|
||||
final RateLimiterConfig config = config();
|
||||
try {
|
||||
final long deficitPermitsAmount = executeValidateScript(config, key, amount, true);
|
||||
if (deficitPermitsAmount > 0) {
|
||||
limitExceededCounter.increment();
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
throw new RateLimitExceededException(retryAfter);
|
||||
}
|
||||
} catch (final Exception e) {
|
||||
if (e instanceof RateLimitExceededException rateLimitExceededException) {
|
||||
throw rateLimitExceededException;
|
||||
}
|
||||
|
||||
if (!config.failOpen()) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> validateAsync(final String key, final int amount) {
|
||||
return current().getRight().validateAsync(key, amount);
|
||||
final RateLimiterConfig config = config();
|
||||
|
||||
return executeValidateScriptAsync(config, key, amount, true)
|
||||
.thenCompose(deficitPermitsAmount -> {
|
||||
if (deficitPermitsAmount == 0) {
|
||||
return CompletableFuture.completedFuture((Void) null);
|
||||
}
|
||||
limitExceededCounter.increment();
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
return CompletableFuture.failedFuture(new RateLimitExceededException(retryAfter));
|
||||
})
|
||||
.exceptionally(throwable -> {
|
||||
if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) {
|
||||
throw ExceptionUtils.wrap(rateLimitExceededException);
|
||||
}
|
||||
|
||||
if (config.failOpen()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasAvailablePermits(final String key, final int permits) {
|
||||
return current().getRight().hasAvailablePermits(key, permits);
|
||||
final RateLimiterConfig config = config();
|
||||
try {
|
||||
final long deficitPermitsAmount = executeValidateScript(config, key, permits, false);
|
||||
return deficitPermitsAmount == 0;
|
||||
} catch (final Exception e) {
|
||||
if (config.failOpen()) {
|
||||
return true;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
|
||||
return current().getRight().hasAvailablePermitsAsync(key, amount);
|
||||
final RateLimiterConfig config = config();
|
||||
return executeValidateScriptAsync(config, key, amount, false)
|
||||
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
|
||||
.exceptionally(throwable -> {
|
||||
if (config.failOpen()) {
|
||||
return true;
|
||||
}
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear(final String key) {
|
||||
current().getRight().clear(key);
|
||||
cluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> clearAsync(final String key) {
|
||||
return current().getRight().clearAsync(key);
|
||||
return cluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
|
||||
.thenRun(Util.NOOP);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimiterConfig config() {
|
||||
return current().getLeft();
|
||||
return configResolver.get();
|
||||
}
|
||||
|
||||
private Pair<RateLimiterConfig, RateLimiter> current() {
|
||||
final RateLimiterConfig cfg = configResolver.get();
|
||||
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
|
||||
? p
|
||||
: Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock, dynamicConfigurationManager))
|
||||
private long executeValidateScript(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return (Long) validateScript.execute(keys, arguments);
|
||||
}
|
||||
|
||||
private CompletionStage<Long> executeValidateScriptAsync(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
|
||||
}
|
||||
|
||||
private static String bucketName(final String name, final String key) {
|
||||
return "leaky_bucket::" + name + "::" + key;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.limits;
|
|||
import jakarta.validation.constraints.AssertTrue;
|
||||
import java.time.Duration;
|
||||
|
||||
public record RateLimiterConfig(int bucketSize, Duration permitRegenerationDuration) {
|
||||
public record RateLimiterConfig(int bucketSize, Duration permitRegenerationDuration, boolean failOpen) {
|
||||
|
||||
public double leakRatePerMillis() {
|
||||
return 1.0 / (permitRegenerationDuration.toNanos() / 1e6);
|
||||
|
|
|
@ -15,14 +15,9 @@ public interface RateLimiterDescriptor {
|
|||
*/
|
||||
String id();
|
||||
|
||||
/**
|
||||
* @return {@code true} if this rate limiter needs to watch for dynamic configuration changes.
|
||||
*/
|
||||
boolean isDynamic();
|
||||
|
||||
/**
|
||||
* @return an instance of {@link RateLimiterConfig} to be used by default,
|
||||
* i.e. if there is no overrides in the application configuration files (static or dynamic).
|
||||
* i.e. if there is no override in the application dynamic configuration.
|
||||
*/
|
||||
RateLimiterConfig defaultConfig();
|
||||
}
|
||||
|
|
|
@ -4,11 +4,9 @@
|
|||
*/
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.Map;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||
|
@ -17,59 +15,54 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
|||
public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
||||
|
||||
public enum For implements RateLimiterDescriptor {
|
||||
BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
PIN("pin", false, new RateLimiterConfig(10, Duration.ofDays(1))),
|
||||
ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, Duration.ofMillis(1200))),
|
||||
BACKUP_ATTACHMENT("backupAttachmentCreate", true, new RateLimiterConfig(10_000, Duration.ofSeconds(1))),
|
||||
PRE_KEYS("prekeys", false, new RateLimiterConfig(6, Duration.ofMinutes(10))),
|
||||
MESSAGES("messages", false, new RateLimiterConfig(60, Duration.ofSeconds(1))),
|
||||
STORIES("stories", false, new RateLimiterConfig(5_000, Duration.ofSeconds(8))),
|
||||
ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2))),
|
||||
VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2))),
|
||||
TURN("turnAllocate", false, new RateLimiterConfig(60, Duration.ofSeconds(1))),
|
||||
PROFILE("profile", false, new RateLimiterConfig(4320, Duration.ofSeconds(20))),
|
||||
STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, Duration.ofMinutes(72))),
|
||||
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1))),
|
||||
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15))),
|
||||
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofSeconds(4))),
|
||||
REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofSeconds(30))),
|
||||
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofSeconds(30))),
|
||||
VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, Duration.ofSeconds(30))),
|
||||
RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, Duration.ofHours(12))),
|
||||
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))),
|
||||
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))),
|
||||
SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(10, Duration.ofHours(1))),
|
||||
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", true, new RateLimiterConfig(5, Duration.ofDays(7))),
|
||||
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))),
|
||||
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))),
|
||||
GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10))),
|
||||
CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000))),
|
||||
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", true,
|
||||
new RateLimiterConfig(100, Duration.ofSeconds(15))),
|
||||
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
|
||||
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
|
||||
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
|
||||
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1))),
|
||||
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
|
||||
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))),
|
||||
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))),
|
||||
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", true, new RateLimiterConfig(10, Duration.ofMinutes(1))),
|
||||
BACKUP_AUTH_CHECK("backupAuthCheck", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
|
||||
PIN("pin", new RateLimiterConfig(10, Duration.ofDays(1), false)),
|
||||
ATTACHMENT("attachmentCreate", new RateLimiterConfig(50, Duration.ofMillis(1200), true)),
|
||||
BACKUP_ATTACHMENT("backupAttachmentCreate", new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)),
|
||||
PRE_KEYS("prekeys", new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
|
||||
MESSAGES("messages", new RateLimiterConfig(60, Duration.ofSeconds(1), true)),
|
||||
STORIES("stories", new RateLimiterConfig(5_000, Duration.ofSeconds(8), true)),
|
||||
ALLOCATE_DEVICE("allocateDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
|
||||
VERIFY_DEVICE("verifyDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
|
||||
PROFILE("profile", new RateLimiterConfig(4320, Duration.ofSeconds(20), true)),
|
||||
STICKER_PACK("stickerPack", new RateLimiterConfig(50, Duration.ofMinutes(72), false)),
|
||||
USERNAME_LOOKUP("usernameLookup", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
|
||||
USERNAME_SET("usernameSet", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||
USERNAME_RESERVE("usernameReserve", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||
USERNAME_LINK_OPERATION("usernameLinkOperation", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
||||
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", new RateLimiterConfig(1000, Duration.ofSeconds(4), true)),
|
||||
REGISTRATION("registration", new RateLimiterConfig(6, Duration.ofSeconds(30), false)),
|
||||
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", new RateLimiterConfig(5, Duration.ofSeconds(30), false)),
|
||||
VERIFICATION_CAPTCHA("verificationCaptcha", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
||||
RATE_LIMIT_RESET("rateLimitReset", new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
||||
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
|
||||
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
||||
SET_BACKUP_ID("setBackupId", new RateLimiterConfig(10, Duration.ofHours(1), false)),
|
||||
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", new RateLimiterConfig(5, Duration.ofDays(7), false)),
|
||||
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
|
||||
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
||||
GET_CALLING_RELAYS("getCallingRelays", new RateLimiterConfig(100, Duration.ofMinutes(10), false)),
|
||||
CREATE_CALL_LINK("createCallLink", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||
INBOUND_MESSAGE_BYTES("inboundMessageBytes", new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), true)),
|
||||
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
||||
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
||||
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
||||
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
|
||||
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
|
||||
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
||||
;
|
||||
|
||||
private final String id;
|
||||
|
||||
private final boolean dynamic;
|
||||
|
||||
private final RateLimiterConfig defaultConfig;
|
||||
|
||||
For(final String id, final boolean dynamic, final RateLimiterConfig defaultConfig) {
|
||||
For(final String id, final RateLimiterConfig defaultConfig) {
|
||||
this.id = id;
|
||||
this.dynamic = dynamic;
|
||||
this.defaultConfig = defaultConfig;
|
||||
}
|
||||
|
||||
|
@ -77,34 +70,25 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
|||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDynamic() {
|
||||
return dynamic;
|
||||
}
|
||||
|
||||
public RateLimiterConfig defaultConfig() {
|
||||
return defaultConfig;
|
||||
}
|
||||
}
|
||||
|
||||
public static RateLimiters createAndValidate(
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
public static RateLimiters create(
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final FaultTolerantRedisClusterClient cacheCluster) {
|
||||
final RateLimiters rateLimiters = new RateLimiters(
|
||||
configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
|
||||
rateLimiters.validateValuesAndConfigs();
|
||||
return rateLimiters;
|
||||
return new RateLimiters(
|
||||
dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
RateLimiters(
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cacheCluster,
|
||||
final Clock clock) {
|
||||
super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock);
|
||||
super(For.values(), dynamicConfigurationManager, validateScript, cacheCluster, clock);
|
||||
}
|
||||
|
||||
public RateLimiter getAllocateDeviceLimiter() {
|
||||
|
@ -131,10 +115,6 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
|||
return forDescriptor(For.PIN);
|
||||
}
|
||||
|
||||
public RateLimiter getTurnLimiter() {
|
||||
return forDescriptor(For.TURN);
|
||||
}
|
||||
|
||||
public RateLimiter getProfileLimiter() {
|
||||
return forDescriptor(For.PROFILE);
|
||||
}
|
||||
|
|
|
@ -1,171 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
import static java.util.concurrent.CompletableFuture.completedFuture;
|
||||
import static java.util.concurrent.CompletableFuture.failedFuture;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.lettuce.core.RedisException;
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
public class StaticRateLimiter implements RateLimiter {
|
||||
|
||||
protected final String name;
|
||||
|
||||
private final RateLimiterConfig config;
|
||||
|
||||
private final Counter counter;
|
||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
|
||||
|
||||
private final ClusterLuaScript validateScript;
|
||||
|
||||
private final FaultTolerantRedisClusterClient cacheCluster;
|
||||
|
||||
private final Clock clock;
|
||||
|
||||
|
||||
public StaticRateLimiter(
|
||||
final String name,
|
||||
final RateLimiterConfig config,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cacheCluster,
|
||||
final Clock clock,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
|
||||
this.name = requireNonNull(name);
|
||||
this.config = requireNonNull(config);
|
||||
this.validateScript = requireNonNull(validateScript);
|
||||
this.cacheCluster = requireNonNull(cacheCluster);
|
||||
this.clock = requireNonNull(clock);
|
||||
this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
|
||||
this.dynamicConfigurationManager = dynamicConfigurationManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(final String key, final int amount) throws RateLimitExceededException {
|
||||
try {
|
||||
final long deficitPermitsAmount = executeValidateScript(key, amount, true);
|
||||
if (deficitPermitsAmount > 0) {
|
||||
counter.increment();
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
throw new RateLimitExceededException(retryAfter);
|
||||
}
|
||||
} catch (RedisException e) {
|
||||
if (!failOpen()) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> validateAsync(final String key, final int amount) {
|
||||
return executeValidateScriptAsync(key, amount, true)
|
||||
.thenCompose(deficitPermitsAmount -> {
|
||||
if (deficitPermitsAmount == 0) {
|
||||
return completedFuture((Void) null);
|
||||
}
|
||||
counter.increment();
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
return failedFuture(new RateLimitExceededException(retryAfter));
|
||||
})
|
||||
.exceptionally(throwable -> {
|
||||
if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) {
|
||||
return null;
|
||||
}
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasAvailablePermits(final String key, final int amount) {
|
||||
try {
|
||||
final long deficitPermitsAmount = executeValidateScript(key, amount, false);
|
||||
return deficitPermitsAmount == 0;
|
||||
} catch (RedisException e) {
|
||||
if (failOpen()) {
|
||||
return true;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
|
||||
return executeValidateScriptAsync(key, amount, false)
|
||||
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
|
||||
.exceptionally(throwable -> {
|
||||
if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) {
|
||||
return true;
|
||||
}
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear(final String key) {
|
||||
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> clearAsync(final String key) {
|
||||
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
|
||||
.thenRun(Util.NOOP);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimiterConfig config() {
|
||||
return config;
|
||||
}
|
||||
|
||||
private boolean failOpen() {
|
||||
return this.dynamicConfigurationManager.getConfiguration().getRateLimitPolicy().failOpen();
|
||||
}
|
||||
|
||||
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return (Long) validateScript.execute(keys, arguments);
|
||||
}
|
||||
|
||||
private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
protected static String bucketName(final String name, final String key) {
|
||||
return "leaky_bucket::" + name + "::" + key;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.metrics;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
|
||||
import java.util.Optional;
|
||||
|
||||
public class DevicePlatformUtil {
|
||||
|
||||
private DevicePlatformUtil() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the most likely client platform for a device.
|
||||
*
|
||||
* @param device the device for which to find a client platform
|
||||
*
|
||||
* @return the most likely client platform for the given device or empty if no likely platform could be determined
|
||||
*/
|
||||
public static Optional<ClientPlatform> getDevicePlatform(final Device device) {
|
||||
final Optional<ClientPlatform> clientPlatform;
|
||||
|
||||
if (StringUtils.isNotBlank(device.getGcmId())) {
|
||||
clientPlatform = Optional.of(ClientPlatform.ANDROID);
|
||||
} else if (StringUtils.isNotBlank(device.getApnId())) {
|
||||
clientPlatform = Optional.of(ClientPlatform.IOS);
|
||||
} else {
|
||||
clientPlatform = Optional.empty();
|
||||
}
|
||||
|
||||
return clientPlatform.or(() -> Optional.ofNullable(
|
||||
switch (device.getUserAgent()) {
|
||||
case "OWA" -> ClientPlatform.ANDROID;
|
||||
case "OWI", "OWP" -> ClientPlatform.IOS;
|
||||
case "OWD" -> ClientPlatform.DESKTOP;
|
||||
case null, default -> null;
|
||||
}));
|
||||
}
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.metrics;
|
||||
|
||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||
|
||||
import com.sun.management.OperatingSystemMXBean;
|
||||
import io.micrometer.core.instrument.Gauge;
|
||||
import io.micrometer.core.instrument.MeterRegistry;
|
||||
import io.micrometer.core.instrument.binder.MeterBinder;
|
||||
import java.lang.management.ManagementFactory;
|
||||
|
||||
public class FreeMemoryGauge implements MeterBinder {
|
||||
|
||||
private final OperatingSystemMXBean operatingSystemMXBean;
|
||||
|
||||
public FreeMemoryGauge() {
|
||||
this.operatingSystemMXBean = (com.sun.management.OperatingSystemMXBean)
|
||||
ManagementFactory.getOperatingSystemMXBean();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void bindTo(final MeterRegistry registry) {
|
||||
Gauge.builder(name(FreeMemoryGauge.class, "freeMemory"), operatingSystemMXBean,
|
||||
OperatingSystemMXBean::getFreeMemorySize)
|
||||
.register(registry);
|
||||
|
||||
}
|
||||
}
|
|
@ -120,10 +120,7 @@ public class MetricsUtil {
|
|||
|
||||
public static void registerSystemResourceMetrics(final Environment environment) {
|
||||
new ProcessorMetrics().bindTo(Metrics.globalRegistry);
|
||||
new FreeMemoryGauge().bindTo(Metrics.globalRegistry);
|
||||
new FileDescriptorMetrics().bindTo(Metrics.globalRegistry);
|
||||
new OperatingSystemMemoryGauge("Buffers").bindTo(Metrics.globalRegistry);
|
||||
new OperatingSystemMemoryGauge("Cached").bindTo(Metrics.globalRegistry);
|
||||
|
||||
new JvmMemoryMetrics().bindTo(Metrics.globalRegistry);
|
||||
new JvmThreadMetrics().bindTo(Metrics.globalRegistry);
|
||||
|
|
|
@ -67,7 +67,7 @@ public class OpenWebSocketCounter {
|
|||
|
||||
try {
|
||||
final ClientPlatform clientPlatform =
|
||||
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).getPlatform();
|
||||
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).platform();
|
||||
|
||||
calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform);
|
||||
calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform);
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.metrics;
|
||||
|
||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.micrometer.core.instrument.Gauge;
|
||||
import io.micrometer.core.instrument.MeterRegistry;
|
||||
import io.micrometer.core.instrument.binder.MeterBinder;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class OperatingSystemMemoryGauge implements MeterBinder {
|
||||
|
||||
private final String metricName;
|
||||
|
||||
private static final File MEMINFO_FILE = new File("/proc/meminfo");
|
||||
private static final Pattern MEMORY_METRIC_PATTERN = Pattern.compile("^([^:]+):\\s+([0-9]+).*$");
|
||||
|
||||
public OperatingSystemMemoryGauge(final String metricName) {
|
||||
this.metricName = metricName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void bindTo(MeterRegistry registry) {
|
||||
final String metricName = this.metricName;
|
||||
Gauge.builder(name(OperatingSystemMemoryGauge.class, metricName.toLowerCase(Locale.ROOT)), () -> {
|
||||
try (final BufferedReader bufferedReader = new BufferedReader(new FileReader(MEMINFO_FILE))) {
|
||||
return getValue(bufferedReader.lines(), metricName);
|
||||
} catch (final IOException e) {
|
||||
return 0L;
|
||||
}
|
||||
})
|
||||
.register(registry);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static double getValue(final Stream<String> lines, final String metricName) {
|
||||
return lines.map(MEMORY_METRIC_PATTERN::matcher)
|
||||
.filter(Matcher::matches)
|
||||
.filter(matcher -> metricName.equalsIgnoreCase(matcher.group(1)))
|
||||
.map(matcher -> Double.parseDouble(matcher.group(2)))
|
||||
.findFirst()
|
||||
.orElse(0d);
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ import io.micrometer.core.instrument.Tag;
|
|||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.whispersystems.textsecuregcm.WhisperServerVersion;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
|
@ -48,15 +49,15 @@ public class UserAgentTagUtil {
|
|||
}
|
||||
|
||||
public static Tag getPlatformTag(@Nullable final UserAgent userAgent) {
|
||||
return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized");
|
||||
return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized");
|
||||
}
|
||||
|
||||
public static Optional<Tag> getClientVersionTag(final String userAgentString, final ClientReleaseManager clientReleaseManager) {
|
||||
try {
|
||||
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
|
||||
|
||||
if (clientReleaseManager.isVersionActive(userAgent.getPlatform(), userAgent.getVersion())) {
|
||||
return Optional.of(Tag.of(VERSION_TAG, userAgent.getVersion().toString()));
|
||||
if (clientReleaseManager.isVersionActive(userAgent.platform(), userAgent.version())) {
|
||||
return Optional.of(Tag.of(VERSION_TAG, userAgent.version().toString()));
|
||||
}
|
||||
} catch (final UnrecognizedUserAgentException ignored) {
|
||||
}
|
||||
|
@ -70,10 +71,8 @@ public class UserAgentTagUtil {
|
|||
|
||||
try {
|
||||
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
|
||||
platform = userAgent.getPlatform().name().toLowerCase();
|
||||
libsignal = userAgent.getAdditionalSpecifiers()
|
||||
.map(additionalSpecifiers -> additionalSpecifiers.contains("libsignal"))
|
||||
.orElse(false);
|
||||
platform = userAgent.platform().name().toLowerCase();
|
||||
libsignal = StringUtils.contains(userAgent.additionalSpecifiers(), "libsignal");
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
platform = "unrecognized";
|
||||
libsignal = false;
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.util.Optional;
|
|||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.signal.libsignal.protocol.util.Pair;
|
||||
|
@ -27,6 +28,7 @@ import org.whispersystems.textsecuregcm.controllers.MessageController;
|
|||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
|
@ -35,7 +37,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
|
|||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
|
||||
|
@ -52,10 +53,12 @@ public class MessageSender {
|
|||
|
||||
private final MessagesManager messagesManager;
|
||||
private final PushNotificationManager pushNotificationManager;
|
||||
private final ExperimentEnrollmentManager experimentEnrollmentManager;
|
||||
|
||||
// Note that these names deliberately reference `MessageController` for metric continuity
|
||||
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage");
|
||||
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize");
|
||||
private static final String EMPTY_MESSAGE_LIST_COUNTER_NAME = MetricsUtil.name(MessageSender.class, "emptyMessageList");
|
||||
|
||||
private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage");
|
||||
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
|
||||
|
@ -64,6 +67,7 @@ public class MessageSender {
|
|||
private static final String STORY_TAG_NAME = "story";
|
||||
private static final String SEALED_SENDER_TAG_NAME = "sealedSender";
|
||||
private static final String MULTI_RECIPIENT_TAG_NAME = "multiRecipient";
|
||||
private static final String SYNC_MESSAGE_TAG_NAME = "sync";
|
||||
|
||||
@VisibleForTesting
|
||||
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
|
||||
|
@ -71,9 +75,13 @@ public class MessageSender {
|
|||
@VisibleForTesting
|
||||
static final byte NO_EXCLUDED_DEVICE_ID = -1;
|
||||
|
||||
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
|
||||
public MessageSender(
|
||||
final MessagesManager messagesManager,
|
||||
final PushNotificationManager pushNotificationManager,
|
||||
final ExperimentEnrollmentManager experimentEnrollmentManager) {
|
||||
this.messagesManager = messagesManager;
|
||||
this.pushNotificationManager = pushNotificationManager;
|
||||
this.experimentEnrollmentManager = experimentEnrollmentManager;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -85,6 +93,8 @@ public class MessageSender {
|
|||
* @param destinationIdentifier the service identifier to which the messages are addressed
|
||||
* @param messagesByDeviceId a map of device IDs to message payloads
|
||||
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
|
||||
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
|
||||
* caller's own account), contains the ID of the device that sent the message
|
||||
* @param userAgent the User-Agent string for the sender; may be {@code null} if not known
|
||||
*
|
||||
* @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required
|
||||
|
@ -96,34 +106,48 @@ public class MessageSender {
|
|||
final ServiceIdentifier destinationIdentifier,
|
||||
final Map<Byte, Envelope> messagesByDeviceId,
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId,
|
||||
@Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException {
|
||||
|
||||
if (messagesByDeviceId.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!destination.isIdentifiedBy(destinationIdentifier)) {
|
||||
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
|
||||
}
|
||||
|
||||
final Envelope firstMessage = messagesByDeviceId.values().iterator().next();
|
||||
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
|
||||
|
||||
final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) &&
|
||||
destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId()));
|
||||
if (messagesByDeviceId.isEmpty()) {
|
||||
Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME,
|
||||
Tags.of(SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment();
|
||||
}
|
||||
|
||||
final boolean isStory = firstMessage.getStory();
|
||||
final byte excludedDeviceId;
|
||||
if (syncMessageSenderDeviceId.isPresent()) {
|
||||
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) ||
|
||||
!destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
|
||||
|
||||
validateIndividualMessageContentLength(messagesByDeviceId.values(), isSyncMessage, isStory, userAgent);
|
||||
throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender");
|
||||
}
|
||||
excludedDeviceId = syncMessageSenderDeviceId.get();
|
||||
} else {
|
||||
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) &&
|
||||
destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
|
||||
|
||||
throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender");
|
||||
}
|
||||
excludedDeviceId = NO_EXCLUDED_DEVICE_ID;
|
||||
}
|
||||
|
||||
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
|
||||
destinationIdentifier,
|
||||
registrationIdsByDeviceId,
|
||||
isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID);
|
||||
excludedDeviceId);
|
||||
|
||||
if (maybeMismatchedDevices.isPresent()) {
|
||||
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
|
||||
}
|
||||
|
||||
validateIndividualMessageContentLength(messagesByDeviceId.values(), syncMessageSenderDeviceId.isPresent(), userAgent);
|
||||
|
||||
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
|
||||
.forEach((deviceId, destinationPresent) -> {
|
||||
final Envelope message = messagesByDeviceId.get(deviceId);
|
||||
|
@ -141,8 +165,9 @@ public class MessageSender {
|
|||
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
|
||||
STORY_TAG_NAME, String.valueOf(message.getStory()),
|
||||
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()),
|
||||
SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent()),
|
||||
MULTI_RECIPIENT_TAG_NAME, "false")
|
||||
.and(UserAgentTagUtil.getPlatformTag(userAgent));
|
||||
.and(platformTag);
|
||||
|
||||
Metrics.counter(SEND_COUNTER_NAME, tags).increment();
|
||||
});
|
||||
|
@ -225,6 +250,7 @@ public class MessageSender {
|
|||
URGENT_TAG_NAME, String.valueOf(isUrgent),
|
||||
STORY_TAG_NAME, String.valueOf(isStory),
|
||||
SEALED_SENDER_TAG_NAME, "true",
|
||||
SYNC_MESSAGE_TAG_NAME, "false",
|
||||
MULTI_RECIPIENT_TAG_NAME, "true")
|
||||
.and(UserAgentTagUtil.getPlatformTag(userAgent));
|
||||
|
||||
|
@ -290,11 +316,7 @@ public class MessageSender {
|
|||
// We know the device must be present because we've already filtered out device IDs that aren't associated
|
||||
// with the given account
|
||||
final Device device = account.getDevice(deviceId).orElseThrow();
|
||||
|
||||
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
|
||||
};
|
||||
final int expectedRegistrationId = device.getRegistrationId(serviceIdentifier.identityType());
|
||||
|
||||
return registrationId != expectedRegistrationId;
|
||||
})
|
||||
|
@ -308,14 +330,13 @@ public class MessageSender {
|
|||
|
||||
private static void validateIndividualMessageContentLength(final Iterable<Envelope> messages,
|
||||
final boolean isSyncMessage,
|
||||
final boolean isStory,
|
||||
@Nullable final String userAgent) throws MessageTooLargeException {
|
||||
|
||||
for (final Envelope message : messages) {
|
||||
MessageSender.validateContentLength(message.getContent().size(),
|
||||
false,
|
||||
isSyncMessage,
|
||||
isStory,
|
||||
message.getStory(),
|
||||
userAgent);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.push;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.util.function.Tuple2;
|
||||
import reactor.util.function.Tuples;
|
||||
|
||||
public class MessageUtil {
|
||||
|
||||
public static final int DEFAULT_MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
|
||||
|
||||
private MessageUtil() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds account records for all recipients named in the given multi-recipient manager. Note that the returned map
|
||||
* of recipients to account records will omit entries for recipients that could not be resolved to active accounts;
|
||||
* callers that require full resolution should check for a missing entries and take appropriate action.
|
||||
*
|
||||
* @param accountsManager the {@code AccountsManager} instance to use to find account records
|
||||
* @param multiRecipientMessage the message for which to resolve recipients
|
||||
*
|
||||
* @return a map of recipients to account records
|
||||
*
|
||||
* @see #getUnresolvedRecipients(SealedSenderMultiRecipientMessage, Map)
|
||||
*/
|
||||
public static Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolveRecipients(
|
||||
final AccountsManager accountsManager,
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage) {
|
||||
|
||||
return resolveRecipients(accountsManager, multiRecipientMessage, DEFAULT_MAX_FETCH_ACCOUNT_CONCURRENCY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds account records for all recipients named in the given multi-recipient manager. Note that the returned map
|
||||
* of recipients to account records will omit entries for recipients that could not be resolved to active accounts;
|
||||
* callers that require full resolution should check for a missing entries and take appropriate action.
|
||||
*
|
||||
* @param accountsManager the {@code AccountsManager} instance to use to find account records
|
||||
* @param multiRecipientMessage the message for which to resolve recipients
|
||||
* @param maxFetchAccountConcurrency the maximum number of concurrent account-retrieval operations
|
||||
*
|
||||
* @return a map of recipients to account records
|
||||
*
|
||||
* @see #getUnresolvedRecipients(SealedSenderMultiRecipientMessage, Map)
|
||||
*/
|
||||
public static Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolveRecipients(
|
||||
final AccountsManager accountsManager,
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage,
|
||||
final int maxFetchAccountConcurrency) {
|
||||
|
||||
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
|
||||
.flatMap(serviceIdAndRecipient -> {
|
||||
final ServiceIdentifier serviceIdentifier =
|
||||
ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
|
||||
|
||||
return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
|
||||
.flatMap(Mono::justOrEmpty)
|
||||
.map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account));
|
||||
}, maxFetchAccountConcurrency)
|
||||
.collectMap(Tuple2::getT1, Tuple2::getT2)
|
||||
.blockOptional()
|
||||
.orElse(Collections.emptyMap());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a list of recipients missing from the map of resolved recipients for a multi-recipient message.
|
||||
*
|
||||
* @param multiRecipientMessage the multi-recipient message
|
||||
* @param resolvedRecipients the map of resolved recipients to check for missing entries
|
||||
*
|
||||
* @return a list of {@code ServiceIdentifiers} belonging to multi-recipient message recipients that are not present
|
||||
* in the given map of {@code resolvedRecipients}
|
||||
*/
|
||||
public static List<ServiceIdentifier> getUnresolvedRecipients(
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage,
|
||||
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients) {
|
||||
|
||||
return multiRecipientMessage.getRecipients().entrySet().stream()
|
||||
.filter(entry -> !resolvedRecipients.containsKey(entry.getValue()))
|
||||
.map(entry -> ServiceIdentifier.fromLibsignal(entry.getKey()))
|
||||
.toList();
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a multi-recipient message contains duplicate recipients.
|
||||
*
|
||||
* @param multiRecipientMessage the message to check for duplicate recipients
|
||||
*
|
||||
* @return {@code true} if the message contains duplicate recipients or {@code false} otherwise
|
||||
*/
|
||||
public static boolean hasDuplicateDevices(final SealedSenderMultiRecipientMessage multiRecipientMessage) {
|
||||
final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID + 1];
|
||||
|
||||
for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) {
|
||||
if (recipient.getDevices().length == 1) {
|
||||
// A recipient can't have repeated devices if they only have one device
|
||||
continue;
|
||||
}
|
||||
|
||||
Arrays.fill(usedDeviceIds, false);
|
||||
|
||||
for (final byte deviceId : recipient.getDevices()) {
|
||||
if (usedDeviceIds[deviceId]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
usedDeviceIds[deviceId] = true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
|
@ -17,6 +17,7 @@ import java.util.function.BiConsumer;
|
|||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
@ -28,6 +29,9 @@ public class PushNotificationManager {
|
|||
private final APNSender apnSender;
|
||||
private final FcmSender fcmSender;
|
||||
private final PushNotificationScheduler pushNotificationScheduler;
|
||||
private final ExperimentEnrollmentManager experimentEnrollmentManager;
|
||||
|
||||
public static final String SCHEDULE_LOW_URGENCY_FCM_PUSH_EXPERIMENT = "scheduleLowUregencyFcmPush";
|
||||
|
||||
private static final String SENT_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "sentPushNotification");
|
||||
private static final String FAILED_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "failedPushNotification");
|
||||
|
@ -38,12 +42,14 @@ public class PushNotificationManager {
|
|||
public PushNotificationManager(final AccountsManager accountsManager,
|
||||
final APNSender apnSender,
|
||||
final FcmSender fcmSender,
|
||||
final PushNotificationScheduler pushNotificationScheduler) {
|
||||
final PushNotificationScheduler pushNotificationScheduler,
|
||||
final ExperimentEnrollmentManager experimentEnrollmentManager) {
|
||||
|
||||
this.accountsManager = accountsManager;
|
||||
this.apnSender = apnSender;
|
||||
this.fcmSender = fcmSender;
|
||||
this.pushNotificationScheduler = pushNotificationScheduler;
|
||||
this.experimentEnrollmentManager = experimentEnrollmentManager;
|
||||
}
|
||||
|
||||
public CompletableFuture<Optional<SendPushNotificationResult>> sendNewMessageNotification(final Account destination, final byte destinationDeviceId, final boolean urgent) throws NotPushRegisteredException {
|
||||
|
@ -101,11 +107,11 @@ public class PushNotificationManager {
|
|||
|
||||
@VisibleForTesting
|
||||
CompletableFuture<Optional<SendPushNotificationResult>> sendNotification(final PushNotification pushNotification) {
|
||||
if (pushNotification.tokenType() == PushNotification.TokenType.APN && !pushNotification.urgent()) {
|
||||
// APNs imposes a per-device limit on background push notifications; schedule a notification for some time in the
|
||||
// future (possibly even now!) rather than sending a notification directly
|
||||
if (shouldScheduleNotification(pushNotification)) {
|
||||
// Schedule a notification for some time in the future (possibly even now!) rather than sending a notification
|
||||
// directly
|
||||
return pushNotificationScheduler
|
||||
.scheduleBackgroundApnsNotification(pushNotification.destination(), pushNotification.destinationDevice())
|
||||
.scheduleBackgroundNotification(pushNotification.tokenType(), pushNotification.destination(), pushNotification.destinationDevice())
|
||||
.whenComplete(logErrors())
|
||||
.thenApply(ignored -> Optional.<SendPushNotificationResult>empty())
|
||||
.toCompletableFuture();
|
||||
|
@ -150,6 +156,16 @@ public class PushNotificationManager {
|
|||
.thenApply(Optional::of);
|
||||
}
|
||||
|
||||
private boolean shouldScheduleNotification(final PushNotification pushNotification) {
|
||||
return !pushNotification.urgent() && switch (pushNotification.tokenType()) {
|
||||
// APNs imposes a per-device limit on background push notifications
|
||||
case APN -> true;
|
||||
case FCM -> experimentEnrollmentManager.isEnrolled(
|
||||
pushNotification.destination().getUuid(),
|
||||
SCHEDULE_LOW_URGENCY_FCM_PUSH_EXPERIMENT);
|
||||
};
|
||||
}
|
||||
|
||||
private static <T> BiConsumer<T, Throwable> logErrors() {
|
||||
return (ignored, throwable) -> {
|
||||
if (throwable != null) {
|
||||
|
|
|
@ -13,7 +13,6 @@ import io.lettuce.core.Range;
|
|||
import io.lettuce.core.ScriptOutputType;
|
||||
import io.lettuce.core.SetArgs;
|
||||
import io.lettuce.core.cluster.SlotHash;
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.io.IOException;
|
||||
import java.time.Clock;
|
||||
|
@ -44,14 +43,15 @@ public class PushNotificationScheduler implements Managed {
|
|||
|
||||
private static final Logger logger = LoggerFactory.getLogger(PushNotificationScheduler.class);
|
||||
|
||||
private static final String PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN";
|
||||
private static final String PENDING_BACKGROUND_APN_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN";
|
||||
private static final String PENDING_BACKGROUND_FCM_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_FCM";
|
||||
private static final String LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX = "LAST_BACKGROUND_NOTIFICATION";
|
||||
private static final String PENDING_DELAYED_NOTIFICATIONS_KEY_PREFIX = "DELAYED";
|
||||
|
||||
@VisibleForTesting
|
||||
static final String NEXT_SLOT_TO_PROCESS_KEY = "pending_notification_next_slot";
|
||||
|
||||
private static final Counter BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER = Metrics.counter(name(PushNotificationScheduler.class, "backgroundNotification", "scheduled"));
|
||||
private static final String BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER_NAME = name(PushNotificationScheduler.class, "backgroundNotification", "scheduled");
|
||||
private static final String BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME = name(PushNotificationScheduler.class, "backgroundNotification", "sent");
|
||||
|
||||
private static final String DELAYED_NOTIFICATION_SCHEDULED_COUNTER_NAME = name(PushNotificationScheduler.class, "delayedNotificationScheduled");
|
||||
|
@ -65,7 +65,7 @@ public class PushNotificationScheduler implements Managed {
|
|||
private final FaultTolerantRedisClusterClient pushSchedulingCluster;
|
||||
private final Clock clock;
|
||||
|
||||
private final ClusterLuaScript scheduleBackgroundApnsNotificationScript;
|
||||
private final ClusterLuaScript scheduleBackgroundNotificationScript;
|
||||
|
||||
private final Thread[] workerThreads;
|
||||
|
||||
|
@ -103,15 +103,18 @@ public class PushNotificationScheduler implements Managed {
|
|||
final int slot = (int) (pushSchedulingCluster.withCluster(connection ->
|
||||
connection.sync().incr(NEXT_SLOT_TO_PROCESS_KEY)) % SlotHash.SLOT_COUNT);
|
||||
|
||||
return processScheduledBackgroundApnsNotifications(slot) + processScheduledDelayedNotifications(slot);
|
||||
return processScheduledBackgroundNotifications(PushNotification.TokenType.APN, slot)
|
||||
+ processScheduledBackgroundNotifications(PushNotification.TokenType.FCM, slot)
|
||||
+ processScheduledDelayedNotifications(slot);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
long processScheduledBackgroundApnsNotifications(final int slot) {
|
||||
return processScheduledNotifications(getPendingBackgroundApnsNotificationQueueKey(slot),
|
||||
PushNotificationScheduler.this::sendBackgroundApnsNotification);
|
||||
long processScheduledBackgroundNotifications(PushNotification.TokenType tokenType, final int slot) {
|
||||
return processScheduledNotifications(getPendingBackgroundNotificationQueueKey(tokenType, slot),
|
||||
(account, device) -> sendBackgroundNotification(tokenType, account, device));
|
||||
}
|
||||
|
||||
|
||||
@VisibleForTesting
|
||||
long processScheduledDelayedNotifications(final int slot) {
|
||||
return processScheduledNotifications(getDelayedNotificationQueueKey(slot),
|
||||
|
@ -172,7 +175,7 @@ public class PushNotificationScheduler implements Managed {
|
|||
this.pushSchedulingCluster = pushSchedulingCluster;
|
||||
this.clock = clock;
|
||||
|
||||
this.scheduleBackgroundApnsNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster,
|
||||
this.scheduleBackgroundNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster,
|
||||
"lua/apn/schedule_background_notification.lua", ScriptOutputType.VALUE);
|
||||
|
||||
this.workerThreads = new Thread[dedicatedProcessThreadCount];
|
||||
|
@ -183,23 +186,21 @@ public class PushNotificationScheduler implements Managed {
|
|||
}
|
||||
|
||||
/**
|
||||
* Schedule a background APNs notification to be sent some time in the future.
|
||||
* Schedule a background push notification to be sent some time in the future.
|
||||
*
|
||||
* @return A CompletionStage that completes when the notification has successfully been scheduled
|
||||
*
|
||||
* @throws IllegalArgumentException if the given device does not have an APNs token
|
||||
* @throws IllegalArgumentException if the given device does not have a push token
|
||||
*/
|
||||
public CompletionStage<Void> scheduleBackgroundApnsNotification(final Account account, final Device device) {
|
||||
if (StringUtils.isBlank(device.getApnId())) {
|
||||
throw new IllegalArgumentException("Device must have an APNs token");
|
||||
public CompletionStage<Void> scheduleBackgroundNotification(final PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
if (StringUtils.isBlank(getPushToken(tokenType, device))) {
|
||||
throw new IllegalArgumentException("Device must have an " + tokenType + " token");
|
||||
}
|
||||
|
||||
BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER.increment();
|
||||
|
||||
return scheduleBackgroundApnsNotificationScript.executeAsync(
|
||||
Metrics.counter(BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER_NAME, "type", tokenType.name()).increment();
|
||||
return scheduleBackgroundNotificationScript.executeAsync(
|
||||
List.of(
|
||||
getLastBackgroundApnsNotificationTimestampKey(account, device),
|
||||
getPendingBackgroundApnsNotificationQueueKey(account, device)),
|
||||
getLastBackgroundNotificationTimestampKey(account, device),
|
||||
getPendingBackgroundNotificationQueueKey(tokenType, account, device)),
|
||||
List.of(
|
||||
encodeAciAndDeviceId(account, device),
|
||||
String.valueOf(clock.millis()),
|
||||
|
@ -236,14 +237,15 @@ public class PushNotificationScheduler implements Managed {
|
|||
*/
|
||||
public CompletionStage<Void> cancelScheduledNotifications(Account account, Device device) {
|
||||
return CompletableFuture.allOf(
|
||||
cancelBackgroundApnsNotifications(account, device),
|
||||
cancelBackgroundNotifications(PushNotification.TokenType.FCM, account, device),
|
||||
cancelBackgroundNotifications(PushNotification.TokenType.APN, account, device),
|
||||
cancelDelayedNotifications(account, device));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
CompletableFuture<Void> cancelBackgroundApnsNotifications(final Account account, final Device device) {
|
||||
CompletableFuture<Void> cancelBackgroundNotifications(final PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
return pushSchedulingCluster.withCluster(connection -> connection.async()
|
||||
.zrem(getPendingBackgroundApnsNotificationQueueKey(account, device), encodeAciAndDeviceId(account, device)))
|
||||
.zrem(getPendingBackgroundNotificationQueueKey(tokenType, account, device), encodeAciAndDeviceId(account, device)))
|
||||
.thenRun(Util.NOOP)
|
||||
.toCompletableFuture();
|
||||
}
|
||||
|
@ -276,17 +278,23 @@ public class PushNotificationScheduler implements Managed {
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
CompletableFuture<Void> sendBackgroundApnsNotification(final Account account, final Device device) {
|
||||
if (StringUtils.isBlank(device.getApnId())) {
|
||||
CompletableFuture<Void> sendBackgroundNotification(PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
final String pushToken = getPushToken(tokenType, device);
|
||||
if (StringUtils.isBlank(pushToken)) {
|
||||
return CompletableFuture.completedFuture(null);
|
||||
}
|
||||
|
||||
final PushNotificationSender sender = switch (tokenType) {
|
||||
case FCM -> fcmSender;
|
||||
case APN -> apnSender;
|
||||
};
|
||||
|
||||
// It's okay for the "last notification" timestamp to expire after the "cooldown" period has elapsed; a missing
|
||||
// timestamp and a timestamp older than the period are functionally equivalent.
|
||||
return pushSchedulingCluster.withCluster(connection -> connection.async().set(
|
||||
getLastBackgroundApnsNotificationTimestampKey(account, device),
|
||||
getLastBackgroundNotificationTimestampKey(account, device),
|
||||
String.valueOf(clock.millis()), new SetArgs().ex(BACKGROUND_NOTIFICATION_PERIOD)))
|
||||
.thenCompose(ignored -> apnSender.sendNotification(new PushNotification(device.getApnId(), PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, false)))
|
||||
.thenCompose(ignored -> sender.sendNotification(new PushNotification(pushToken, tokenType, PushNotification.NotificationType.NOTIFICATION, null, account, device, false)))
|
||||
.thenAccept(response -> Metrics.counter(BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME,
|
||||
ACCEPTED_TAG, String.valueOf(response.accepted()))
|
||||
.increment())
|
||||
|
@ -321,6 +329,10 @@ public class PushNotificationScheduler implements Managed {
|
|||
|
||||
@VisibleForTesting
|
||||
static String encodeAciAndDeviceId(final Account account, final Device device) {
|
||||
// Note: This does not include a device registration id. If a device is unlinked and a new device is linked with
|
||||
// the original device's id, the new device might get the old device's scheduled push, or the new device might
|
||||
// delay its own push because the old device had a recent push. An extra or delayed background push is harmless,
|
||||
// so this is okay.
|
||||
return account.getUuid() + ":" + device.getId();
|
||||
}
|
||||
|
||||
|
@ -351,15 +363,19 @@ public class PushNotificationScheduler implements Managed {
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static String getPendingBackgroundApnsNotificationQueueKey(final Account account, final Device device) {
|
||||
return getPendingBackgroundApnsNotificationQueueKey(SlotHash.getSlot(encodeAciAndDeviceId(account, device)));
|
||||
static String getPendingBackgroundNotificationQueueKey(final PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
return getPendingBackgroundNotificationQueueKey(tokenType, SlotHash.getSlot(encodeAciAndDeviceId(account, device)));
|
||||
}
|
||||
|
||||
private static String getPendingBackgroundApnsNotificationQueueKey(final int slot) {
|
||||
return PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
|
||||
private static String getPendingBackgroundNotificationQueueKey(final PushNotification.TokenType tokenType, final int slot) {
|
||||
final String prefix = switch (tokenType) {
|
||||
case APN -> PENDING_BACKGROUND_APN_NOTIFICATIONS_KEY_PREFIX;
|
||||
case FCM -> PENDING_BACKGROUND_FCM_NOTIFICATIONS_KEY_PREFIX;
|
||||
};
|
||||
return prefix + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
|
||||
}
|
||||
|
||||
private static String getLastBackgroundApnsNotificationTimestampKey(final Account account, final Device device) {
|
||||
private static String getLastBackgroundNotificationTimestampKey(final Account account, final Device device) {
|
||||
return LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX + "::{" + encodeAciAndDeviceId(account, device) + "}";
|
||||
}
|
||||
|
||||
|
@ -376,15 +392,15 @@ public class PushNotificationScheduler implements Managed {
|
|||
Optional<Instant> getLastBackgroundApnsNotificationTimestamp(final Account account, final Device device) {
|
||||
return Optional.ofNullable(
|
||||
pushSchedulingCluster.withCluster(connection ->
|
||||
connection.sync().get(getLastBackgroundApnsNotificationTimestampKey(account, device))))
|
||||
connection.sync().get(getLastBackgroundNotificationTimestampKey(account, device))))
|
||||
.map(timestampString -> Instant.ofEpochMilli(Long.parseLong(timestampString)));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Optional<Instant> getNextScheduledBackgroundApnsNotificationTimestamp(final Account account, final Device device) {
|
||||
Optional<Instant> getNextScheduledBackgroundNotificationTimestamp(PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
return Optional.ofNullable(
|
||||
pushSchedulingCluster.withCluster(connection ->
|
||||
connection.sync().zscore(getPendingBackgroundApnsNotificationQueueKey(account, device),
|
||||
connection.sync().zscore(getPendingBackgroundNotificationQueueKey(tokenType, account, device),
|
||||
encodeAciAndDeviceId(account, device))))
|
||||
.map(timestamp -> Instant.ofEpochMilli(timestamp.longValue()));
|
||||
}
|
||||
|
@ -407,4 +423,11 @@ public class PushNotificationScheduler implements Managed {
|
|||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
private static String getPushToken(final PushNotification.TokenType tokenType, final Device device) {
|
||||
return switch (tokenType) {
|
||||
case FCM -> device.getGcmId();
|
||||
case APN -> device.getApnId();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.push;
|
|||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -60,16 +61,15 @@ public class ReceiptSender {
|
|||
.collect(Collectors.toMap(Device::getId, ignored -> message));
|
||||
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
|
||||
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
|
||||
}));
|
||||
.collect(Collectors.toMap(Device::getId,
|
||||
device -> device.getRegistrationId(destinationIdentifier.identityType())));
|
||||
|
||||
try {
|
||||
messageSender.sendMessages(destinationAccount,
|
||||
destinationIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId,
|
||||
Optional.empty(),
|
||||
UserAgentTagUtil.SERVER_UA);
|
||||
} catch (final Exception e) {
|
||||
logger.warn("Could not send delivery receipt", e);
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue