Compare commits
75 Commits
v20250409.
...
main
Author | SHA1 | Date |
---|---|---|
![]() |
74ee1c8c4f | |
![]() |
35604cf151 | |
![]() |
aafcd63a9f | |
![]() |
43a534f05b | |
![]() |
9ec66dac7f | |
![]() |
13fc0ffbca | |
![]() |
93ba6616d1 | |
![]() |
a4b98f38a6 | |
![]() |
b95d08aaea | |
![]() |
b400d49e77 | |
![]() |
e43487155f | |
![]() |
dee3723d97 | |
![]() |
b7e986f43c | |
![]() |
664fb23e97 | |
![]() |
714ef128a1 | |
![]() |
7cf3fce624 | |
![]() |
0cc5431867 | |
![]() |
b8d5b2c8ea | |
![]() |
894ca6d290 | |
![]() |
847b25f695 | |
![]() |
703a05cb15 | |
![]() |
30c194c557 | |
![]() |
cc7b030a41 | |
![]() |
7a91c4d5b7 | |
![]() |
287da6e7e3 | |
![]() |
7cf89764e7 | |
![]() |
d316c72beb | |
![]() |
82d187cc45 | |
![]() |
0c240d21d2 | |
![]() |
009252c831 | |
![]() |
0c1146aaa5 | |
![]() |
4fd06594a0 | |
![]() |
4e175be88f | |
![]() |
771a700acd | |
![]() |
e9bd5da2c3 | |
![]() |
f64244f33a | |
![]() |
ed1417c3e3 | |
![]() |
0398e02690 | |
![]() |
e285bf1a52 | |
![]() |
2c9219d4f7 | |
![]() |
26b3b75054 | |
![]() |
cdb651b68f | |
![]() |
91a36f4421 | |
![]() |
21c1d71551 | |
![]() |
38befdb260 | |
![]() |
63c79173b2 | |
![]() |
d2ad003891 | |
![]() |
eb89773819 | |
![]() |
403abd84f6 | |
![]() |
f62f79c95c | |
![]() |
144c4c9223 | |
![]() |
ab4fc4f459 | |
![]() |
51569ce0a5 | |
![]() |
f191c68efc | |
![]() |
bb8ce6d981 | |
![]() |
e0ee75e0d0 | |
![]() |
1ef3a230a1 | |
![]() |
b1805d4bf1 | |
![]() |
cac979c7fd | |
![]() |
4072dcdda5 | |
![]() |
ed382fff6d | |
![]() |
23bb8277d5 | |
![]() |
8099d6465c | |
![]() |
28a0b9e84e | |
![]() |
9287aaf7ce | |
![]() |
0585f862cb | |
![]() |
7cac6f6f72 | |
![]() |
57be4d798b | |
![]() |
05c74f1997 | |
![]() |
f5e49b6db7 | |
![]() |
3c40e72d27 | |
![]() |
2f2ae7cec5 | |
![]() |
b236b53dc3 | |
![]() |
eb71e30046 | |
![]() |
aa5fd52302 |
|
@ -1,6 +1,7 @@
|
|||
name: Service CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches-ignore:
|
||||
- gh-pages
|
||||
|
|
|
@ -72,14 +72,12 @@ public final class Operations {
|
|||
}
|
||||
|
||||
public static TestUser newRegisteredUser(final String number) {
|
||||
final byte[] registrationPassword = randomBytes(32);
|
||||
final byte[] registrationPassword = populateRandomRecoveryPassword(number);
|
||||
final String accountPassword = Base64.getEncoder().encodeToString(randomBytes(32));
|
||||
|
||||
final TestUser user = TestUser.create(number, accountPassword, registrationPassword);
|
||||
final AccountAttributes accountAttributes = user.accountAttributes();
|
||||
|
||||
INTEGRATION_TOOLS.populateRecoveryPassword(number, registrationPassword).join();
|
||||
|
||||
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
|
||||
|
@ -108,6 +106,7 @@ public final class Operations {
|
|||
}
|
||||
|
||||
public record PrescribedVerificationNumber(String number, String verificationCode) {}
|
||||
|
||||
public static PrescribedVerificationNumber prescribedVerificationNumber() {
|
||||
return new PrescribedVerificationNumber(
|
||||
CONFIG.prescribedRegistrationNumber(),
|
||||
|
@ -123,6 +122,13 @@ public final class Operations {
|
|||
.orElseThrow(() -> new RuntimeException("push challenge not found for the verification session"));
|
||||
}
|
||||
|
||||
public static byte[] populateRandomRecoveryPassword(final String number) {
|
||||
final byte[] recoveryPassword = randomBytes(32);
|
||||
INTEGRATION_TOOLS.populateRecoveryPassword(number, recoveryPassword).join();
|
||||
|
||||
return recoveryPassword;
|
||||
}
|
||||
|
||||
public static <T> T sendEmptyRequestAuthenticated(
|
||||
final String endpoint,
|
||||
final String method,
|
||||
|
@ -329,15 +335,15 @@ public final class Operations {
|
|||
}
|
||||
}
|
||||
|
||||
private static ECSignedPreKey generateSignedECPreKey(long id, final ECKeyPair identityKeyPair) {
|
||||
public static ECSignedPreKey generateSignedECPreKey(final long id, final ECKeyPair identityKeyPair) {
|
||||
final ECPublicKey pubKey = Curve.generateKeyPair().getPublicKey();
|
||||
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
|
||||
return new ECSignedPreKey(id, pubKey, sig);
|
||||
final byte[] signature = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
|
||||
return new ECSignedPreKey(id, pubKey, signature);
|
||||
}
|
||||
|
||||
private static KEMSignedPreKey generateSignedKEMPreKey(long id, final ECKeyPair identityKeyPair) {
|
||||
public static KEMSignedPreKey generateSignedKEMPreKey(final long id, final ECKeyPair identityKeyPair) {
|
||||
final KEMPublicKey pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey();
|
||||
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
|
||||
return new KEMSignedPreKey(id, pubKey, sig);
|
||||
final byte[] signature = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
|
||||
return new KEMSignedPreKey(id, pubKey, signature);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,27 +6,35 @@
|
|||
package org.signal.integration;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.apache.http.HttpStatus;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.usernames.BaseUsernameException;
|
||||
import org.signal.libsignal.usernames.Username;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountIdentifierResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.ConfirmUsernameHashRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.UsernameHashResponse;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
||||
public class AccountTest {
|
||||
|
||||
@Test
|
||||
public void testCreateAccount() throws Exception {
|
||||
public void testCreateAccount() {
|
||||
final TestUser user = Operations.newRegisteredUser("+19995550101");
|
||||
try {
|
||||
final Pair<Integer, AccountIdentityResponse> execute = Operations.apiGet("/v1/accounts/whoami")
|
||||
|
@ -39,7 +47,7 @@ public class AccountTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCreateAccountAtomic() throws Exception {
|
||||
public void testCreateAccountAtomic() {
|
||||
final TestUser user = Operations.newRegisteredUser("+19995550201");
|
||||
try {
|
||||
final Pair<Integer, AccountIdentityResponse> execute = Operations.apiGet("/v1/accounts/whoami")
|
||||
|
@ -51,6 +59,33 @@ public class AccountTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void changePhoneNumber() {
|
||||
final TestUser user = Operations.newRegisteredUser("+19995550301");
|
||||
final String targetNumber = "+19995550302";
|
||||
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
|
||||
final ChangeNumberRequest changeNumberRequest = new ChangeNumberRequest(null,
|
||||
Operations.populateRandomRecoveryPassword(targetNumber),
|
||||
targetNumber,
|
||||
null,
|
||||
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
|
||||
Collections.emptyList(),
|
||||
Map.of(Device.PRIMARY_ID, Operations.generateSignedECPreKey(1, pniIdentityKeyPair)),
|
||||
Map.of(Device.PRIMARY_ID, Operations.generateSignedKEMPreKey(2, pniIdentityKeyPair)),
|
||||
Map.of(Device.PRIMARY_ID, 17));
|
||||
|
||||
final AccountIdentityResponse accountIdentityResponse =
|
||||
Operations.apiPut("/v2/accounts/number", changeNumberRequest)
|
||||
.authorized(user)
|
||||
.executeExpectSuccess(AccountIdentityResponse.class);
|
||||
|
||||
assertEquals(user.aciUuid(), accountIdentityResponse.uuid());
|
||||
assertNotEquals(user.pniUuid(), accountIdentityResponse.pni());
|
||||
assertEquals(targetNumber, accountIdentityResponse.number());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUsernameOperations() throws Exception {
|
||||
final TestUser user = Operations.newRegisteredUser("+19995550102");
|
||||
|
|
17
pom.xml
17
pom.xml
|
@ -46,7 +46,7 @@
|
|||
<!-- can be updated to latest version with Dropwizard 5 (Jetty 12); will then need to disable telemetry -->
|
||||
<dynamodblocal.version>2.2.1</dynamodblocal.version>
|
||||
<google-cloud-libraries.version>26.57.0</google-cloud-libraries.version>
|
||||
<grpc.version>1.69.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
|
||||
<grpc.version>1.70.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
|
||||
<gson.version>2.12.1</gson.version>
|
||||
<!-- several libraries (AWS, Google Cloud) use Apache http components transitively, and we need to align them -->
|
||||
<httpcore.version>4.4.16</httpcore.version>
|
||||
|
@ -65,9 +65,9 @@
|
|||
<luajava.version>3.5.0</luajava.version>
|
||||
<micrometer.version>1.14.5</micrometer.version>
|
||||
<netty.version>4.1.119.Final</netty.version>
|
||||
<!-- Must be greater than or equal to the value from Google libraries-bom
|
||||
since some of its libraries generate code. See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
|
||||
<protobuf.version>3.25.5</protobuf.version>
|
||||
<!-- Must be less than or equal to the value from Google libraries-bom which controls the protobuf runtime version.
|
||||
See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
|
||||
<protoc.version>4.29.4</protoc.version>
|
||||
<pushy.version>0.15.4</pushy.version>
|
||||
<reactive.grpc.version>1.2.4</reactive.grpc.version>
|
||||
<reactor-bom.version>2024.0.4</reactor-bom.version> <!-- 3.7.4, see https://github.com/reactor/reactor#bom-versioning-scheme -->
|
||||
|
@ -127,7 +127,7 @@
|
|||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.cloud</groupId>
|
||||
<artifactId>libraries-bom-protobuf3</artifactId>
|
||||
<artifactId>libraries-bom</artifactId>
|
||||
<version>${google-cloud-libraries.version}</version>
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
|
@ -175,11 +175,6 @@
|
|||
<artifactId>pushy-dropwizard-metrics-listener</artifactId>
|
||||
<version>${pushy.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>${protobuf.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.googlecode.libphonenumber</groupId>
|
||||
<artifactId>libphonenumber</artifactId>
|
||||
|
@ -443,7 +438,7 @@
|
|||
<version>0.6.1</version>
|
||||
<configuration>
|
||||
<checkStaleness>false</checkStaleness>
|
||||
<protocArtifact>com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}</protocArtifact>
|
||||
<protocArtifact>com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}</protocArtifact>
|
||||
<pluginId>grpc-java</pluginId>
|
||||
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
|
||||
|
||||
|
|
|
@ -482,7 +482,8 @@ turn:
|
|||
- turn:%s
|
||||
- turn:%s:80?transport=tcp
|
||||
- turns:%s:443?transport=tcp
|
||||
ttl: 86400
|
||||
requestedCredentialTtl: PT24H
|
||||
clientCredentialTtl: PT12H
|
||||
hostname: turn.cloudflare.example.com
|
||||
numHttpClients: 1
|
||||
|
||||
|
|
|
@ -407,10 +407,6 @@ public class WhisperServerConfiguration extends Configuration {
|
|||
return rateLimitersCluster;
|
||||
}
|
||||
|
||||
public Map<String, RateLimiterConfig> getLimitsConfiguration() {
|
||||
return limits;
|
||||
}
|
||||
|
||||
public FcmConfiguration getFcmConfiguration() {
|
||||
return fcm;
|
||||
}
|
||||
|
|
|
@ -154,7 +154,7 @@ import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
|||
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseWebSocketTunnelServer;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.websocket.NoiseWebSocketTunnelServer;
|
||||
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
|
||||
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
|
||||
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
|
||||
|
@ -269,6 +269,7 @@ import org.whispersystems.textsecuregcm.workers.IdleDeviceNotificationSchedulerF
|
|||
import org.whispersystems.textsecuregcm.workers.MessagePersisterServiceCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.NotifyIdleDevicesCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.ProcessScheduledJobsServiceCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.RegenerateAccountConstraintDataCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.RemoveExpiredBackupsCommand;
|
||||
import org.whispersystems.textsecuregcm.workers.RemoveExpiredLinkedDevicesCommand;
|
||||
|
@ -335,6 +336,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
bootstrap.addCommand(new ProcessScheduledJobsServiceCommand("process-idle-device-notification-jobs",
|
||||
"Processes scheduled jobs to send notifications to idle devices",
|
||||
new IdleDeviceNotificationSchedulerFactory()));
|
||||
|
||||
bootstrap.addCommand(new RegenerateAccountConstraintDataCommand());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -633,11 +636,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
PushNotificationScheduler pushNotificationScheduler = new PushNotificationScheduler(pushSchedulerCluster,
|
||||
apnSender, fcmSender, accountsManager, 0, 0);
|
||||
PushNotificationManager pushNotificationManager =
|
||||
new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler);
|
||||
new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler, experimentEnrollmentManager);
|
||||
WebSocketConnectionEventManager webSocketConnectionEventManager =
|
||||
new WebSocketConnectionEventManager(accountsManager, pushNotificationManager, messagesCluster, clientEventExecutor, asyncOperationQueueingExecutor);
|
||||
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(),
|
||||
dynamicConfigurationManager, rateLimitersCluster);
|
||||
RateLimiters rateLimiters = RateLimiters.create(dynamicConfigurationManager, rateLimitersCluster);
|
||||
ProvisioningManager provisioningManager = new ProvisioningManager(pubsubClient);
|
||||
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
|
||||
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
|
||||
|
@ -668,12 +670,13 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
|
||||
final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
|
||||
|
||||
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager);
|
||||
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager);
|
||||
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
|
||||
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
|
||||
config.getTurnConfiguration().cloudflare().apiToken().value(),
|
||||
config.getTurnConfiguration().cloudflare().endpoint(),
|
||||
config.getTurnConfiguration().cloudflare().ttl(),
|
||||
config.getTurnConfiguration().cloudflare().requestedCredentialTtl(),
|
||||
config.getTurnConfiguration().cloudflare().clientCredentialTtl(),
|
||||
config.getTurnConfiguration().cloudflare().urls(),
|
||||
config.getTurnConfiguration().cloudflare().urlsWithIps(),
|
||||
config.getTurnConfiguration().cloudflare().hostname(),
|
||||
|
@ -693,7 +696,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager,
|
||||
pushChallengeDynamoDb);
|
||||
|
||||
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
|
||||
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, Clock.systemUTC());
|
||||
|
||||
HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build();
|
||||
FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients()
|
||||
|
@ -987,7 +990,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
webSocketEnvironment.setConnectListener(
|
||||
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
|
||||
pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor,
|
||||
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor));
|
||||
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager));
|
||||
webSocketEnvironment.jersey()
|
||||
.register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager));
|
||||
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));
|
||||
|
|
|
@ -15,6 +15,7 @@ import java.net.Inet6Address;
|
|||
import java.net.URI;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.time.Duration;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletionException;
|
||||
|
@ -39,16 +40,18 @@ public class CloudflareTurnCredentialsManager {
|
|||
private final List<String> cloudflareTurnUrls;
|
||||
private final List<String> cloudflareTurnUrlsWithIps;
|
||||
private final String cloudflareTurnHostname;
|
||||
private final HttpRequest request;
|
||||
private final HttpRequest getCredentialsRequest;
|
||||
|
||||
private final FaultTolerantHttpClient cloudflareTurnClient;
|
||||
private final DnsNameResolver dnsNameResolver;
|
||||
|
||||
record CredentialRequest(long ttl) {}
|
||||
private final Duration clientCredentialTtl;
|
||||
|
||||
record CloudflareTurnResponse(IceServer iceServers) {
|
||||
private record CredentialRequest(long ttl) {}
|
||||
|
||||
record IceServer(
|
||||
private record CloudflareTurnResponse(IceServer iceServers) {
|
||||
|
||||
private record IceServer(
|
||||
String username,
|
||||
String credential,
|
||||
List<String> urls) {
|
||||
|
@ -56,10 +59,17 @@ public class CloudflareTurnCredentialsManager {
|
|||
}
|
||||
|
||||
public CloudflareTurnCredentialsManager(final String cloudflareTurnApiToken,
|
||||
final String cloudflareTurnEndpoint, final long cloudflareTurnTtl, final List<String> cloudflareTurnUrls,
|
||||
final List<String> cloudflareTurnUrlsWithIps, final String cloudflareTurnHostname,
|
||||
final int cloudflareTurnNumHttpClients, final CircuitBreakerConfiguration circuitBreaker,
|
||||
final ExecutorService executor, final RetryConfiguration retry, final ScheduledExecutorService retryExecutor,
|
||||
final String cloudflareTurnEndpoint,
|
||||
final Duration requestedCredentialTtl,
|
||||
final Duration clientCredentialTtl,
|
||||
final List<String> cloudflareTurnUrls,
|
||||
final List<String> cloudflareTurnUrlsWithIps,
|
||||
final String cloudflareTurnHostname,
|
||||
final int cloudflareTurnNumHttpClients,
|
||||
final CircuitBreakerConfiguration circuitBreaker,
|
||||
final ExecutorService executor,
|
||||
final RetryConfiguration retry,
|
||||
final ScheduledExecutorService retryExecutor,
|
||||
final DnsNameResolver dnsNameResolver) {
|
||||
|
||||
this.cloudflareTurnClient = FaultTolerantHttpClient.newBuilder()
|
||||
|
@ -75,17 +85,24 @@ public class CloudflareTurnCredentialsManager {
|
|||
this.cloudflareTurnHostname = cloudflareTurnHostname;
|
||||
this.dnsNameResolver = dnsNameResolver;
|
||||
|
||||
final String credentialsRequestBody;
|
||||
|
||||
try {
|
||||
final String body = SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(cloudflareTurnTtl));
|
||||
this.request = HttpRequest.newBuilder()
|
||||
credentialsRequestBody =
|
||||
SystemMapper.jsonMapper().writeValueAsString(new CredentialRequest(requestedCredentialTtl.toSeconds()));
|
||||
} catch (final JsonProcessingException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
|
||||
// We repeat the same request to Cloudflare every time, so we can construct it once and re-use it
|
||||
this.getCredentialsRequest = HttpRequest.newBuilder()
|
||||
.uri(URI.create(cloudflareTurnEndpoint))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", String.format("Bearer %s", cloudflareTurnApiToken))
|
||||
.POST(HttpRequest.BodyPublishers.ofString(body))
|
||||
.POST(HttpRequest.BodyPublishers.ofString(credentialsRequestBody))
|
||||
.build();
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
|
||||
this.clientCredentialTtl = clientCredentialTtl;
|
||||
}
|
||||
|
||||
public TurnToken retrieveFromCloudflare() throws IOException {
|
||||
|
@ -105,7 +122,7 @@ public class CloudflareTurnCredentialsManager {
|
|||
final Timer.Sample sample = Timer.start();
|
||||
final HttpResponse<String> response;
|
||||
try {
|
||||
response = cloudflareTurnClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join();
|
||||
response = cloudflareTurnClient.sendAsync(getCredentialsRequest, HttpResponse.BodyHandlers.ofString()).join();
|
||||
sample.stop(Timer.builder(CREDENTIAL_FETCH_TIMER_NAME)
|
||||
.publishPercentileHistogram(true)
|
||||
.tags("outcome", "success")
|
||||
|
@ -130,6 +147,7 @@ public class CloudflareTurnCredentialsManager {
|
|||
return new TurnToken(
|
||||
cloudflareTurnResponse.iceServers().username(),
|
||||
cloudflareTurnResponse.iceServers().credential(),
|
||||
clientCredentialTtl.toSeconds(),
|
||||
cloudflareTurnUrls == null ? Collections.emptyList() : cloudflareTurnUrls,
|
||||
cloudflareTurnComposedUrls,
|
||||
cloudflareTurnHostname
|
||||
|
|
|
@ -5,13 +5,15 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.auth;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import java.util.List;
|
||||
import javax.annotation.Nonnull;
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.List;
|
||||
|
||||
public record TurnToken(
|
||||
String username,
|
||||
String password,
|
||||
@JsonProperty("ttl") long ttlSeconds,
|
||||
@Nonnull List<String> urls,
|
||||
@Nonnull List<String> urlsWithIps,
|
||||
@Nullable String hostname) {
|
||||
|
|
|
@ -1,34 +1,22 @@
|
|||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
private static final Metadata EMPTY_TRAILERS = new Metadata();
|
||||
|
||||
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
}
|
||||
|
||||
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) {
|
||||
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
|
||||
return grpcClientConnectionManager.getAuthenticatedDevice(localAddress);
|
||||
} else {
|
||||
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
|
||||
}
|
||||
}
|
||||
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call)
|
||||
throws ChannelNotFoundException {
|
||||
|
||||
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) {
|
||||
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
|
||||
return new ServerCall.Listener<>() {};
|
||||
return grpcClientConnectionManager.getAuthenticatedDevice(call);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,12 +3,17 @@ package org.whispersystems.textsecuregcm.auth.grpc;
|
|||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.Status;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
/**
|
||||
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
|
||||
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
|
||||
* device are closed with an {@code UNAUTHENTICATED} status.
|
||||
* device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined
|
||||
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
|
||||
* reject the call with a status of {@code UNAVAILABLE}.
|
||||
*/
|
||||
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
|
||||
|
||||
|
@ -21,8 +26,15 @@ public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInt
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
try {
|
||||
return getAuthenticatedDevice(call)
|
||||
.map(ignored -> closeAsUnauthenticated(call))
|
||||
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited
|
||||
// service via an authenticated connection, then that's actually a server configuration issue and not a
|
||||
// problem with the client's request.
|
||||
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
|
||||
.orElseGet(() -> next.startCall(call, headers));
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,12 +5,16 @@ import io.grpc.Contexts;
|
|||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.Status;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
/**
|
||||
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
|
||||
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
|
||||
* status.
|
||||
* status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed
|
||||
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
|
||||
*/
|
||||
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
|
||||
|
||||
|
@ -23,10 +27,17 @@ public class RequireAuthenticationInterceptor extends AbstractAuthenticationInte
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
try {
|
||||
return getAuthenticatedDevice(call)
|
||||
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
|
||||
call, headers, next))
|
||||
.orElseGet(() -> closeAsUnauthenticated(call));
|
||||
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
|
||||
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
|
||||
// problem with the client's request.
|
||||
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,16 +6,36 @@
|
|||
package org.whispersystems.textsecuregcm.configuration;
|
||||
|
||||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.AssertTrue;
|
||||
import jakarta.validation.constraints.NotBlank;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import jakarta.validation.constraints.Positive;
|
||||
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
|
||||
|
||||
/**
|
||||
* Configuration properties for Cloudflare TURN integration.
|
||||
*
|
||||
* @param apiToken the API token to use when requesting TURN tokens from Cloudflare
|
||||
* @param endpoint the URI of the Cloudflare API endpoint that vends TURN tokens
|
||||
* @param requestedCredentialTtl the lifetime of TURN tokens to request from Cloudflare
|
||||
* @param clientCredentialTtl the time clients may cache a TURN token; must be less than or equal to {@link #requestedCredentialTtl}
|
||||
* @param urls a collection of TURN URLs to include verbatim in responses to clients
|
||||
* @param urlsWithIps a collection of {@link String#format(String, Object...)} patterns to be populated with resolved IP
|
||||
* addresses for {@link #hostname} in responses to clients; each pattern must include a single
|
||||
* {@code %s} placeholder for the IP address
|
||||
* @param circuitBreaker a circuit breaker for requests to Cloudflare
|
||||
* @param retry a retry policy for requests to Cloudflare
|
||||
* @param hostname the hostname to resolve to IP addresses for use with {@link #urlsWithIps}; also transmitted to
|
||||
* clients for use as an SNI when connecting to pre-resolved hosts
|
||||
* @param numHttpClients the number of parallel HTTP clients to use to communicate with Cloudflare
|
||||
*/
|
||||
public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
|
||||
@NotBlank String endpoint,
|
||||
@NotBlank long ttl,
|
||||
@NotNull Duration requestedCredentialTtl,
|
||||
@NotNull Duration clientCredentialTtl,
|
||||
@NotNull @NotEmpty @Valid List<@NotBlank String> urls,
|
||||
@NotNull @NotEmpty @Valid List<@NotBlank String> urlsWithIps,
|
||||
@NotNull @Valid CircuitBreakerConfiguration circuitBreaker,
|
||||
|
@ -35,4 +55,9 @@ public record CloudflareTurnConfiguration(@NotNull SecretString apiToken,
|
|||
retry = new RetryConfiguration();
|
||||
}
|
||||
}
|
||||
|
||||
@AssertTrue
|
||||
public boolean isClientTtlShorterThanRequestedTtl() {
|
||||
return clientCredentialTtl.compareTo(requestedCredentialTtl) <= 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,10 +46,6 @@ public class DynamicConfiguration {
|
|||
@Valid
|
||||
DynamicMessagePersisterConfiguration messagePersister = new DynamicMessagePersisterConfiguration();
|
||||
|
||||
@JsonProperty
|
||||
@Valid
|
||||
DynamicRateLimitPolicy rateLimitPolicy = new DynamicRateLimitPolicy(false);
|
||||
|
||||
@JsonProperty
|
||||
@Valid
|
||||
DynamicRegistrationConfiguration registrationConfiguration = new DynamicRegistrationConfiguration(false);
|
||||
|
@ -100,10 +96,6 @@ public class DynamicConfiguration {
|
|||
return messagePersister;
|
||||
}
|
||||
|
||||
public DynamicRateLimitPolicy getRateLimitPolicy() {
|
||||
return rateLimitPolicy;
|
||||
}
|
||||
|
||||
public DynamicRegistrationConfiguration getRegistrationConfiguration() {
|
||||
return registrationConfiguration;
|
||||
}
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.configuration.dynamic;
|
||||
|
||||
public record DynamicRateLimitPolicy(boolean failOpen) {}
|
|
@ -102,9 +102,9 @@ public class AccountControllerV2 {
|
|||
name = "Retry-After",
|
||||
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
|
||||
public AccountIdentityResponse changeNumber(@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
|
||||
@NotNull @Valid final ChangeNumberRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgentString,
|
||||
@Context final ContainerRequestContext requestContext)
|
||||
throws RateLimitExceededException, InterruptedException {
|
||||
@NotNull @Valid final ChangeNumberRequest request,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgentString,
|
||||
@Context final ContainerRequestContext requestContext) throws RateLimitExceededException, InterruptedException {
|
||||
|
||||
if (!authenticatedDevice.getAuthenticatedDevice().isPrimary()) {
|
||||
throw new ForbiddenException();
|
||||
|
|
|
@ -15,16 +15,12 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
|||
import jakarta.ws.rs.GET;
|
||||
import jakarta.ws.rs.Path;
|
||||
import jakarta.ws.rs.Produces;
|
||||
import jakarta.ws.rs.container.ContainerRequestContext;
|
||||
import jakarta.ws.rs.core.Context;
|
||||
import jakarta.ws.rs.core.MediaType;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager;
|
||||
import org.whispersystems.textsecuregcm.auth.TurnToken;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.websocket.auth.ReadOnly;
|
||||
|
||||
|
@ -32,14 +28,16 @@ import org.whispersystems.websocket.auth.ReadOnly;
|
|||
@Path("/v2/calling")
|
||||
public class CallRoutingControllerV2 {
|
||||
|
||||
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
|
||||
private final RateLimiters rateLimiters;
|
||||
private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager;
|
||||
|
||||
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER =
|
||||
Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
|
||||
|
||||
public CallRoutingControllerV2(
|
||||
final RateLimiters rateLimiters,
|
||||
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager
|
||||
) {
|
||||
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager) {
|
||||
|
||||
this.rateLimiters = rateLimiters;
|
||||
this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager;
|
||||
}
|
||||
|
@ -58,25 +56,17 @@ public class CallRoutingControllerV2 {
|
|||
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
|
||||
@ApiResponse(responseCode = "422", description = "Invalid request format.")
|
||||
@ApiResponse(responseCode = "429", description = "Rate limited.")
|
||||
public GetCallingRelaysResponse getCallingRelays(
|
||||
final @ReadOnly @Auth AuthenticatedDevice auth
|
||||
) throws RateLimitExceededException, IOException {
|
||||
UUID aci = auth.getAccount().getUuid();
|
||||
public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth)
|
||||
throws RateLimitExceededException, IOException {
|
||||
|
||||
final UUID aci = auth.getAccount().getUuid();
|
||||
rateLimiters.getCallEndpointLimiter().validate(aci);
|
||||
|
||||
List<TurnToken> tokens = new ArrayList<>();
|
||||
try {
|
||||
tokens.add(cloudflareTurnCredentialsManager.retrieveFromCloudflare());
|
||||
} catch (Exception e) {
|
||||
CallRoutingControllerV2.CLOUDFLARE_TURN_ERROR_COUNTER.increment();
|
||||
return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare()));
|
||||
} catch (final Exception e) {
|
||||
CLOUDFLARE_TURN_ERROR_COUNTER.increment();
|
||||
throw e;
|
||||
}
|
||||
|
||||
return new GetCallingRelaysResponse(tokens);
|
||||
}
|
||||
|
||||
public record GetCallingRelaysResponse(
|
||||
List<TurnToken> relays
|
||||
) {
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,6 @@ import java.util.EnumMap;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionException;
|
||||
|
@ -52,7 +51,6 @@ import java.util.concurrent.CompletionStage;
|
|||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.glassfish.jersey.server.ContainerRequest;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
|
||||
|
@ -74,6 +72,7 @@ import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
|
|||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
|
@ -402,7 +401,7 @@ public class DeviceController {
|
|||
|
||||
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
|
||||
try {
|
||||
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform());
|
||||
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).platform());
|
||||
} catch (final UnrecognizedUserAgentException ignored) {
|
||||
return linkedDeviceListenersForUnrecognizedPlatforms;
|
||||
}
|
||||
|
@ -600,25 +599,9 @@ public class DeviceController {
|
|||
}
|
||||
|
||||
private static io.micrometer.core.instrument.Tag primaryPlatformTag(final Account account) {
|
||||
final Device primaryDevice = account.getPrimaryDevice();
|
||||
|
||||
Optional<ClientPlatform> clientPlatform = Optional.empty();
|
||||
if (StringUtils.isNotBlank(primaryDevice.getGcmId())) {
|
||||
clientPlatform = Optional.of(ClientPlatform.ANDROID);
|
||||
} else if (StringUtils.isNotBlank(primaryDevice.getApnId())) {
|
||||
clientPlatform = Optional.of(ClientPlatform.IOS);
|
||||
}
|
||||
clientPlatform = clientPlatform.or(() -> Optional.ofNullable(
|
||||
switch (primaryDevice.getUserAgent()) {
|
||||
case "OWA" -> ClientPlatform.ANDROID;
|
||||
case "OWI", "OWP" -> ClientPlatform.IOS;
|
||||
case "OWD" -> ClientPlatform.DESKTOP;
|
||||
case null, default -> null;
|
||||
}));
|
||||
|
||||
return io.micrometer.core.instrument.Tag.of(
|
||||
"primaryPlatform",
|
||||
clientPlatform
|
||||
DevicePlatformUtil.getDevicePlatform(account.getPrimaryDevice())
|
||||
.map(p -> p.name().toLowerCase(Locale.ROOT))
|
||||
.orElse("unknown"));
|
||||
}
|
||||
|
|
|
@ -98,19 +98,18 @@ public class DonationController {
|
|||
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid())
|
||||
.thenCompose(receiptMatched -> {
|
||||
if (!receiptMatched) {
|
||||
|
||||
return CompletableFuture.completedFuture(
|
||||
Response.status(Status.BAD_REQUEST).entity("receipt serial is already redeemed")
|
||||
.type(MediaType.TEXT_PLAIN_TYPE).build());
|
||||
}
|
||||
|
||||
return accountsManager.getByAccountIdentifierAsync(auth.getAccount().getUuid())
|
||||
.thenCompose(optionalAccount ->
|
||||
optionalAccount.map(account -> accountsManager.updateAsync(account, a -> {
|
||||
return accountsManager.updateAsync(auth.getAccount(), a -> {
|
||||
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
|
||||
if (request.isPrimary()) {
|
||||
a.makeBadgePrimaryIfExists(clock, badgeId);
|
||||
}
|
||||
})).orElse(CompletableFuture.completedFuture(null)))
|
||||
})
|
||||
.thenApply(ignored -> Response.ok().build());
|
||||
});
|
||||
}).thenCompose(Function.identity());
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import org.whispersystems.textsecuregcm.auth.TurnToken;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public record GetCallingRelaysResponse(List<TurnToken> relays) {
|
||||
}
|
|
@ -152,7 +152,7 @@ public class KeysController {
|
|||
|
||||
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
|
||||
|
||||
if (setKeysRequest.preKeys() != null && !setKeysRequest.preKeys().isEmpty()) {
|
||||
if (!setKeysRequest.preKeys().isEmpty()) {
|
||||
Metrics.counter(STORE_KEYS_COUNTER_NAME,
|
||||
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec")))
|
||||
.increment();
|
||||
|
@ -168,7 +168,7 @@ public class KeysController {
|
|||
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
|
||||
}
|
||||
|
||||
if (setKeysRequest.pqPreKeys() != null && !setKeysRequest.pqPreKeys().isEmpty()) {
|
||||
if (!setKeysRequest.pqPreKeys().isEmpty()) {
|
||||
Metrics.counter(STORE_KEYS_COUNTER_NAME,
|
||||
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber")))
|
||||
.increment();
|
||||
|
@ -192,11 +192,7 @@ public class KeysController {
|
|||
final IdentityKey identityKey,
|
||||
@Nullable final String userAgent) {
|
||||
|
||||
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>();
|
||||
|
||||
if (setKeysRequest.pqPreKeys() != null) {
|
||||
signedPreKeys.addAll(setKeysRequest.pqPreKeys());
|
||||
}
|
||||
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(setKeysRequest.pqPreKeys());
|
||||
|
||||
if (setKeysRequest.pqLastResortPreKey() != null) {
|
||||
signedPreKeys.add(setKeysRequest.pqLastResortPreKey());
|
||||
|
@ -244,8 +240,7 @@ public class KeysController {
|
|||
@ApiResponse(responseCode = "422", description = "Invalid request format")
|
||||
public CompletableFuture<Response> checkKeys(
|
||||
@ReadOnly @Auth final AuthenticatedDevice auth,
|
||||
@RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
|
||||
@RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest) {
|
||||
|
||||
final UUID identifier = auth.getAccount().getIdentifier(checkKeysRequest.identityType());
|
||||
final byte deviceId = auth.getAuthenticatedDevice().getId();
|
||||
|
@ -386,10 +381,7 @@ public class KeysController {
|
|||
.increment();
|
||||
|
||||
if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
|
||||
final int registrationId = switch (targetIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
|
||||
};
|
||||
final int registrationId = device.getRegistrationId(targetIdentifier.identityType());
|
||||
|
||||
responseItems.add(
|
||||
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey,
|
||||
|
|
|
@ -436,11 +436,16 @@ public class MessageController {
|
|||
final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
|
||||
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
|
||||
|
||||
final Optional<Byte> syncMessageSenderDeviceId = messageType == MessageType.SYNC
|
||||
? Optional.ofNullable(sender).map(authenticatedDevice -> authenticatedDevice.getAuthenticatedDevice().getId())
|
||||
: Optional.empty();
|
||||
|
||||
try {
|
||||
messageSender.sendMessages(destination,
|
||||
destinationIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId,
|
||||
syncMessageSenderDeviceId,
|
||||
userAgent);
|
||||
} catch (final MismatchedDevicesException e) {
|
||||
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
|
||||
|
|
|
@ -428,7 +428,7 @@ public class OneTimeDonationController {
|
|||
@Nullable
|
||||
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
|
||||
try {
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString).platform();
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ import java.util.concurrent.CompletableFuture;
|
|||
import java.util.concurrent.Executor;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
import org.glassfish.jersey.server.ManagedAsync;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.signal.libsignal.protocol.ServiceId;
|
||||
|
@ -123,6 +124,7 @@ public class ProfileController {
|
|||
private static final String EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE = "expiringProfileKey";
|
||||
|
||||
private static final String VERSION_NOT_FOUND_COUNTER_NAME = name(ProfileController.class, "versionNotFound");
|
||||
private static final String DUPLICATE_AUTHENTICATION_COUNTER_NAME = name(ProfileController.class, "duplicateAuthentication");
|
||||
|
||||
public ProfileController(
|
||||
Clock clock,
|
||||
|
@ -204,11 +206,12 @@ public class ProfileController {
|
|||
.build()));
|
||||
}
|
||||
|
||||
final List<AccountBadge> updatedBadges = request.badges()
|
||||
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, auth.getAccount().getBadges()))
|
||||
.orElseGet(() -> auth.getAccount().getBadges());
|
||||
|
||||
accountsManager.update(auth.getAccount(), a -> {
|
||||
|
||||
final List<AccountBadge> updatedBadges = request.badges()
|
||||
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
|
||||
.orElseGet(a::getBadges);
|
||||
|
||||
a.setBadges(clock, updatedBadges);
|
||||
a.setCurrentProfileVersion(request.version());
|
||||
});
|
||||
|
@ -229,11 +232,12 @@ public class ProfileController {
|
|||
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
|
||||
@Context ContainerRequestContext containerRequestContext,
|
||||
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
|
||||
@PathParam("version") String version)
|
||||
@PathParam("version") String version,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
|
||||
throws RateLimitExceededException {
|
||||
|
||||
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
|
||||
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier);
|
||||
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "getVersionedProfile", userAgent);
|
||||
|
||||
return buildVersionedProfileResponse(targetAccount,
|
||||
version,
|
||||
|
@ -252,7 +256,8 @@ public class ProfileController {
|
|||
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
|
||||
@PathParam("version") String version,
|
||||
@PathParam("credentialRequest") String credentialRequest,
|
||||
@QueryParam("credentialType") String credentialType)
|
||||
@QueryParam("credentialType") String credentialType,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
|
||||
throws RateLimitExceededException {
|
||||
|
||||
if (!EXPIRING_PROFILE_KEY_CREDENTIAL_TYPE.equals(credentialType)) {
|
||||
|
@ -260,7 +265,7 @@ public class ProfileController {
|
|||
}
|
||||
|
||||
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
|
||||
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier);
|
||||
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "credentialRequest", userAgent);
|
||||
final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false);
|
||||
|
||||
return buildExpiringProfileKeyCredentialProfileResponse(targetAccount,
|
||||
|
@ -282,8 +287,7 @@ public class ProfileController {
|
|||
@HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken,
|
||||
@Context ContainerRequestContext containerRequestContext,
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
|
||||
@PathParam("identifier") ServiceIdentifier identifier,
|
||||
@QueryParam("ca") boolean useCaCertificate)
|
||||
@PathParam("identifier") ServiceIdentifier identifier)
|
||||
throws RateLimitExceededException {
|
||||
|
||||
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
|
||||
|
@ -302,7 +306,7 @@ public class ProfileController {
|
|||
}
|
||||
} else {
|
||||
targetAccount = verifyPermissionToReceiveProfile(
|
||||
maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier);
|
||||
maybeRequester, accessKey.filter(ignored -> identifier.identityType() == IdentityType.ACI), identifier, "getUnversionedProfile", userAgent);
|
||||
}
|
||||
return switch (identifier.identityType()) {
|
||||
case ACI -> buildBaseProfileResponseForAccountIdentity(targetAccount,
|
||||
|
@ -385,7 +389,7 @@ public class ProfileController {
|
|||
profileKeyCredentialResponse = ProfileHelper.getExpiringProfileKeyCredential(HexFormat.of().parseHex(encodedCredentialRequest),
|
||||
profile, new ServiceId.Aci(account.getUuid()), zkProfileOperations);
|
||||
} catch (VerificationFailedException | InvalidInputException e) {
|
||||
throw new BadRequestException(Response.status(Response.Status.BAD_REQUEST).build(), e);
|
||||
throw new BadRequestException(e);
|
||||
}
|
||||
return profileKeyCredentialResponse;
|
||||
})
|
||||
|
@ -473,7 +477,15 @@ public class ProfileController {
|
|||
*/
|
||||
private Account verifyPermissionToReceiveProfile(final Optional<Account> maybeRequester,
|
||||
final Optional<Anonymous> maybeAccessKey,
|
||||
final ServiceIdentifier accountIdentifier) throws RateLimitExceededException {
|
||||
final ServiceIdentifier accountIdentifier,
|
||||
final String endpoint,
|
||||
@Nullable final String userAgent) throws RateLimitExceededException {
|
||||
|
||||
if (maybeRequester.isPresent() && maybeAccessKey.isPresent()) {
|
||||
Metrics.counter(DUPLICATE_AUTHENTICATION_COUNTER_NAME,
|
||||
Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), io.micrometer.core.instrument.Tag.of("endpoint", endpoint)))
|
||||
.increment();
|
||||
}
|
||||
|
||||
if (maybeRequester.isPresent()) {
|
||||
rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid());
|
||||
|
|
|
@ -755,7 +755,7 @@ public class SubscriptionController {
|
|||
@Nullable
|
||||
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
|
||||
try {
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString).platform();
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -7,30 +7,30 @@ package org.whispersystems.textsecuregcm.entities;
|
|||
|
||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import java.util.UUID;
|
||||
import javax.annotation.Nullable;
|
||||
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
|
||||
|
||||
public record AccountIdentityResponse(
|
||||
@Schema(description="the account identifier for this account")
|
||||
@Schema(description = "the account identifier for this account")
|
||||
UUID uuid,
|
||||
|
||||
@Schema(description="the phone number associated with this account")
|
||||
@Schema(description = "the phone number associated with this account")
|
||||
String number,
|
||||
|
||||
@Schema(description="the account identifier for this account's phone-number identity")
|
||||
@Schema(description = "the account identifier for this account's phone-number identity")
|
||||
UUID pni,
|
||||
|
||||
@Schema(description="a hash of this account's username, if set")
|
||||
@Schema(description = "a hash of this account's username, if set")
|
||||
@JsonSerialize(using = ByteArrayBase64UrlAdapter.Serializing.class)
|
||||
@JsonDeserialize(using = ByteArrayBase64UrlAdapter.Deserializing.class)
|
||||
@Nullable byte[] usernameHash,
|
||||
|
||||
@Schema(description="this account's username link handle, if set")
|
||||
@Schema(description = "this account's username link handle, if set")
|
||||
@Nullable UUID usernameLinkHandle,
|
||||
|
||||
@Schema(description="whether any of this account's devices support storage")
|
||||
@Schema(description = "whether any of this account's devices support storage")
|
||||
boolean storageCapable,
|
||||
|
||||
@Schema(description = "entitlements for this account and their current expirations")
|
||||
|
|
|
@ -17,6 +17,7 @@ import javax.annotation.Nullable;
|
|||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.AssertTrue;
|
||||
import jakarta.validation.constraints.NotBlank;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
|
||||
|
@ -51,34 +52,27 @@ public record ChangeNumberRequest(
|
|||
arraySchema=@Schema(description="""
|
||||
A list of synchronization messages to send to companion devices to supply the private keysManager
|
||||
associated with the new identity key and their new prekeys.
|
||||
Exactly one message must be supplied for each enabled device other than the sending (primary) device."""))
|
||||
Exactly one message must be supplied for each device other than the sending (primary) device."""))
|
||||
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
|
||||
|
||||
@Schema(description="""
|
||||
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
|
||||
A new signed elliptic-curve prekey for each device on the account, including this one.
|
||||
Each must be accompanied by a valid signature from the new identity key in this request.""")
|
||||
@NotNull @Valid Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
|
||||
@NotNull @NotEmpty @Valid Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
|
||||
|
||||
@Schema(description="""
|
||||
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
|
||||
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
|
||||
If present, must contain one prekey per enabled device including this one.
|
||||
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
|
||||
A new signed post-quantum last-resort prekey for each device on the account, including this one.
|
||||
Each must be accompanied by a valid signature from the new identity key in this request.""")
|
||||
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
|
||||
@NotNull @NotEmpty @Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
|
||||
|
||||
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
|
||||
@NotNull Map<Byte, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
|
||||
@Schema(description="the new phone-number-identity registration ID for each device on the account, including this one")
|
||||
@NotNull @NotEmpty Map<Byte, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
|
||||
|
||||
public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) {
|
||||
List<SignedPreKey<?>> spks = new ArrayList<>();
|
||||
if (devicePniSignedPrekeys != null) {
|
||||
spks.addAll(devicePniSignedPrekeys.values());
|
||||
}
|
||||
if (devicePniPqLastResortPrekeys != null) {
|
||||
final List<SignedPreKey<?>> spks = new ArrayList<>(devicePniSignedPrekeys.values());
|
||||
spks.addAll(devicePniPqLastResortPrekeys.values());
|
||||
}
|
||||
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number");
|
||||
|
||||
return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number");
|
||||
}
|
||||
|
||||
@AssertTrue
|
||||
|
|
|
@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
|||
import io.swagger.v3.oas.annotations.media.ArraySchema;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -29,36 +30,36 @@ public record PhoneNumberIdentityKeyDistributionRequest(
|
|||
arraySchema=@Schema(description="""
|
||||
A list of synchronization messages to send to companion devices to supply the private keys
|
||||
associated with the new identity key and their new prekeys.
|
||||
Exactly one message must be supplied for each enabled device other than the sending (primary) device.
|
||||
Exactly one message must be supplied for each device other than the sending (primary) device.
|
||||
"""))
|
||||
List<@NotNull @Valid IncomingMessage> deviceMessages,
|
||||
|
||||
@NotNull
|
||||
@NotEmpty
|
||||
@Valid
|
||||
@Schema(description="""
|
||||
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
|
||||
A new signed elliptic-curve prekey for each device on the account, including this one.
|
||||
Each must be accompanied by a valid signature from the new identity key in this request.""")
|
||||
Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
|
||||
|
||||
@NotNull
|
||||
@NotEmpty
|
||||
@Valid
|
||||
@Schema(description="""
|
||||
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
|
||||
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
|
||||
If present, must contain one prekey per enabled device including this one.
|
||||
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
|
||||
A new signed post-quantum last-resort prekey for each device on the account, including this one.
|
||||
Each must be accompanied by a valid signature from the new identity key in this request.""")
|
||||
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
|
||||
Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
|
||||
|
||||
@NotNull
|
||||
@NotEmpty
|
||||
@Valid
|
||||
@Schema(description="The new registration ID to use for the phone-number identity of each device, including this one.")
|
||||
Map<Byte, Integer> pniRegistrationIds) {
|
||||
|
||||
public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) {
|
||||
List<SignedPreKey<?>> spks = new ArrayList<>(devicePniSignedPrekeys.values());
|
||||
if (devicePniPqLastResortPrekeys != null) {
|
||||
spks.addAll(devicePniPqLastResortPrekeys.values());
|
||||
}
|
||||
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "distribute-pni-keys");
|
||||
}
|
||||
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(devicePniSignedPrekeys.values());
|
||||
signedPreKeys.addAll(devicePniPqLastResortPrekeys.values());
|
||||
|
||||
return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, signedPreKeys, userAgent, "distribute-pni-keys");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,11 +5,15 @@
|
|||
package org.whispersystems.textsecuregcm.entities;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import jakarta.validation.Valid;
|
||||
import java.util.List;
|
||||
|
||||
public record SetKeysRequest(
|
||||
@NotNull
|
||||
@Valid
|
||||
@Size(max=100)
|
||||
@Schema(description = """
|
||||
A list of unsigned elliptic-curve prekeys to use for this device. If present and not empty, replaces all stored
|
||||
unsigned EC prekeys for the device; if absent or empty, any stored unsigned EC prekeys for the device are not
|
||||
|
@ -25,7 +29,9 @@ public record SetKeysRequest(
|
|||
""")
|
||||
ECSignedPreKey signedPreKey,
|
||||
|
||||
@NotNull
|
||||
@Valid
|
||||
@Size(max=100)
|
||||
@Schema(description = """
|
||||
A list of signed post-quantum one-time prekeys to use for this device. Each key must have a valid signature from
|
||||
the identity key in this request. If present and not empty, replaces all stored unsigned PQ prekeys for the
|
||||
|
@ -40,4 +46,16 @@ public record SetKeysRequest(
|
|||
deleted. If present, must have a valid signature from the identity key in this request.
|
||||
""")
|
||||
KEMSignedPreKey pqLastResortPreKey) {
|
||||
public SetKeysRequest {
|
||||
// It’s a little counter-intuitive, but this compact constructor allows a default value
|
||||
// to be used when one isn’t specified, allowing the field to still be
|
||||
// validated as @NotNull
|
||||
if (preKeys == null) {
|
||||
preKeys = List.of();
|
||||
}
|
||||
|
||||
if (pqPreKeys == null) {
|
||||
pqPreKeys = List.of();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,7 +81,16 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) {
|
||||
@Nullable final UserAgent userAgent = RequestAttributesUtil.getUserAgent()
|
||||
.map(userAgentString -> {
|
||||
try {
|
||||
return UserAgentUtil.parseUserAgentString(userAgentString);
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
return null;
|
||||
}
|
||||
}).orElse(null);
|
||||
|
||||
if (shouldBlock(userAgent)) {
|
||||
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
|
||||
return new ServerCall.Listener<>() {};
|
||||
} else {
|
||||
|
@ -108,28 +117,28 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
|
|||
return true;
|
||||
}
|
||||
|
||||
if (blockedVersionsByPlatform.containsKey(userAgent.getPlatform())) {
|
||||
if (blockedVersionsByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) {
|
||||
if (blockedVersionsByPlatform.containsKey(userAgent.platform())) {
|
||||
if (blockedVersionsByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
|
||||
recordDeprecation(userAgent, BLOCKED_CLIENT_REASON);
|
||||
shouldBlock = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (minimumVersionsByPlatform.containsKey(userAgent.getPlatform())) {
|
||||
if (userAgent.getVersion().isLowerThan(minimumVersionsByPlatform.get(userAgent.getPlatform()))) {
|
||||
if (minimumVersionsByPlatform.containsKey(userAgent.platform())) {
|
||||
if (userAgent.version().isLowerThan(minimumVersionsByPlatform.get(userAgent.platform()))) {
|
||||
recordDeprecation(userAgent, EXPIRED_CLIENT_REASON);
|
||||
shouldBlock = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (versionsPendingBlockByPlatform.containsKey(userAgent.getPlatform())) {
|
||||
if (versionsPendingBlockByPlatform.get(userAgent.getPlatform()).contains(userAgent.getVersion())) {
|
||||
if (versionsPendingBlockByPlatform.containsKey(userAgent.platform())) {
|
||||
if (versionsPendingBlockByPlatform.get(userAgent.platform()).contains(userAgent.version())) {
|
||||
recordPendingDeprecation(userAgent, BLOCKED_CLIENT_REASON);
|
||||
}
|
||||
}
|
||||
|
||||
if (versionsPendingDeprecationByPlatform.containsKey(userAgent.getPlatform())) {
|
||||
if (userAgent.getVersion().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.getPlatform()))) {
|
||||
if (versionsPendingDeprecationByPlatform.containsKey(userAgent.platform())) {
|
||||
if (userAgent.version().isLowerThan(versionsPendingDeprecationByPlatform.get(userAgent.platform()))) {
|
||||
recordPendingDeprecation(userAgent, EXPIRED_CLIENT_REASON);
|
||||
}
|
||||
}
|
||||
|
@ -139,13 +148,13 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
|
|||
|
||||
private void recordDeprecation(final UserAgent userAgent, final String reason) {
|
||||
Metrics.counter(DEPRECATED_CLIENT_COUNTER_NAME,
|
||||
PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized",
|
||||
PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized",
|
||||
REASON_TAG_NAME, reason).increment();
|
||||
}
|
||||
|
||||
private void recordPendingDeprecation(final UserAgent userAgent, final String reason) {
|
||||
Metrics.counter(PENDING_DEPRECATION_COUNTER_NAME,
|
||||
PLATFORM_TAG, userAgent.getPlatform().name().toLowerCase(),
|
||||
PLATFORM_TAG, userAgent.platform().name().toLowerCase(),
|
||||
REASON_TAG_NAME, reason).increment();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,8 +15,6 @@ import jakarta.ws.rs.container.ContainerRequestFilter;
|
|||
import jakarta.ws.rs.core.SecurityContext;
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||
|
@ -70,8 +68,8 @@ public class RestDeprecationFilter implements ContainerRequestFilter {
|
|||
|
||||
try {
|
||||
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
|
||||
final ClientPlatform platform = userAgent.getPlatform();
|
||||
final Semver version = userAgent.getVersion();
|
||||
final ClientPlatform platform = userAgent.platform();
|
||||
final Semver version = userAgent.version();
|
||||
if (!minimumRestFreeVersion.containsKey(platform)) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
/**
|
||||
* Indicates that a remote channel was not found for a given server call or remote address.
|
||||
*/
|
||||
public class ChannelNotFoundException extends Exception {
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.ForwardingServerCallListener;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
/**
|
||||
* Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with
|
||||
* {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise.
|
||||
*/
|
||||
public class ChannelShutdownInterceptor implements ServerInterceptor {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (!grpcClientConnectionManager.handleServerCallStart(call)) {
|
||||
// Don't allow new calls if the connection is getting ready to close
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
|
||||
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {
|
||||
@Override
|
||||
public void onComplete() {
|
||||
grpcClientConnectionManager.handleServerCallComplete(call);
|
||||
super.onComplete();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancel() {
|
||||
grpcClientConnectionManager.handleServerCallComplete(call);
|
||||
super.onCancel();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -10,8 +10,12 @@ import org.whispersystems.textsecuregcm.storage.Device;
|
|||
|
||||
public class DeviceIdUtil {
|
||||
|
||||
public static boolean isValid(int deviceId) {
|
||||
return deviceId >= Device.PRIMARY_ID && deviceId <= Byte.MAX_VALUE;
|
||||
}
|
||||
|
||||
static byte validate(int deviceId) {
|
||||
if (deviceId < Device.PRIMARY_ID || deviceId > Byte.MAX_VALUE) {
|
||||
if (!isValid(deviceId)) {
|
||||
throw Status.INVALID_ARGUMENT.withDescription("Device ID is out of range").asRuntimeException();
|
||||
}
|
||||
|
||||
|
|
|
@ -187,7 +187,8 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me
|
|||
destination,
|
||||
destinationServiceIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId);
|
||||
registrationIdsByDeviceId,
|
||||
Optional.empty());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -252,7 +253,7 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me
|
|||
story,
|
||||
ephemeral,
|
||||
urgent,
|
||||
RequestAttributesUtil.getRawUserAgent().orElse(null));
|
||||
RequestAttributesUtil.getUserAgent().orElse(null));
|
||||
|
||||
final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder();
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
|
|||
import io.grpc.Status;
|
||||
import io.grpc.StatusException;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import org.signal.chat.messages.MismatchedDevices;
|
||||
import org.signal.chat.messages.SendMessageResponse;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
|
@ -31,6 +32,8 @@ public class MessagesGrpcHelper {
|
|||
* @param destinationServiceIdentifier the service identifier for the destination account
|
||||
* @param messagesByDeviceId a map of device IDs to message payloads
|
||||
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
|
||||
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
|
||||
* caller's own account), contains the ID of the device that sent the message
|
||||
*
|
||||
* @return a response object to send to callers
|
||||
*
|
||||
|
@ -42,14 +45,17 @@ public class MessagesGrpcHelper {
|
|||
final Account destination,
|
||||
final ServiceIdentifier destinationServiceIdentifier,
|
||||
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId,
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId) throws StatusException, RateLimitExceededException {
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId)
|
||||
throws StatusException, RateLimitExceededException {
|
||||
|
||||
try {
|
||||
messageSender.sendMessages(destination,
|
||||
destinationServiceIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId,
|
||||
RequestAttributesUtil.getRawUserAgent().orElse(null));
|
||||
syncMessageSenderDeviceId,
|
||||
RequestAttributesUtil.getUserAgent().orElse(null));
|
||||
|
||||
return SEND_MESSAGE_SUCCESS_RESPONSE;
|
||||
} catch (final MismatchedDevicesException e) {
|
||||
|
|
|
@ -172,7 +172,8 @@ public class MessagesGrpcService extends SimpleMessagesGrpc.MessagesImplBase {
|
|||
destination,
|
||||
destinationServiceIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId);
|
||||
registrationIdsByDeviceId,
|
||||
messageType == MessageType.SYNC ? Optional.of(sender.deviceId()) : Optional.empty());
|
||||
}
|
||||
|
||||
private static MessageProtos.Envelope.Type getEnvelopeType(final AuthenticatedSenderMessageType type) {
|
||||
|
|
|
@ -145,11 +145,13 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase {
|
|||
request.getCommitment().toByteArray())));
|
||||
|
||||
final List<Mono<?>> updates = new ArrayList<>(2);
|
||||
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
|
||||
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, account.getBadges()))
|
||||
.orElseGet(account::getBadges);
|
||||
|
||||
updates.add(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> {
|
||||
|
||||
final List<AccountBadge> updatedBadges = Optional.of(request.getBadgeIdsList())
|
||||
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
|
||||
.orElseGet(a::getBadges);
|
||||
|
||||
a.setBadges(clock, updatedBadges);
|
||||
a.setCurrentProfileVersion(request.getVersion());
|
||||
})));
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import java.net.InetAddress;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
public record RequestAttributes(InetAddress remoteAddress,
|
||||
@Nullable String userAgent,
|
||||
List<Locale.LanguageRange> acceptLanguage) {
|
||||
}
|
|
@ -2,28 +2,25 @@ package org.whispersystems.textsecuregcm.grpc;
|
|||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import java.net.InetAddress;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
|
||||
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
|
||||
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
|
||||
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
|
||||
*
|
||||
* @see RequestAttributesUtil
|
||||
*/
|
||||
public class RequestAttributesInterceptor implements ServerInterceptor {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
|
||||
|
||||
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
}
|
||||
|
@ -33,52 +30,12 @@ public class RequestAttributesInterceptor implements ServerInterceptor {
|
|||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
|
||||
Context context = Context.current();
|
||||
|
||||
{
|
||||
final Optional<InetAddress> maybeRemoteAddress = grpcClientConnectionManager.getRemoteAddress(localAddress);
|
||||
|
||||
if (maybeRemoteAddress.isEmpty()) {
|
||||
// We should never have a call from a party whose remote address we can't identify
|
||||
log.warn("No remote address available");
|
||||
|
||||
call.close(Status.INTERNAL, new Metadata());
|
||||
return new ServerCall.Listener<>() {};
|
||||
}
|
||||
|
||||
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
|
||||
}
|
||||
|
||||
{
|
||||
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
|
||||
grpcClientConnectionManager.getAcceptableLanguages(localAddress);
|
||||
|
||||
if (maybeAcceptLanguage.isPresent()) {
|
||||
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
final Optional<String> maybeRawUserAgent =
|
||||
grpcClientConnectionManager.getRawUserAgent(localAddress);
|
||||
|
||||
if (maybeRawUserAgent.isPresent()) {
|
||||
context = context.withValue(RequestAttributesUtil.RAW_USER_AGENT_CONTEXT_KEY, maybeRawUserAgent.get());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
final Optional<UserAgent> maybeUserAgent = grpcClientConnectionManager.getUserAgent(localAddress);
|
||||
|
||||
if (maybeUserAgent.isPresent()) {
|
||||
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
|
||||
}
|
||||
}
|
||||
|
||||
return Contexts.interceptCall(context, call, headers, next);
|
||||
} else {
|
||||
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
|
||||
try {
|
||||
return Contexts.interceptCall(Context.current()
|
||||
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
|
||||
grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next);
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,18 +3,13 @@ package org.whispersystems.textsecuregcm.grpc;
|
|||
import io.grpc.Context;
|
||||
import java.net.InetAddress;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
|
||||
public class RequestAttributesUtil {
|
||||
|
||||
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language");
|
||||
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
|
||||
static final Context.Key<String> RAW_USER_AGENT_CONTEXT_KEY = Context.key("unparsed-user-agent");
|
||||
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("parsed-user-agent");
|
||||
static final Context.Key<RequestAttributes> REQUEST_ATTRIBUTES_CONTEXT_KEY = Context.key("request-attributes");
|
||||
|
||||
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
|
||||
|
||||
|
@ -23,8 +18,8 @@ public class RequestAttributesUtil {
|
|||
*
|
||||
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
|
||||
*/
|
||||
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() {
|
||||
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get());
|
||||
public static List<Locale.LanguageRange> getAcceptableLanguages() {
|
||||
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().acceptLanguage();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -35,9 +30,7 @@ public class RequestAttributesUtil {
|
|||
* @return a list of distinct locales acceptable to the remote client and available in this JVM
|
||||
*/
|
||||
public static List<Locale> getAvailableAcceptedLocales() {
|
||||
return getAcceptableLanguages()
|
||||
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
|
||||
.orElseGet(Collections::emptyList);
|
||||
return Locale.filter(getAcceptableLanguages(), AVAILABLE_LOCALES);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -46,16 +39,7 @@ public class RequestAttributesUtil {
|
|||
* @return the remote address of the remote client
|
||||
*/
|
||||
public static InetAddress getRemoteAddress() {
|
||||
return REMOTE_ADDRESS_CONTEXT_KEY.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the parsed user-agent of the remote client in the current gRPC request context.
|
||||
*
|
||||
* @return the parsed user-agent of the remote client; may be empty if unparseable or not specified
|
||||
*/
|
||||
public static Optional<UserAgent> getUserAgent() {
|
||||
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
|
||||
return REQUEST_ATTRIBUTES_CONTEXT_KEY.get().remoteAddress();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -63,7 +47,7 @@ public class RequestAttributesUtil {
|
|||
*
|
||||
* @return the unparsed user-agent of the remote client; may be empty if not specified
|
||||
*/
|
||||
public static Optional<String> getRawUserAgent() {
|
||||
return Optional.ofNullable(RAW_USER_AGENT_CONTEXT_KEY.get());
|
||||
public static Optional<String> getUserAgent() {
|
||||
return Optional.ofNullable(REQUEST_ATTRIBUTES_CONTEXT_KEY.get().userAgent());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.Status;
|
||||
|
||||
public class ServerInterceptorUtil {
|
||||
|
||||
@SuppressWarnings("rawtypes")
|
||||
private static final ServerCall.Listener NO_OP_LISTENER = new ServerCall.Listener<>() {};
|
||||
|
||||
private static final Metadata EMPTY_TRAILERS = new Metadata();
|
||||
|
||||
private ServerInterceptorUtil() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes the given server call with the given status, returning a no-op listener.
|
||||
*
|
||||
* @param call the server call to close
|
||||
* @param status the status with which to close the call
|
||||
*
|
||||
* @return a no-op server call listener
|
||||
*
|
||||
* @param <ReqT> the type of request object handled by the server call
|
||||
* @param <RespT> the type of response object returned by the server call
|
||||
*/
|
||||
public static <ReqT, RespT> ServerCall.Listener<ReqT> closeWithStatus(final ServerCall<ReqT, RespT> call, final Status status) {
|
||||
call.close(status, EMPTY_TRAILERS);
|
||||
|
||||
//noinspection unchecked
|
||||
return NO_OP_LISTENER;
|
||||
}
|
||||
}
|
|
@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
|
|||
import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.internalError;
|
||||
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.ForwardingServerCallListener;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
|
@ -75,7 +75,7 @@ public class ValidatingInterceptor implements ServerInterceptor {
|
|||
}
|
||||
|
||||
private void validateMessage(final Object message) throws StatusException {
|
||||
if (message instanceof GeneratedMessageV3 msg) {
|
||||
if (message instanceof Message msg) {
|
||||
try {
|
||||
for (final Descriptors.FieldDescriptor fd: msg.getDescriptorForType().getFields()) {
|
||||
for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry: fd.getOptions().getAllFields().entrySet()) {
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
/**
|
||||
* Indicates that an attempt to authenticate a remote client failed for some reason.
|
||||
*/
|
||||
class ClientAuthenticationException extends Exception {
|
||||
}
|
|
@ -3,60 +3,39 @@ package org.whispersystems.textsecuregcm.grpc.net;
|
|||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a
|
||||
* WebSocket handshake, the error handler will send appropriate WebSocket closure codes to the client in an attempt to
|
||||
* identify the problem. If the client has not completed a WebSocket handshake, the handler simply closes the
|
||||
* connection.
|
||||
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. It translates exceptions
|
||||
* thrown in inbound handlers into {@link OutboundCloseErrorMessage}s.
|
||||
*/
|
||||
class ErrorHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private boolean websocketHandshakeComplete = false;
|
||||
|
||||
public class ErrorHandler extends ChannelInboundHandlerAdapter {
|
||||
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
||||
setWebsocketHandshakeComplete();
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
protected void setWebsocketHandshakeComplete() {
|
||||
this.websocketHandshakeComplete = true;
|
||||
}
|
||||
private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage(
|
||||
OutboundCloseErrorMessage.Code.NOISE_ERROR,
|
||||
"Noise encryption error");
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
|
||||
if (websocketHandshakeComplete) {
|
||||
final WebSocketCloseStatus webSocketCloseStatus = switch (ExceptionUtils.unwrap(cause)) {
|
||||
case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage());
|
||||
case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated");
|
||||
case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
|
||||
case NoiseException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
|
||||
final OutboundCloseErrorMessage closeMessage = switch (ExceptionUtils.unwrap(cause)) {
|
||||
case NoiseHandshakeException e -> new OutboundCloseErrorMessage(
|
||||
OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR,
|
||||
e.getMessage());
|
||||
case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||
case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||
default -> {
|
||||
log.warn("An unexpected exception reached the end of the pipeline", cause);
|
||||
yield WebSocketCloseStatus.INTERNAL_SERVER_ERROR;
|
||||
yield new OutboundCloseErrorMessage(
|
||||
OutboundCloseErrorMessage.Code.INTERNAL_SERVER_ERROR,
|
||||
cause.getMessage());
|
||||
}
|
||||
};
|
||||
|
||||
context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus))
|
||||
context.writeAndFlush(closeMessage)
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
} else {
|
||||
log.debug("Error occurred before websocket handshake complete", cause);
|
||||
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
|
||||
// way; just close the connection instead.
|
||||
context.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,20 +7,21 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
|
|||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.InetAddress;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
|
||||
/**
|
||||
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
|
||||
* any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
|
||||
* server.
|
||||
*/
|
||||
class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
|
@ -48,15 +49,20 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
|
||||
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
|
||||
if (event instanceof NoiseIdentityDeterminedEvent(
|
||||
final Optional<AuthenticatedDevice> authenticatedDevice,
|
||||
InetAddress remoteAddress, String userAgent, String acceptLanguage)) {
|
||||
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
|
||||
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
|
||||
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
|
||||
// authenticated service.
|
||||
final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent()
|
||||
final LocalAddress grpcServerAddress = authenticatedDevice.isPresent()
|
||||
? authenticatedGrpcServerAddress
|
||||
: anonymousGrpcServerAddress;
|
||||
|
||||
GrpcClientConnectionManager.handleHandshakeInitiated(
|
||||
remoteChannelContext.channel(), remoteAddress, userAgent, acceptLanguage);
|
||||
|
||||
new Bootstrap()
|
||||
.remoteAddress(grpcServerAddress)
|
||||
.channel(LocalChannel.class)
|
||||
|
@ -72,12 +78,14 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
if (localChannelFuture.isSuccess()) {
|
||||
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
|
||||
remoteChannelContext.channel(),
|
||||
noiseIdentityDeterminedEvent.authenticatedDevice());
|
||||
authenticatedDevice);
|
||||
|
||||
// Close the local connection if the remote channel closes and vice versa
|
||||
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
|
||||
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
|
||||
remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
|
||||
remoteChannelContext.channel()
|
||||
.write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
|
||||
|
||||
remoteChannelContext.pipeline()
|
||||
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.ServerCall;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.util.AttributeKey;
|
||||
import java.net.InetAddress;
|
||||
import java.util.ArrayList;
|
||||
|
@ -23,15 +24,26 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
||||
import org.whispersystems.textsecuregcm.util.ClosableEpoch;
|
||||
|
||||
/**
|
||||
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
|
||||
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the
|
||||
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close
|
||||
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate.
|
||||
* Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
|
||||
* identity of the device that opened the connection (for non-anonymous connections). It can also close connections
|
||||
* associated with a given device if that device's credentials have changed and clients must reauthenticate.
|
||||
* <p>
|
||||
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
|
||||
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
|
||||
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
|
||||
* narrow window between a server call being created and the start of call execution, in which case accessor methods
|
||||
* in this class will throw a {@link ChannelNotFoundException}.
|
||||
* <p>
|
||||
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
|
||||
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
|
||||
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
|
||||
* be called from any application code.
|
||||
*/
|
||||
public class GrpcClientConnectionManager implements DisconnectionRequestListener {
|
||||
|
||||
|
@ -43,94 +55,96 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
|||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress");
|
||||
public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
|
||||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent");
|
||||
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
|
||||
private static OutboundCloseErrorMessage SERVER_CLOSED =
|
||||
new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed");
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
|
||||
|
||||
/**
|
||||
* Returns the authenticated device associated with the given local address, if any. An authenticated device is
|
||||
* available if and only if the given local address maps to an active local connection and that connection is
|
||||
* authenticated (i.e. not anonymous).
|
||||
* Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
|
||||
* (i.e. unauthenticated), the returned value will be empty.
|
||||
*
|
||||
* @param localAddress the local address for which to find an authenticated device
|
||||
* @param serverCall the gRPC server call for which to find an authenticated device
|
||||
*
|
||||
* @return the authenticated device associated with the given local address, if any
|
||||
*
|
||||
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
|
||||
* generally indicates that the channel has closed while request processing is still in progress
|
||||
*/
|
||||
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) {
|
||||
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress));
|
||||
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall)
|
||||
throws ChannelNotFoundException {
|
||||
|
||||
return getAuthenticatedDevice(getRemoteChannel(serverCall));
|
||||
}
|
||||
|
||||
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) {
|
||||
return Optional.ofNullable(remoteChannel)
|
||||
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
|
||||
@VisibleForTesting
|
||||
Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) {
|
||||
return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may
|
||||
* be unavailable if the local connection associated with the given local address has already closed, if the client
|
||||
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
|
||||
* Returns the request attributes associated with the given server call.
|
||||
*
|
||||
* @param localAddress the local address for which to find acceptable languages
|
||||
* @param serverCall the gRPC server call for which to retrieve request attributes
|
||||
*
|
||||
* @return the acceptable languages associated with the given local address, if any
|
||||
* @return the request attributes associated with the given server call
|
||||
*
|
||||
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
|
||||
* generally indicates that the channel has closed while request processing is still in progress
|
||||
*/
|
||||
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
|
||||
public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
|
||||
return getRequestAttributes(getRemoteChannel(serverCall));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
RequestAttributes getRequestAttributes(final Channel remoteChannel) {
|
||||
final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
|
||||
|
||||
if (requestAttributes == null) {
|
||||
throw new IllegalStateException("Channel does not have request attributes");
|
||||
}
|
||||
|
||||
return requestAttributes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if
|
||||
* the local connection associated with the given local address has already closed.
|
||||
* Handles the start of a server call, incrementing the active call count for the remote channel associated with the
|
||||
* given server call.
|
||||
*
|
||||
* @param localAddress the local address for which to find a remote address
|
||||
* @param serverCall the server call to start
|
||||
*
|
||||
* @return the remote address associated with the given local address, if any
|
||||
* @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the
|
||||
* underlying channel is closing
|
||||
*/
|
||||
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
|
||||
public boolean handleServerCallStart(final ServerCall<?, ?> serverCall) {
|
||||
try {
|
||||
return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive();
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
// This would only happen if the channel had already closed, which is certainly possible. In this case, the call
|
||||
// should certainly not proceed.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the unparsed user agent provided by the client that opened the connection associated with the given local
|
||||
* address. This method may return an empty value if no active local connection is associated with the given local
|
||||
* address.
|
||||
* Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel
|
||||
* associated with the given server call.
|
||||
*
|
||||
* @param localAddress the local address for which to find a User-Agent string
|
||||
*
|
||||
* @return the user agent string associated with the given local address
|
||||
* @param serverCall the server call to complete
|
||||
*/
|
||||
public Optional<String> getRawUserAgent(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(RAW_USER_AGENT_ATTRIBUTE_KEY).get());
|
||||
public void handleServerCallComplete(final ServerCall<?, ?> serverCall) {
|
||||
try {
|
||||
getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart();
|
||||
} catch (final ChannelNotFoundException ignored) {
|
||||
// In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the parsed user agent provided by the client that opened the connection associated with the given local
|
||||
* address. This method may return an empty value if no active local connection is associated with the given local
|
||||
* address or if the client's user-agent string was not recognized.
|
||||
*
|
||||
* @param localAddress the local address for which to find a User-Agent string
|
||||
*
|
||||
* @return the user agent associated with the given local address
|
||||
*/
|
||||
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
|
||||
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
|
||||
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -145,10 +159,11 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
|||
final List<Channel> channelsToClose =
|
||||
new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
|
||||
|
||||
channelsToClose.forEach(channel ->
|
||||
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
|
||||
.toWebSocketCloseStatus("Reauthentication required")))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
|
||||
channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
|
||||
}
|
||||
|
||||
private static void closeRemoteChannel(final Channel channel) {
|
||||
channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
|
@ -156,53 +171,66 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
|||
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
|
||||
}
|
||||
|
||||
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
|
||||
return getRemoteChannel(getLocalAddress(serverCall));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) {
|
||||
Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
|
||||
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
|
||||
|
||||
if (remoteChannel == null) {
|
||||
throw new ChannelNotFoundException();
|
||||
}
|
||||
|
||||
return remoteChannelsByLocalAddress.get(localAddress);
|
||||
}
|
||||
|
||||
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
|
||||
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
|
||||
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
|
||||
// a proxied Noise channel.
|
||||
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
|
||||
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
|
||||
}
|
||||
|
||||
return localAddress;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
|
||||
* Handles receipt of a handshake message and associates attributes and headers from the handshake
|
||||
* request with the channel via which the handshake took place.
|
||||
*
|
||||
* @param channel the channel that completed a WebSocket handshake
|
||||
* @param channel the channel where the handshake was initiated
|
||||
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
|
||||
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
|
||||
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
|
||||
* {@code null}
|
||||
*/
|
||||
static void handleWebSocketHandshakeComplete(final Channel channel,
|
||||
public static void handleHandshakeInitiated(final Channel channel,
|
||||
final InetAddress preferredRemoteAddress,
|
||||
@Nullable final String userAgentHeader,
|
||||
@Nullable final String acceptLanguageHeader) {
|
||||
|
||||
channel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress);
|
||||
|
||||
if (StringUtils.isNotBlank(userAgentHeader)) {
|
||||
channel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
|
||||
|
||||
try {
|
||||
channel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
|
||||
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
|
||||
} catch (final UnrecognizedUserAgentException ignored) {
|
||||
}
|
||||
}
|
||||
@Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
|
||||
|
||||
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
|
||||
try {
|
||||
channel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader));
|
||||
acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
|
||||
} catch (final IllegalArgumentException e) {
|
||||
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
|
||||
}
|
||||
}
|
||||
|
||||
channel.attr(REQUEST_ATTRIBUTES_KEY)
|
||||
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server.
|
||||
* Handles successful establishment of a Noise connection from a remote client to a local gRPC server.
|
||||
*
|
||||
* @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC
|
||||
* server
|
||||
* @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel
|
||||
* @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server
|
||||
* @param remoteChannel the channel from the remote client to the Noise tunnel
|
||||
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
|
||||
*/
|
||||
void handleConnectionEstablished(final LocalChannel localChannel,
|
||||
|
@ -212,6 +240,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
|||
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
|
||||
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
|
||||
|
||||
remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
|
||||
.set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
|
||||
|
||||
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
|
||||
|
||||
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
enum HandshakePattern {
|
||||
public enum HandshakePattern {
|
||||
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
|
||||
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");
|
||||
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
/**
|
||||
* A NoiseAnonymousHandler is a netty pipeline element that handles the responder side of an unauthenticated handshake
|
||||
* and noise encryption/decryption.
|
||||
* <p>
|
||||
* A noise NK handshake must be used for unauthenticated connections. Optionally, the initiator can also include an
|
||||
* initial request in their payload. If provided, this allows the server to begin processing the request without an
|
||||
* initial message delay (fast open).
|
||||
* <p>
|
||||
* Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent}
|
||||
* indicating that initiator connected anonymously.
|
||||
*/
|
||||
class NoiseAnonymousHandler extends NoiseHandler {
|
||||
|
||||
public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) {
|
||||
super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair));
|
||||
}
|
||||
|
||||
@Override
|
||||
CompletableFuture<HandshakeResult> handleHandshakePayload(final ChannelHandlerContext context,
|
||||
final Optional<byte[]> initiatorPublicKey, final ByteBuf handshakePayload) {
|
||||
return CompletableFuture.completedFuture(new HandshakeResult(
|
||||
handshakePayload,
|
||||
Optional.empty()
|
||||
));
|
||||
}
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* A NoiseAuthenticatedHandler is a netty pipeline element that handles the responder side of an authenticated handshake
|
||||
* and noise encryption/decryption. Authenticated handshakes are noise IK handshakes where the initiator's static public
|
||||
* key is authenticated by the responder.
|
||||
* <p>
|
||||
* The authenticated handshake requires the initiator to provide a payload with their first handshake message that
|
||||
* includes their account identifier and device id in network byte-order. Optionally, the initiator can also include an
|
||||
* initial request in their payload. If provided, this allows the server to begin processing the request without an
|
||||
* initial message delay (fast open).
|
||||
* <pre>
|
||||
* +-----------------+----------------+------------------------+
|
||||
* | UUID (16) | deviceId (1) | request bytes (N) |
|
||||
* +-----------------+----------------+------------------------+
|
||||
* </pre>
|
||||
* <p>
|
||||
* For a successful handshake, the static key provided in the handshake message must match the server's stored public
|
||||
* key for the device identified by the provided ACI and deviceId.
|
||||
* <p>
|
||||
* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}.
|
||||
*/
|
||||
class NoiseAuthenticatedHandler extends NoiseHandler {
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair) {
|
||||
super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair));
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
CompletableFuture<HandshakeResult> handleHandshakePayload(
|
||||
final ChannelHandlerContext context,
|
||||
final Optional<byte[]> initiatorPublicKey,
|
||||
final ByteBuf handshakePayload) throws NoiseHandshakeException {
|
||||
if (handshakePayload.readableBytes() < 17) {
|
||||
throw new NoiseHandshakeException("Invalid handshake payload");
|
||||
}
|
||||
|
||||
final byte[] publicKeyFromClient = initiatorPublicKey
|
||||
.orElseThrow(() -> new IllegalStateException("No remote public key"));
|
||||
|
||||
// Advances the read index by 16 bytes
|
||||
final UUID accountIdentifier = parseUUID(handshakePayload);
|
||||
|
||||
// Advances the read index by 1 byte
|
||||
final byte deviceId = handshakePayload.readByte();
|
||||
|
||||
final ByteBuf fastOpenRequest = handshakePayload.slice();
|
||||
return clientPublicKeysManager
|
||||
.findPublicKey(accountIdentifier, deviceId)
|
||||
.handleAsync((storedPublicKey, throwable) -> {
|
||||
if (throwable != null) {
|
||||
ReferenceCountUtil.release(fastOpenRequest);
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
}
|
||||
final boolean valid = storedPublicKey
|
||||
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
|
||||
.orElse(false);
|
||||
if (!valid) {
|
||||
throw ExceptionUtils.wrap(new ClientAuthenticationException());
|
||||
}
|
||||
return new HandshakeResult(
|
||||
fastOpenRequest,
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
|
||||
}, context.executor());
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a {@link UUID} out of bytes, advancing the readerIdx by 16
|
||||
*
|
||||
* @param bytes The {@link ByteBuf} to read from
|
||||
* @return The parsed UUID
|
||||
* @throws NoiseHandshakeException If a UUID could not be parsed from bytes
|
||||
*/
|
||||
private UUID parseUUID(final ByteBuf bytes) throws NoiseHandshakeException {
|
||||
if (bytes.readableBytes() < 16) {
|
||||
throw new NoiseHandshakeException("Could not parse account identifier");
|
||||
}
|
||||
return new UUID(bytes.readLong(), bytes.readLong());
|
||||
}
|
||||
}
|
|
@ -1,10 +1,12 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
||||
|
||||
/**
|
||||
* Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/
|
||||
* format or a general encryption error).
|
||||
*/
|
||||
class NoiseException extends Exception {
|
||||
public class NoiseException extends NoStackTraceException {
|
||||
public NoiseException(final String message) {
|
||||
super(message);
|
||||
}
|
||||
|
|
|
@ -11,135 +11,50 @@ import io.netty.buffer.ByteBuf;
|
|||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.concurrent.PromiseCombiner;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts
|
||||
* inbound messages, and encrypts outbound messages
|
||||
* A bidirectional {@link io.netty.channel.ChannelHandler} that decrypts inbound messages, and encrypts outbound
|
||||
* messages
|
||||
*/
|
||||
abstract class NoiseHandler extends ChannelDuplexHandler {
|
||||
public class NoiseHandler extends ChannelDuplexHandler {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
|
||||
private final CipherStatePair cipherStatePair;
|
||||
|
||||
private enum State {
|
||||
// Waiting for handshake to complete
|
||||
HANDSHAKE,
|
||||
// Can freely exchange encrypted noise messages on an established session
|
||||
TRANSPORT,
|
||||
// Finished with error
|
||||
ERROR
|
||||
NoiseHandler(CipherStatePair cipherStatePair) {
|
||||
this.cipherStatePair = cipherStatePair;
|
||||
}
|
||||
|
||||
private final NoiseHandshakeHelper handshakeHelper;
|
||||
|
||||
private State state = State.HANDSHAKE;
|
||||
private CipherStatePair cipherStatePair;
|
||||
|
||||
NoiseHandler(NoiseHandshakeHelper handshakeHelper) {
|
||||
this.handshakeHelper = handshakeHelper;
|
||||
}
|
||||
|
||||
/**
|
||||
* The result of processing an initiator handshake payload
|
||||
*
|
||||
* @param fastOpenRequest A fast-open request included in the handshake. If none was present, this should be an
|
||||
* empty ByteBuf
|
||||
* @param authenticatedDevice If present, the successfully authenticated initiator identity
|
||||
*/
|
||||
record HandshakeResult(ByteBuf fastOpenRequest, Optional<AuthenticatedDevice> authenticatedDevice) {}
|
||||
|
||||
/**
|
||||
* Parse and potentially authenticate the initiator handshake message
|
||||
*
|
||||
* @param context A {@link ChannelHandlerContext}
|
||||
* @param initiatorPublicKey The initiator's static public key, if a handshake pattern that includes it was used
|
||||
* @param handshakePayload The handshake payload provided in the initiator message
|
||||
* @return A {@link HandshakeResult} that includes an authenticated device and a parsed fast-open request if one was
|
||||
* present in the handshake payload.
|
||||
* @throws NoiseHandshakeException If the handshake payload was invalid
|
||||
* @throws ClientAuthenticationException If the initiatorPublicKey could not be authenticated
|
||||
*/
|
||||
abstract CompletableFuture<HandshakeResult> handleHandshakePayload(
|
||||
final ChannelHandlerContext context,
|
||||
final Optional<byte[]> initiatorPublicKey,
|
||||
final ByteBuf handshakePayload) throws NoiseHandshakeException, ClientAuthenticationException;
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
if (frame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
final String error = "Invalid noise message length " + frame.content().readableBytes();
|
||||
throw state == State.HANDSHAKE ? new NoiseHandshakeException(error) : new NoiseException(error);
|
||||
if (message instanceof ByteBuf frame) {
|
||||
if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
throw new NoiseException("Invalid noise message length " + frame.readableBytes());
|
||||
}
|
||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||
// We'll need to copy it to a heap buffer.
|
||||
handleInboundMessage(context, ByteBufUtil.getBytes(frame.content()));
|
||||
handleInboundDataMessage(context, ByteBufUtil.getBytes(frame));
|
||||
} else {
|
||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||
// error
|
||||
// Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
fail(context, e);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleInboundMessage(final ChannelHandlerContext context, final byte[] frameBytes)
|
||||
throws NoiseHandshakeException, ShortBufferException, BadPaddingException, ClientAuthenticationException {
|
||||
switch (state) {
|
||||
|
||||
// Got an initiator handshake message
|
||||
case HANDSHAKE -> {
|
||||
final ByteBuf payload = handshakeHelper.read(frameBytes);
|
||||
handleHandshakePayload(context, handshakeHelper.remotePublicKey(), payload).whenCompleteAsync(
|
||||
(result, throwable) -> {
|
||||
if (state == State.ERROR) {
|
||||
return;
|
||||
}
|
||||
if (throwable != null) {
|
||||
fail(context, ExceptionUtils.unwrap(throwable));
|
||||
return;
|
||||
}
|
||||
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(result.authenticatedDevice()));
|
||||
|
||||
// Now that we've authenticated, write the handshake response
|
||||
byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES);
|
||||
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
|
||||
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
||||
this.state = State.TRANSPORT;
|
||||
this.cipherStatePair = handshakeHelper.getHandshakeState().split();
|
||||
if (result.fastOpenRequest().isReadable()) {
|
||||
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
|
||||
// encrypt the response when the server writes back through us
|
||||
context.fireChannelRead(result.fastOpenRequest());
|
||||
} else {
|
||||
ReferenceCountUtil.release(result.fastOpenRequest());
|
||||
}
|
||||
}, context.executor());
|
||||
}
|
||||
|
||||
// Got a client message that should be decrypted and forwarded
|
||||
case TRANSPORT -> {
|
||||
private void handleInboundDataMessage(final ChannelHandlerContext context, final byte[] frameBytes)
|
||||
throws ShortBufferException, BadPaddingException {
|
||||
final CipherState cipherState = cipherStatePair.getReceiver();
|
||||
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
||||
final int plaintextLength = cipherState.decryptWithAd(null,
|
||||
|
@ -151,21 +66,6 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
|
|||
context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
|
||||
}
|
||||
|
||||
// The session is already in an error state, drop the message
|
||||
case ERROR -> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the state to the error state (so subsequent messages fast-fail) and propagate the failure reason on the
|
||||
* context
|
||||
*/
|
||||
private void fail(final ChannelHandlerContext context, final Throwable cause) {
|
||||
this.state = State.ERROR;
|
||||
context.fireExceptionCaught(cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
||||
throws Exception {
|
||||
|
@ -193,19 +93,27 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
|
|||
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
||||
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
||||
|
||||
pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer))));
|
||||
pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer)));
|
||||
}
|
||||
pc.finish(promise);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(byteBuf);
|
||||
}
|
||||
} else {
|
||||
if (!(message instanceof WebSocketFrame)) {
|
||||
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
|
||||
// get issued in response to exceptions)
|
||||
if (!(message instanceof OutboundCloseErrorMessage)) {
|
||||
// Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames
|
||||
// that get issued in response to exceptions)
|
||||
log.warn("Unexpected object in pipeline: {}", message);
|
||||
}
|
||||
context.write(message, promise);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(ChannelHandlerContext var1) {
|
||||
if (cipherStatePair != null) {
|
||||
cipherStatePair.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
||||
|
||||
/**
|
||||
* Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
|
||||
* a general encryption error).
|
||||
*/
|
||||
class NoiseHandshakeException extends Exception {
|
||||
public class NoiseHandshakeException extends NoStackTraceException {
|
||||
|
||||
public NoiseHandshakeException(final String message) {
|
||||
super(message);
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufInputStream;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.DeviceIdUtil;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
/**
|
||||
* Handles the responder side of a noise handshake and then replaces itself with a {@link NoiseHandler} which will
|
||||
* encrypt/decrypt subsequent data frames
|
||||
* <p>
|
||||
* The handler expects to receive a single inbound message, a {@link NoiseHandshakeInit} that includes the initiator
|
||||
* handshake message, connection metadata, and the type of handshake determined by the framing layer. This handler
|
||||
* currently supports two types of handshakes.
|
||||
* <p>
|
||||
* The first are IK handshakes where the initiator's static public key is authenticated by the responder. The initiator
|
||||
* handshake message must contain the ACI and deviceId of the initiator. To be authenticated, the static key provided in
|
||||
* the handshake message must match the server's stored public key for the device identified by the provided ACI and
|
||||
* deviceId.
|
||||
* <p>
|
||||
* The second are NK handshakes which are anonymous.
|
||||
* <p>
|
||||
* Optionally, the initiator can also include an initial request in their payload. If provided, this allows the server
|
||||
* to begin processing the request without an initial message delay (fast open).
|
||||
* <p>
|
||||
* Once the handshake has been validated, a {@link NoiseIdentityDeterminedEvent} will be fired. For an IK handshake,
|
||||
* this will include the {@link org.whispersystems.textsecuregcm.auth.AuthenticatedDevice} of the initiator. This
|
||||
* handler will then replace itself with a {@link NoiseHandler} with a noise state pair ready to encrypt/decrypt data
|
||||
* frames.
|
||||
*/
|
||||
public class NoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private static final byte[] HANDSHAKE_WRONG_PK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY)
|
||||
.build().toByteArray();
|
||||
private static final byte[] HANDSHAKE_OK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
|
||||
.build().toByteArray();
|
||||
|
||||
// We might get additional messages while we're waiting to process a handshake, so keep track of where we are
|
||||
private boolean receivedHandshakeInit = false;
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
private final ECKeyPair ecKeyPair;
|
||||
|
||||
public NoiseHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (!(message instanceof NoiseHandshakeInit handshakeInit)) {
|
||||
// Anything except HandshakeInit should have been filtered out of the pipeline by now; treat this as an error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
if (receivedHandshakeInit) {
|
||||
throw new NoiseHandshakeException("Should not receive messages until handshake complete");
|
||||
}
|
||||
receivedHandshakeInit = true;
|
||||
|
||||
if (handshakeInit.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
throw new NoiseHandshakeException("Invalid noise message length " + handshakeInit.content().readableBytes());
|
||||
}
|
||||
|
||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||
// We'll need to copy it to a heap buffer
|
||||
handleInboundHandshake(context,
|
||||
handshakeInit.getRemoteAddress(),
|
||||
handshakeInit.getHandshakePattern(),
|
||||
ByteBufUtil.getBytes(handshakeInit.content()));
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleInboundHandshake(
|
||||
final ChannelHandlerContext context,
|
||||
final InetAddress remoteAddress,
|
||||
final HandshakePattern handshakePattern,
|
||||
final byte[] frameBytes) throws NoiseHandshakeException {
|
||||
final NoiseHandshakeHelper handshakeHelper = new NoiseHandshakeHelper(handshakePattern, ecKeyPair);
|
||||
final ByteBuf payload = handshakeHelper.read(frameBytes);
|
||||
|
||||
// Parse the handshake message
|
||||
final NoiseTunnelProtos.HandshakeInit handshakeInit;
|
||||
try {
|
||||
handshakeInit = NoiseTunnelProtos.HandshakeInit.parseFrom(new ByteBufInputStream(payload));
|
||||
} catch (IOException e) {
|
||||
throw new NoiseHandshakeException("Failed to parse handshake message");
|
||||
}
|
||||
|
||||
switch (handshakePattern) {
|
||||
case NK -> {
|
||||
if (handshakeInit.getDeviceId() != 0 || !handshakeInit.getAci().isEmpty()) {
|
||||
throw new NoiseHandshakeException("Anonymous handshake should not include identifiers");
|
||||
}
|
||||
handleAuthenticated(context, handshakeHelper, remoteAddress, handshakeInit, Optional.empty());
|
||||
}
|
||||
case IK -> {
|
||||
final byte[] publicKeyFromClient = handshakeHelper.remotePublicKey()
|
||||
.orElseThrow(() -> new IllegalStateException("No remote public key"));
|
||||
final UUID accountIdentifier = aci(handshakeInit);
|
||||
final byte deviceId = deviceId(handshakeInit);
|
||||
clientPublicKeysManager
|
||||
.findPublicKey(accountIdentifier, deviceId)
|
||||
.whenCompleteAsync((storedPublicKey, throwable) -> {
|
||||
if (throwable != null) {
|
||||
context.fireExceptionCaught(ExceptionUtils.unwrap(throwable));
|
||||
return;
|
||||
}
|
||||
final boolean valid = storedPublicKey
|
||||
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
|
||||
.orElse(false);
|
||||
if (!valid) {
|
||||
// Write a handshake response indicating that the client used the wrong public key
|
||||
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_WRONG_PK);
|
||||
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
|
||||
context.fireExceptionCaught(new NoiseHandshakeException("Bad public key"));
|
||||
return;
|
||||
}
|
||||
handleAuthenticated(context,
|
||||
handshakeHelper, remoteAddress, handshakeInit,
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
|
||||
}, context.executor());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void handleAuthenticated(final ChannelHandlerContext context,
|
||||
final NoiseHandshakeHelper handshakeHelper,
|
||||
final InetAddress remoteAddress,
|
||||
final NoiseTunnelProtos.HandshakeInit handshakeInit,
|
||||
final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
|
||||
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(
|
||||
maybeAuthenticatedDevice,
|
||||
remoteAddress,
|
||||
handshakeInit.getUserAgent(),
|
||||
handshakeInit.getAcceptLanguage()));
|
||||
|
||||
// Now that we've authenticated, write the handshake response
|
||||
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_OK);
|
||||
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
|
||||
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
||||
// Note: It may be tempting to swap the before/remove for a replace, but then when we forward the fast open
|
||||
// request it will go through the NoiseHandler. We want to skip the NoiseHandler because we've already
|
||||
// decrypted the fastOpen request
|
||||
context.pipeline()
|
||||
.addBefore(context.name(), null, new NoiseHandler(handshakeHelper.getHandshakeState().split()));
|
||||
context.pipeline().remove(NoiseHandshakeHandler.class);
|
||||
if (!handshakeInit.getFastOpenRequest().isEmpty()) {
|
||||
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
|
||||
// encrypt the response when the server writes back through us
|
||||
context.fireChannelRead(Unpooled.wrappedBuffer(handshakeInit.getFastOpenRequest().asReadOnlyByteBuffer()));
|
||||
}
|
||||
}
|
||||
|
||||
private static UUID aci(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
|
||||
try {
|
||||
return UUIDUtil.fromByteString(handshakePayload.getAci());
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new NoiseHandshakeException("Could not parse aci");
|
||||
}
|
||||
}
|
||||
|
||||
private static byte deviceId(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
|
||||
if (!DeviceIdUtil.isValid(handshakePayload.getDeviceId())) {
|
||||
throw new NoiseHandshakeException("Invalid deviceId");
|
||||
}
|
||||
return (byte) handshakePayload.getDeviceId();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.DefaultByteBufHolder;
|
||||
import java.net.InetAddress;
|
||||
|
||||
/**
|
||||
* A message that includes the initiator's handshake message, connection metadata, and the handshake type. The metadata
|
||||
* and handshake type are extracted from the framing layer, so this allows receivers to be framing layer agnostic.
|
||||
*/
|
||||
public class NoiseHandshakeInit extends DefaultByteBufHolder {
|
||||
|
||||
private final InetAddress remoteAddress;
|
||||
private final HandshakePattern handshakePattern;
|
||||
|
||||
public NoiseHandshakeInit(
|
||||
final InetAddress remoteAddress,
|
||||
final HandshakePattern handshakePattern,
|
||||
final ByteBuf initiatorHandshakeMessage) {
|
||||
super(initiatorHandshakeMessage);
|
||||
this.remoteAddress = remoteAddress;
|
||||
this.handshakePattern = handshakePattern;
|
||||
}
|
||||
|
||||
public InetAddress getRemoteAddress() {
|
||||
return remoteAddress;
|
||||
}
|
||||
|
||||
public HandshakePattern getHandshakePattern() {
|
||||
return handshakePattern;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import java.net.InetAddress;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
|
||||
|
@ -9,5 +10,12 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
|||
*
|
||||
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
||||
* type that performs authentication
|
||||
* @param remoteAddress the remote address of the connecting client
|
||||
* @param userAgent the client supplied userAgent
|
||||
* @param acceptLanguage the client supplied acceptLanguage
|
||||
*/
|
||||
record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}
|
||||
public record NoiseIdentityDeterminedEvent(
|
||||
Optional<AuthenticatedDevice> authenticatedDevice,
|
||||
InetAddress remoteAddress,
|
||||
String userAgent,
|
||||
String acceptLanguage) {}
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
/**
|
||||
* An error written to the outbound pipeline that indicates the connection should be closed
|
||||
*/
|
||||
public record OutboundCloseErrorMessage(Code code, String message) {
|
||||
public enum Code {
|
||||
|
||||
/**
|
||||
* The server decided to close the connection. This could be because the server is going away, or it could be
|
||||
* because the credentials for the connected client have been updated.
|
||||
*/
|
||||
SERVER_CLOSED,
|
||||
|
||||
/**
|
||||
* There was a noise decryption error after the noise session was established
|
||||
*/
|
||||
NOISE_ERROR,
|
||||
|
||||
/**
|
||||
* There was an error establishing the noise handshake
|
||||
*/
|
||||
NOISE_HANDSHAKE_ERROR,
|
||||
|
||||
INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
|
@ -8,7 +8,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
|
|||
/**
|
||||
* A proxy handler writes all data read from one channel to another peer channel.
|
||||
*/
|
||||
class ProxyHandler extends ChannelInboundHandlerAdapter {
|
||||
public class ProxyHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final Channel peerChannel;
|
||||
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
|
||||
|
||||
/**
|
||||
* In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the
|
||||
* noise packet to the noise layer as a {@link ByteBuf} for decryption.
|
||||
* <p>
|
||||
* In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper
|
||||
* so it can be wire serialized. This handler assumes the first outbound message received will correspond to the
|
||||
* handshake response, and then the subsequent messages are all data frame payloads.
|
||||
*/
|
||||
public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame frame) {
|
||||
if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType());
|
||||
}
|
||||
ctx.fireChannelRead(frame.content());
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||
if (msg instanceof ByteBuf bb) {
|
||||
ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.DefaultByteBufHolder;
|
||||
|
||||
public class NoiseDirectFrame extends DefaultByteBufHolder {
|
||||
|
||||
static final byte VERSION = 0x00;
|
||||
|
||||
private final FrameType frameType;
|
||||
|
||||
public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) {
|
||||
super(data);
|
||||
this.frameType = frameType;
|
||||
}
|
||||
|
||||
public FrameType frameType() {
|
||||
return frameType;
|
||||
}
|
||||
|
||||
public byte versionedFrameTypeByte() {
|
||||
final byte frameBits = frameType().getFrameBits();
|
||||
return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits);
|
||||
}
|
||||
|
||||
|
||||
public enum FrameType {
|
||||
/**
|
||||
* The payload is the initiator message for a Noise NK handshake. If established, the
|
||||
* session will be unauthenticated.
|
||||
*/
|
||||
NK_HANDSHAKE((byte) 1),
|
||||
/**
|
||||
* The payload is the initiator message for a Noise IK handshake. If established, the
|
||||
* session will be authenticated.
|
||||
*/
|
||||
IK_HANDSHAKE((byte) 2),
|
||||
/**
|
||||
* The payload is an encrypted noise packet.
|
||||
*/
|
||||
DATA((byte) 3),
|
||||
/**
|
||||
* A frame sent before the connection is closed. The payload is a protobuf indicating why the connection is being
|
||||
* closed.
|
||||
*/
|
||||
CLOSE((byte) 4);
|
||||
|
||||
private final byte frameType;
|
||||
|
||||
FrameType(byte frameType) {
|
||||
if (frameType != (0x0F & frameType)) {
|
||||
throw new IllegalStateException("Frame type must fit in 4 bits");
|
||||
}
|
||||
this.frameType = frameType;
|
||||
}
|
||||
|
||||
public byte getFrameBits() {
|
||||
return frameType;
|
||||
}
|
||||
|
||||
public boolean isHandshake() {
|
||||
return switch (this) {
|
||||
case IK_HANDSHAKE, NK_HANDSHAKE -> true;
|
||||
case DATA, CLOSE -> false;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
|
||||
/**
|
||||
* Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes
|
||||
* have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder}
|
||||
*/
|
||||
public class NoiseDirectFrameCodec extends ChannelDuplexHandler {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof ByteBuf byteBuf) {
|
||||
try {
|
||||
ctx.fireChannelRead(deserialize(byteBuf));
|
||||
} catch (Exception e) {
|
||||
ReferenceCountUtil.release(byteBuf);
|
||||
throw e;
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||
if (msg instanceof NoiseDirectFrame noiseDirectFrame) {
|
||||
try {
|
||||
// Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the
|
||||
// network, nothing should have to make another copy of this. If later another layer is added, it may be more
|
||||
// efficient to reuse the input buffer (typically not direct) by using a composite byte buffer
|
||||
final ByteBuf serialized = serialize(ctx, noiseDirectFrame);
|
||||
ctx.writeAndFlush(serialized, promise);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(noiseDirectFrame);
|
||||
}
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
|
||||
private ByteBuf serialize(
|
||||
final ChannelHandlerContext ctx,
|
||||
final NoiseDirectFrame noiseDirectFrame) {
|
||||
if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes());
|
||||
}
|
||||
|
||||
// 1 version/frametype byte, 2 length bytes, content
|
||||
final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes());
|
||||
|
||||
byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte());
|
||||
byteBuf.writeShort(noiseDirectFrame.content().readableBytes());
|
||||
byteBuf.writeBytes(noiseDirectFrame.content());
|
||||
return byteBuf;
|
||||
}
|
||||
|
||||
private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception {
|
||||
final byte versionAndFrameByte = byteBuf.readByte();
|
||||
final int version = (versionAndFrameByte & 0xF0) >> 4;
|
||||
if (version != NoiseDirectFrame.VERSION) {
|
||||
throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version);
|
||||
}
|
||||
final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F);
|
||||
final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) {
|
||||
case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
|
||||
case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
|
||||
case 3 -> NoiseDirectFrame.FrameType.DATA;
|
||||
case 4 -> NoiseDirectFrame.FrameType.CLOSE;
|
||||
default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
|
||||
};
|
||||
|
||||
final int length = Short.toUnsignedInt(byteBuf.readShort());
|
||||
if (length != byteBuf.readableBytes()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Payload length did not match remaining buffer, should have been guaranteed by a previous handler");
|
||||
}
|
||||
return new NoiseDirectFrame(frameType, byteBuf.readSlice(length));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||
|
||||
/**
|
||||
* Waits for a Handshake {@link NoiseDirectFrame} and then replaces itself with a {@link NoiseDirectDataFrameCodec} and
|
||||
* forwards the handshake frame along as a {@link NoiseHandshakeInit} message
|
||||
*/
|
||||
public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame frame) {
|
||||
try {
|
||||
if (!(ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress)) {
|
||||
throw new IOException("Could not determine remote address");
|
||||
}
|
||||
// We've received an inbound handshake frame. Pull the framing-protocol specific data the downstream handler
|
||||
// needs into a NoiseHandshakeInit message and forward that along
|
||||
final NoiseHandshakeInit handshakeMessage = new NoiseHandshakeInit(inetSocketAddress.getAddress(),
|
||||
switch (frame.frameType()) {
|
||||
case DATA -> throw new NoiseHandshakeException("First message must have handshake frame type");
|
||||
case CLOSE -> throw new IllegalStateException("Close frames should not reach handshake selector");
|
||||
case IK_HANDSHAKE -> HandshakePattern.IK;
|
||||
case NK_HANDSHAKE -> HandshakePattern.NK;
|
||||
}, frame.content());
|
||||
|
||||
// Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
|
||||
// should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
|
||||
// for network serialization. Note that we need to install the Data frame handler before firing the read,
|
||||
// because we may receive an outbound message from the noiseHandler
|
||||
ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
|
||||
ctx.fireChannelRead(handshakeMessage);
|
||||
} catch (Exception e) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
throw e;
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
|
||||
|
||||
/**
|
||||
* Watches for inbound close frames and closes the connection in response
|
||||
*/
|
||||
public class NoiseDirectInboundCloseHandler extends ChannelInboundHandlerAdapter {
|
||||
private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(ChannelInboundHandlerAdapter.class, "clientClose");
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||
try {
|
||||
final NoiseDirectProtos.CloseReason closeReason = NoiseDirectProtos.CloseReason
|
||||
.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||
|
||||
Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "reason", closeReason.getCode().name()).increment();
|
||||
} finally {
|
||||
ReferenceCountUtil.release(msg);
|
||||
ctx.close();
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufOutputStream;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||
|
||||
/**
|
||||
* Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are
|
||||
* written, the channel is closed
|
||||
*/
|
||||
class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof OutboundCloseErrorMessage err) {
|
||||
final NoiseDirectProtos.CloseReason.Code code = switch (err.code()) {
|
||||
case SERVER_CLOSED -> NoiseDirectProtos.CloseReason.Code.UNAVAILABLE;
|
||||
case NOISE_ERROR -> NoiseDirectProtos.CloseReason.Code.ENCRYPTION_ERROR;
|
||||
case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.CloseReason.Code.HANDSHAKE_ERROR;
|
||||
case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.CloseReason.Code.INTERNAL_ERROR;
|
||||
};
|
||||
final NoiseDirectProtos.CloseReason proto = NoiseDirectProtos.CloseReason.newBuilder()
|
||||
.setCode(code)
|
||||
.setMessage(err.message())
|
||||
.build();
|
||||
final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
|
||||
proto.writeTo(new ByteBufOutputStream(byteBuf));
|
||||
ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.CLOSE, byteBuf))
|
||||
.addListener(ChannelFutureListener.CLOSE);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.dropwizard.lifecycle.Managed;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.channel.socket.ServerSocketChannel;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||
import java.net.InetSocketAddress;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
/**
|
||||
* A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom
|
||||
* binary framing protocol) and passes it through to a local gRPC server.
|
||||
*/
|
||||
public class NoiseDirectTunnelServer implements Managed {
|
||||
|
||||
private final ServerBootstrap bootstrap;
|
||||
private ServerSocketChannel channel;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class);
|
||||
|
||||
public NoiseDirectTunnelServer(final int port,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress) {
|
||||
|
||||
this.bootstrap = new ServerBootstrap()
|
||||
.group(eventLoopGroup)
|
||||
.channel(NioServerSocketChannel.class)
|
||||
.localAddress(port)
|
||||
.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
protected void initChannel(SocketChannel socketChannel) {
|
||||
socketChannel.pipeline()
|
||||
.addLast(new ProxyProtocolDetectionHandler())
|
||||
.addLast(new HAProxyMessageHandler())
|
||||
// frame byte followed by a 2-byte length field
|
||||
.addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
|
||||
// Parses NoiseDirectFrames from wire bytes and vice versa
|
||||
.addLast(new NoiseDirectFrameCodec())
|
||||
// Terminate the connection if the client sends us a close frame
|
||||
.addLast(new NoiseDirectInboundCloseHandler())
|
||||
// Turn generic OutboundCloseErrorMessages into noise direct error frames
|
||||
.addLast(new NoiseDirectOutboundErrorHandler())
|
||||
// Forwards the first payload supplemented with handshake metadata, and then replaces itself with a
|
||||
// NoiseDirectDataFrameCodec to handle subsequent data frames
|
||||
.addLast(new NoiseDirectHandshakeSelector())
|
||||
// Performs the noise handshake and then replace itself with a NoiseHandler
|
||||
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
|
||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||
// once the Noise handshake has completed
|
||||
.addLast(new EstablishLocalGrpcConnectionHandler(
|
||||
grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
||||
.addLast(new ErrorHandler());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public InetSocketAddress getLocalAddress() {
|
||||
return channel.localAddress();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws InterruptedException {
|
||||
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stop() throws InterruptedException {
|
||||
if (channel != null) {
|
||||
channel.close().await();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,12 +1,10 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
|
||||
enum ApplicationWebSocketCloseReason {
|
||||
NOISE_HANDSHAKE_ERROR(4001),
|
||||
CLIENT_AUTHENTICATION_ERROR(4002),
|
||||
NOISE_ENCRYPTION_ERROR(4003),
|
||||
REAUTHENTICATION_REQUIRED(4004);
|
||||
NOISE_ENCRYPTION_ERROR(4002);
|
||||
|
||||
private final int statusCode;
|
||||
|
||||
|
@ -17,8 +15,4 @@ enum ApplicationWebSocketCloseReason {
|
|||
public int getStatusCode() {
|
||||
return statusCode;
|
||||
}
|
||||
|
||||
WebSocketCloseStatus toWebSocketCloseStatus(final String reason) {
|
||||
return new WebSocketCloseStatus(statusCode, reason);
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
|
@ -28,6 +28,7 @@ import javax.net.ssl.SSLException;
|
|||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.*;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
/**
|
||||
|
@ -103,10 +104,16 @@ public class NoiseWebSocketTunnelServer implements Managed {
|
|||
// request and passed it down the pipeline
|
||||
.addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH))
|
||||
.addLast(new WebSocketServerProtocolHandler("/", true))
|
||||
// Turn generic OutboundCloseErrorMessages into websocket close frames
|
||||
.addLast(new WebSocketOutboundErrorHandler())
|
||||
.addLast(new RejectUnsupportedMessagesHandler())
|
||||
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
|
||||
// a WebSocket handshake has been completed
|
||||
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret))
|
||||
.addLast(new WebsocketPayloadCodec())
|
||||
// The WebSocket handshake complete listener will forward the first payload supplemented with
|
||||
// data from the websocket handshake completion event, and then remove itself from the pipeline
|
||||
.addLast(new WebsocketHandshakeCompleteHandler(recognizedProxySecret))
|
||||
// The NoiseHandshakeHandler will perform the noise handshake and then replace itself with a
|
||||
// NoiseHandler
|
||||
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
|
||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||
// once the Noise handshake has completed
|
||||
.addLast(new EstablishLocalGrpcConnectionHandler(grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
|
@ -1,4 +1,4 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
|
@ -1,4 +1,4 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
|
@ -0,0 +1,58 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||
|
||||
/**
|
||||
* Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
|
||||
*/
|
||||
class WebSocketOutboundErrorHandler extends ChannelDuplexHandler {
|
||||
|
||||
private boolean websocketHandshakeComplete = false;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class);
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
||||
setWebsocketHandshakeComplete();
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
protected void setWebsocketHandshakeComplete() {
|
||||
this.websocketHandshakeComplete = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof OutboundCloseErrorMessage err) {
|
||||
if (websocketHandshakeComplete) {
|
||||
final int status = switch (err.code()) {
|
||||
case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
|
||||
case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
|
||||
case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
|
||||
case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
|
||||
};
|
||||
ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
} else {
|
||||
log.debug("Error {} occurred before websocket handshake complete", err);
|
||||
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
|
||||
// way; just close the connection instead.
|
||||
ctx.close();
|
||||
}
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,15 +1,15 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.net.InetAddresses;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
@ -17,10 +17,10 @@ import java.security.MessageDigest;
|
|||
import java.util.Optional;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||
|
||||
/**
|
||||
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
|
||||
|
@ -28,10 +28,6 @@ import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
|||
*/
|
||||
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private final ECKeyPair ecKeyPair;
|
||||
|
||||
private final byte[] recognizedProxySecret;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
|
||||
|
@ -42,12 +38,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
|||
@VisibleForTesting
|
||||
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
|
||||
|
||||
WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final String recognizedProxySecret) {
|
||||
private InetAddress remoteAddress = null;
|
||||
private HandshakePattern handshakePattern = null;
|
||||
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
WebsocketHandshakeCompleteHandler(final String recognizedProxySecret) {
|
||||
|
||||
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
|
||||
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
|
||||
|
@ -58,8 +52,6 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
|||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||
final InetAddress preferredRemoteAddress;
|
||||
{
|
||||
final Optional<InetAddress> maybePreferredRemoteAddress =
|
||||
getPreferredRemoteAddress(context, handshakeCompleteEvent);
|
||||
|
||||
|
@ -71,34 +63,41 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
|||
return;
|
||||
}
|
||||
|
||||
preferredRemoteAddress = maybePreferredRemoteAddress.get();
|
||||
}
|
||||
|
||||
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(),
|
||||
preferredRemoteAddress,
|
||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
|
||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));
|
||||
|
||||
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
|
||||
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH ->
|
||||
new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair);
|
||||
|
||||
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH ->
|
||||
new NoiseAnonymousHandler(ecKeyPair);
|
||||
|
||||
default -> {
|
||||
remoteAddress = maybePreferredRemoteAddress.get();
|
||||
handshakePattern = switch (handshakeCompleteEvent.requestUri()) {
|
||||
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> HandshakePattern.IK;
|
||||
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> HandshakePattern.NK;
|
||||
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
|
||||
// internal error if something slipped through.
|
||||
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
||||
}
|
||||
default -> throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
||||
};
|
||||
|
||||
context.pipeline().replace(WebsocketHandshakeCompleteHandler.this, null, noiseHandshakeHandler);
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object msg) {
|
||||
try {
|
||||
if (!(msg instanceof ByteBuf frame)) {
|
||||
throw new IllegalStateException("Unexpected msg type: " + msg.getClass());
|
||||
}
|
||||
|
||||
if (handshakePattern == null || remoteAddress == null) {
|
||||
throw new IllegalStateException("Received payload before websocket handshake complete");
|
||||
}
|
||||
|
||||
final NoiseHandshakeInit handshakeMessage =
|
||||
new NoiseHandshakeInit(remoteAddress, handshakePattern, frame);
|
||||
|
||||
context.pipeline().remove(WebsocketHandshakeCompleteHandler.class);
|
||||
context.fireChannelRead(handshakeMessage);
|
||||
} catch (Exception e) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
|
||||
/**
|
||||
* Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a
|
||||
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise
|
||||
* packet buffers in BinaryWebsocketFrames for writing through the websocket layer.
|
||||
*/
|
||||
public class WebsocketPayloadCodec extends ChannelDuplexHandler {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
|
||||
if (msg instanceof BinaryWebSocketFrame frame) {
|
||||
ctx.fireChannelRead(frame.content());
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||
if (msg instanceof ByteBuf bb) {
|
||||
ctx.write(new BinaryWebSocketFrame(bb), promise);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,7 +11,6 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusException;
|
||||
|
@ -49,7 +48,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
|
|||
public void validate(
|
||||
final Object extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final GeneratedMessageV3 msg) throws StatusException {
|
||||
final Message msg) throws StatusException {
|
||||
try {
|
||||
final T extensionValueTyped = resolveExtensionValue(extensionValue);
|
||||
|
||||
|
@ -116,7 +115,7 @@ public abstract class BaseFieldValidator<T> implements FieldValidator {
|
|||
protected void validateRepeatedField(
|
||||
final T extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final GeneratedMessageV3 msg) throws StatusException {
|
||||
final Message msg) throws StatusException {
|
||||
throw internalError("`validateRepeatedField` method needs to be implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.StatusException;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
@ -53,7 +53,7 @@ public class ExactlySizeFieldValidator extends BaseFieldValidator<Set<Integer>>
|
|||
protected void validateRepeatedField(
|
||||
final Set<Integer> permittedSizes,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final GeneratedMessageV3 msg) throws StatusException {
|
||||
final Message msg) throws StatusException {
|
||||
final int size = msg.getRepeatedFieldCount(fd);
|
||||
if (permittedSizes.contains(size)) {
|
||||
return;
|
||||
|
|
|
@ -6,11 +6,11 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.validators;
|
||||
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.StatusException;
|
||||
|
||||
public interface FieldValidator {
|
||||
|
||||
void validate(Object extensionValue, Descriptors.FieldDescriptor fd, GeneratedMessageV3 msg)
|
||||
void validate(Object extensionValue, Descriptors.FieldDescriptor fd, Message msg)
|
||||
throws StatusException;
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.StatusException;
|
||||
import java.util.Set;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
@ -52,7 +52,7 @@ public class NonEmptyFieldValidator extends BaseFieldValidator<Boolean> {
|
|||
protected void validateRepeatedField(
|
||||
final Boolean extensionValue,
|
||||
final Descriptors.FieldDescriptor fd,
|
||||
final GeneratedMessageV3 msg) throws StatusException {
|
||||
final Message msg) throws StatusException {
|
||||
if (msg.getRepeatedFieldCount(fd) > 0) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.in
|
|||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Descriptors;
|
||||
import com.google.protobuf.GeneratedMessageV3;
|
||||
import com.google.protobuf.Message;
|
||||
import io.grpc.StatusException;
|
||||
import java.util.Set;
|
||||
import org.signal.chat.require.SizeConstraint;
|
||||
|
@ -48,7 +48,7 @@ public class SizeFieldValidator extends BaseFieldValidator<Range> {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final GeneratedMessageV3 msg) throws StatusException {
|
||||
protected void validateRepeatedField(final Range range, final Descriptors.FieldDescriptor fd, final Message msg) throws StatusException {
|
||||
final int size = msg.getRepeatedFieldCount(fd);
|
||||
if (size < range.min() || size > range.max()) {
|
||||
throw invalidArgument("field value is [%d] but expected to be within the [%d, %d] range".formatted(
|
||||
|
|
|
@ -10,16 +10,12 @@ import static java.util.Objects.requireNonNull;
|
|||
import io.lettuce.core.ScriptOutputType;
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.time.Clock;
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||
|
@ -27,25 +23,18 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
|||
|
||||
public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
|
||||
|
||||
private final Map<T, RateLimiter> rateLimiterByDescriptor;
|
||||
|
||||
private final Map<String, RateLimiterConfig> configs;
|
||||
|
||||
|
||||
protected BaseRateLimiters(
|
||||
final T[] values,
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cacheCluster,
|
||||
final Clock clock) {
|
||||
this.configs = configs;
|
||||
this.rateLimiterByDescriptor = Arrays.stream(values)
|
||||
.map(descriptor -> Pair.of(
|
||||
descriptor,
|
||||
createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
|
||||
createForDescriptor(descriptor, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
|
||||
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
|
||||
}
|
||||
|
||||
|
@ -53,22 +42,6 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
|||
return requireNonNull(rateLimiterByDescriptor.get(handle));
|
||||
}
|
||||
|
||||
public void validateValuesAndConfigs() {
|
||||
final Set<String> ids = rateLimiterByDescriptor.keySet().stream()
|
||||
.map(RateLimiterDescriptor::id)
|
||||
.collect(Collectors.toSet());
|
||||
for (final String key: configs.keySet()) {
|
||||
if (!ids.contains(key)) {
|
||||
final String message = String.format(
|
||||
"Static configuration has an unexpected field '%s' that doesn't match any RateLimiterDescriptor",
|
||||
key
|
||||
);
|
||||
logger.error(message);
|
||||
throw new IllegalArgumentException(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected static ClusterLuaScript defaultScript(final FaultTolerantRedisClusterClient cacheCluster) {
|
||||
try {
|
||||
return ClusterLuaScript.fromResource(
|
||||
|
@ -80,21 +53,12 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
|||
|
||||
private static RateLimiter createForDescriptor(
|
||||
final RateLimiterDescriptor descriptor,
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cacheCluster,
|
||||
final Clock clock) {
|
||||
if (descriptor.isDynamic()) {
|
||||
final Supplier<RateLimiterConfig> configResolver = () -> {
|
||||
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
|
||||
return config != null
|
||||
? config
|
||||
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
||||
};
|
||||
return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock);
|
||||
}
|
||||
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
||||
return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock, dynamicConfigurationManager);
|
||||
final Supplier<RateLimiterConfig> configResolver =
|
||||
() -> dynamicConfigurationManager.getConfiguration().getLimits().getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
||||
return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,87 +7,167 @@ package org.whispersystems.textsecuregcm.limits;
|
|||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.Supplier;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
public class DynamicRateLimiter implements RateLimiter {
|
||||
|
||||
private final String name;
|
||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
|
||||
private final Supplier<RateLimiterConfig> configResolver;
|
||||
|
||||
private final ClusterLuaScript validateScript;
|
||||
|
||||
private final FaultTolerantRedisClusterClient cluster;
|
||||
|
||||
private final Clock clock;
|
||||
private final Counter limitExceededCounter;
|
||||
|
||||
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
|
||||
private final Clock clock;
|
||||
|
||||
|
||||
public DynamicRateLimiter(
|
||||
final String name,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final Supplier<RateLimiterConfig> configResolver,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cluster,
|
||||
final Clock clock) {
|
||||
this.name = requireNonNull(name);
|
||||
this.dynamicConfigurationManager = dynamicConfigurationManager;
|
||||
this.configResolver = requireNonNull(configResolver);
|
||||
this.validateScript = requireNonNull(validateScript);
|
||||
this.cluster = requireNonNull(cluster);
|
||||
this.clock = requireNonNull(clock);
|
||||
this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(final String key, final int amount) throws RateLimitExceededException {
|
||||
current().getRight().validate(key, amount);
|
||||
final RateLimiterConfig config = config();
|
||||
try {
|
||||
final long deficitPermitsAmount = executeValidateScript(config, key, amount, true);
|
||||
if (deficitPermitsAmount > 0) {
|
||||
limitExceededCounter.increment();
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
throw new RateLimitExceededException(retryAfter);
|
||||
}
|
||||
} catch (final Exception e) {
|
||||
if (e instanceof RateLimitExceededException rateLimitExceededException) {
|
||||
throw rateLimitExceededException;
|
||||
}
|
||||
|
||||
if (!config.failOpen()) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> validateAsync(final String key, final int amount) {
|
||||
return current().getRight().validateAsync(key, amount);
|
||||
final RateLimiterConfig config = config();
|
||||
|
||||
return executeValidateScriptAsync(config, key, amount, true)
|
||||
.thenCompose(deficitPermitsAmount -> {
|
||||
if (deficitPermitsAmount == 0) {
|
||||
return CompletableFuture.completedFuture((Void) null);
|
||||
}
|
||||
limitExceededCounter.increment();
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
return CompletableFuture.failedFuture(new RateLimitExceededException(retryAfter));
|
||||
})
|
||||
.exceptionally(throwable -> {
|
||||
if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) {
|
||||
throw ExceptionUtils.wrap(rateLimitExceededException);
|
||||
}
|
||||
|
||||
if (config.failOpen()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasAvailablePermits(final String key, final int permits) {
|
||||
return current().getRight().hasAvailablePermits(key, permits);
|
||||
final RateLimiterConfig config = config();
|
||||
try {
|
||||
final long deficitPermitsAmount = executeValidateScript(config, key, permits, false);
|
||||
return deficitPermitsAmount == 0;
|
||||
} catch (final Exception e) {
|
||||
if (config.failOpen()) {
|
||||
return true;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
|
||||
return current().getRight().hasAvailablePermitsAsync(key, amount);
|
||||
final RateLimiterConfig config = config();
|
||||
return executeValidateScriptAsync(config, key, amount, false)
|
||||
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
|
||||
.exceptionally(throwable -> {
|
||||
if (config.failOpen()) {
|
||||
return true;
|
||||
}
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear(final String key) {
|
||||
current().getRight().clear(key);
|
||||
cluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> clearAsync(final String key) {
|
||||
return current().getRight().clearAsync(key);
|
||||
return cluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
|
||||
.thenRun(Util.NOOP);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimiterConfig config() {
|
||||
return current().getLeft();
|
||||
return configResolver.get();
|
||||
}
|
||||
|
||||
private Pair<RateLimiterConfig, RateLimiter> current() {
|
||||
final RateLimiterConfig cfg = configResolver.get();
|
||||
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
|
||||
? p
|
||||
: Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock, dynamicConfigurationManager))
|
||||
private long executeValidateScript(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return (Long) validateScript.execute(keys, arguments);
|
||||
}
|
||||
|
||||
private CompletionStage<Long> executeValidateScriptAsync(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
|
||||
}
|
||||
|
||||
private static String bucketName(final String name, final String key) {
|
||||
return "leaky_bucket::" + name + "::" + key;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.limits;
|
|||
import jakarta.validation.constraints.AssertTrue;
|
||||
import java.time.Duration;
|
||||
|
||||
public record RateLimiterConfig(int bucketSize, Duration permitRegenerationDuration) {
|
||||
public record RateLimiterConfig(int bucketSize, Duration permitRegenerationDuration, boolean failOpen) {
|
||||
|
||||
public double leakRatePerMillis() {
|
||||
return 1.0 / (permitRegenerationDuration.toNanos() / 1e6);
|
||||
|
|
|
@ -15,14 +15,9 @@ public interface RateLimiterDescriptor {
|
|||
*/
|
||||
String id();
|
||||
|
||||
/**
|
||||
* @return {@code true} if this rate limiter needs to watch for dynamic configuration changes.
|
||||
*/
|
||||
boolean isDynamic();
|
||||
|
||||
/**
|
||||
* @return an instance of {@link RateLimiterConfig} to be used by default,
|
||||
* i.e. if there is no overrides in the application configuration files (static or dynamic).
|
||||
* i.e. if there is no override in the application dynamic configuration.
|
||||
*/
|
||||
RateLimiterConfig defaultConfig();
|
||||
}
|
||||
|
|
|
@ -4,11 +4,9 @@
|
|||
*/
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.Map;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||
|
@ -17,59 +15,54 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
|||
public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
||||
|
||||
public enum For implements RateLimiterDescriptor {
|
||||
BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
PIN("pin", false, new RateLimiterConfig(10, Duration.ofDays(1))),
|
||||
ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, Duration.ofMillis(1200))),
|
||||
BACKUP_ATTACHMENT("backupAttachmentCreate", true, new RateLimiterConfig(10_000, Duration.ofSeconds(1))),
|
||||
PRE_KEYS("prekeys", false, new RateLimiterConfig(6, Duration.ofMinutes(10))),
|
||||
MESSAGES("messages", false, new RateLimiterConfig(60, Duration.ofSeconds(1))),
|
||||
STORIES("stories", false, new RateLimiterConfig(5_000, Duration.ofSeconds(8))),
|
||||
ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2))),
|
||||
VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2))),
|
||||
TURN("turnAllocate", false, new RateLimiterConfig(60, Duration.ofSeconds(1))),
|
||||
PROFILE("profile", false, new RateLimiterConfig(4320, Duration.ofSeconds(20))),
|
||||
STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, Duration.ofMinutes(72))),
|
||||
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1))),
|
||||
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15))),
|
||||
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofSeconds(4))),
|
||||
REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofSeconds(30))),
|
||||
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofSeconds(30))),
|
||||
VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, Duration.ofSeconds(30))),
|
||||
RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, Duration.ofHours(12))),
|
||||
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))),
|
||||
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))),
|
||||
SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(10, Duration.ofHours(1))),
|
||||
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", true, new RateLimiterConfig(5, Duration.ofDays(7))),
|
||||
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))),
|
||||
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))),
|
||||
GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10))),
|
||||
CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000))),
|
||||
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))),
|
||||
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", true,
|
||||
new RateLimiterConfig(100, Duration.ofSeconds(15))),
|
||||
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
|
||||
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
|
||||
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
|
||||
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1))),
|
||||
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
|
||||
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))),
|
||||
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))),
|
||||
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", true, new RateLimiterConfig(10, Duration.ofMinutes(1))),
|
||||
BACKUP_AUTH_CHECK("backupAuthCheck", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
|
||||
PIN("pin", new RateLimiterConfig(10, Duration.ofDays(1), false)),
|
||||
ATTACHMENT("attachmentCreate", new RateLimiterConfig(50, Duration.ofMillis(1200), true)),
|
||||
BACKUP_ATTACHMENT("backupAttachmentCreate", new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)),
|
||||
PRE_KEYS("prekeys", new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
|
||||
MESSAGES("messages", new RateLimiterConfig(60, Duration.ofSeconds(1), true)),
|
||||
STORIES("stories", new RateLimiterConfig(5_000, Duration.ofSeconds(8), true)),
|
||||
ALLOCATE_DEVICE("allocateDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
|
||||
VERIFY_DEVICE("verifyDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
|
||||
PROFILE("profile", new RateLimiterConfig(4320, Duration.ofSeconds(20), true)),
|
||||
STICKER_PACK("stickerPack", new RateLimiterConfig(50, Duration.ofMinutes(72), false)),
|
||||
USERNAME_LOOKUP("usernameLookup", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
|
||||
USERNAME_SET("usernameSet", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||
USERNAME_RESERVE("usernameReserve", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||
USERNAME_LINK_OPERATION("usernameLinkOperation", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
||||
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", new RateLimiterConfig(1000, Duration.ofSeconds(4), true)),
|
||||
REGISTRATION("registration", new RateLimiterConfig(6, Duration.ofSeconds(30), false)),
|
||||
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", new RateLimiterConfig(5, Duration.ofSeconds(30), false)),
|
||||
VERIFICATION_CAPTCHA("verificationCaptcha", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
||||
RATE_LIMIT_RESET("rateLimitReset", new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
||||
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
|
||||
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
||||
SET_BACKUP_ID("setBackupId", new RateLimiterConfig(10, Duration.ofHours(1), false)),
|
||||
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", new RateLimiterConfig(5, Duration.ofDays(7), false)),
|
||||
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
|
||||
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
||||
GET_CALLING_RELAYS("getCallingRelays", new RateLimiterConfig(100, Duration.ofMinutes(10), false)),
|
||||
CREATE_CALL_LINK("createCallLink", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||
INBOUND_MESSAGE_BYTES("inboundMessageBytes", new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), true)),
|
||||
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
||||
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
||||
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
||||
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
|
||||
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
|
||||
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
||||
;
|
||||
|
||||
private final String id;
|
||||
|
||||
private final boolean dynamic;
|
||||
|
||||
private final RateLimiterConfig defaultConfig;
|
||||
|
||||
For(final String id, final boolean dynamic, final RateLimiterConfig defaultConfig) {
|
||||
For(final String id, final RateLimiterConfig defaultConfig) {
|
||||
this.id = id;
|
||||
this.dynamic = dynamic;
|
||||
this.defaultConfig = defaultConfig;
|
||||
}
|
||||
|
||||
|
@ -77,34 +70,25 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
|||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDynamic() {
|
||||
return dynamic;
|
||||
}
|
||||
|
||||
public RateLimiterConfig defaultConfig() {
|
||||
return defaultConfig;
|
||||
}
|
||||
}
|
||||
|
||||
public static RateLimiters createAndValidate(
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
public static RateLimiters create(
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final FaultTolerantRedisClusterClient cacheCluster) {
|
||||
final RateLimiters rateLimiters = new RateLimiters(
|
||||
configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
|
||||
rateLimiters.validateValuesAndConfigs();
|
||||
return rateLimiters;
|
||||
return new RateLimiters(
|
||||
dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
RateLimiters(
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cacheCluster,
|
||||
final Clock clock) {
|
||||
super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock);
|
||||
super(For.values(), dynamicConfigurationManager, validateScript, cacheCluster, clock);
|
||||
}
|
||||
|
||||
public RateLimiter getAllocateDeviceLimiter() {
|
||||
|
@ -131,10 +115,6 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
|||
return forDescriptor(For.PIN);
|
||||
}
|
||||
|
||||
public RateLimiter getTurnLimiter() {
|
||||
return forDescriptor(For.TURN);
|
||||
}
|
||||
|
||||
public RateLimiter getProfileLimiter() {
|
||||
return forDescriptor(For.PROFILE);
|
||||
}
|
||||
|
|
|
@ -1,171 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
import static java.util.concurrent.CompletableFuture.completedFuture;
|
||||
import static java.util.concurrent.CompletableFuture.failedFuture;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.lettuce.core.RedisException;
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
public class StaticRateLimiter implements RateLimiter {
|
||||
|
||||
protected final String name;
|
||||
|
||||
private final RateLimiterConfig config;
|
||||
|
||||
private final Counter counter;
|
||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
|
||||
|
||||
private final ClusterLuaScript validateScript;
|
||||
|
||||
private final FaultTolerantRedisClusterClient cacheCluster;
|
||||
|
||||
private final Clock clock;
|
||||
|
||||
|
||||
public StaticRateLimiter(
|
||||
final String name,
|
||||
final RateLimiterConfig config,
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisClusterClient cacheCluster,
|
||||
final Clock clock,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
|
||||
this.name = requireNonNull(name);
|
||||
this.config = requireNonNull(config);
|
||||
this.validateScript = requireNonNull(validateScript);
|
||||
this.cacheCluster = requireNonNull(cacheCluster);
|
||||
this.clock = requireNonNull(clock);
|
||||
this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
|
||||
this.dynamicConfigurationManager = dynamicConfigurationManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(final String key, final int amount) throws RateLimitExceededException {
|
||||
try {
|
||||
final long deficitPermitsAmount = executeValidateScript(key, amount, true);
|
||||
if (deficitPermitsAmount > 0) {
|
||||
counter.increment();
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
throw new RateLimitExceededException(retryAfter);
|
||||
}
|
||||
} catch (RedisException e) {
|
||||
if (!failOpen()) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> validateAsync(final String key, final int amount) {
|
||||
return executeValidateScriptAsync(key, amount, true)
|
||||
.thenCompose(deficitPermitsAmount -> {
|
||||
if (deficitPermitsAmount == 0) {
|
||||
return completedFuture((Void) null);
|
||||
}
|
||||
counter.increment();
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
return failedFuture(new RateLimitExceededException(retryAfter));
|
||||
})
|
||||
.exceptionally(throwable -> {
|
||||
if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) {
|
||||
return null;
|
||||
}
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasAvailablePermits(final String key, final int amount) {
|
||||
try {
|
||||
final long deficitPermitsAmount = executeValidateScript(key, amount, false);
|
||||
return deficitPermitsAmount == 0;
|
||||
} catch (RedisException e) {
|
||||
if (failOpen()) {
|
||||
return true;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
|
||||
return executeValidateScriptAsync(key, amount, false)
|
||||
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
|
||||
.exceptionally(throwable -> {
|
||||
if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) {
|
||||
return true;
|
||||
}
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear(final String key) {
|
||||
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> clearAsync(final String key) {
|
||||
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
|
||||
.thenRun(Util.NOOP);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimiterConfig config() {
|
||||
return config;
|
||||
}
|
||||
|
||||
private boolean failOpen() {
|
||||
return this.dynamicConfigurationManager.getConfiguration().getRateLimitPolicy().failOpen();
|
||||
}
|
||||
|
||||
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return (Long) validateScript.execute(keys, arguments);
|
||||
}
|
||||
|
||||
private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
protected static String bucketName(final String name, final String key) {
|
||||
return "leaky_bucket::" + name + "::" + key;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.metrics;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
|
||||
import java.util.Optional;
|
||||
|
||||
public class DevicePlatformUtil {
|
||||
|
||||
private DevicePlatformUtil() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the most likely client platform for a device.
|
||||
*
|
||||
* @param device the device for which to find a client platform
|
||||
*
|
||||
* @return the most likely client platform for the given device or empty if no likely platform could be determined
|
||||
*/
|
||||
public static Optional<ClientPlatform> getDevicePlatform(final Device device) {
|
||||
final Optional<ClientPlatform> clientPlatform;
|
||||
|
||||
if (StringUtils.isNotBlank(device.getGcmId())) {
|
||||
clientPlatform = Optional.of(ClientPlatform.ANDROID);
|
||||
} else if (StringUtils.isNotBlank(device.getApnId())) {
|
||||
clientPlatform = Optional.of(ClientPlatform.IOS);
|
||||
} else {
|
||||
clientPlatform = Optional.empty();
|
||||
}
|
||||
|
||||
return clientPlatform.or(() -> Optional.ofNullable(
|
||||
switch (device.getUserAgent()) {
|
||||
case "OWA" -> ClientPlatform.ANDROID;
|
||||
case "OWI", "OWP" -> ClientPlatform.IOS;
|
||||
case "OWD" -> ClientPlatform.DESKTOP;
|
||||
case null, default -> null;
|
||||
}));
|
||||
}
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.metrics;
|
||||
|
||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||
|
||||
import com.sun.management.OperatingSystemMXBean;
|
||||
import io.micrometer.core.instrument.Gauge;
|
||||
import io.micrometer.core.instrument.MeterRegistry;
|
||||
import io.micrometer.core.instrument.binder.MeterBinder;
|
||||
import java.lang.management.ManagementFactory;
|
||||
|
||||
public class FreeMemoryGauge implements MeterBinder {
|
||||
|
||||
private final OperatingSystemMXBean operatingSystemMXBean;
|
||||
|
||||
public FreeMemoryGauge() {
|
||||
this.operatingSystemMXBean = (com.sun.management.OperatingSystemMXBean)
|
||||
ManagementFactory.getOperatingSystemMXBean();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void bindTo(final MeterRegistry registry) {
|
||||
Gauge.builder(name(FreeMemoryGauge.class, "freeMemory"), operatingSystemMXBean,
|
||||
OperatingSystemMXBean::getFreeMemorySize)
|
||||
.register(registry);
|
||||
|
||||
}
|
||||
}
|
|
@ -120,10 +120,7 @@ public class MetricsUtil {
|
|||
|
||||
public static void registerSystemResourceMetrics(final Environment environment) {
|
||||
new ProcessorMetrics().bindTo(Metrics.globalRegistry);
|
||||
new FreeMemoryGauge().bindTo(Metrics.globalRegistry);
|
||||
new FileDescriptorMetrics().bindTo(Metrics.globalRegistry);
|
||||
new OperatingSystemMemoryGauge("Buffers").bindTo(Metrics.globalRegistry);
|
||||
new OperatingSystemMemoryGauge("Cached").bindTo(Metrics.globalRegistry);
|
||||
|
||||
new JvmMemoryMetrics().bindTo(Metrics.globalRegistry);
|
||||
new JvmThreadMetrics().bindTo(Metrics.globalRegistry);
|
||||
|
|
|
@ -67,7 +67,7 @@ public class OpenWebSocketCounter {
|
|||
|
||||
try {
|
||||
final ClientPlatform clientPlatform =
|
||||
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).getPlatform();
|
||||
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).platform();
|
||||
|
||||
calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform);
|
||||
calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform);
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.metrics;
|
||||
|
||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.micrometer.core.instrument.Gauge;
|
||||
import io.micrometer.core.instrument.MeterRegistry;
|
||||
import io.micrometer.core.instrument.binder.MeterBinder;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class OperatingSystemMemoryGauge implements MeterBinder {
|
||||
|
||||
private final String metricName;
|
||||
|
||||
private static final File MEMINFO_FILE = new File("/proc/meminfo");
|
||||
private static final Pattern MEMORY_METRIC_PATTERN = Pattern.compile("^([^:]+):\\s+([0-9]+).*$");
|
||||
|
||||
public OperatingSystemMemoryGauge(final String metricName) {
|
||||
this.metricName = metricName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void bindTo(MeterRegistry registry) {
|
||||
final String metricName = this.metricName;
|
||||
Gauge.builder(name(OperatingSystemMemoryGauge.class, metricName.toLowerCase(Locale.ROOT)), () -> {
|
||||
try (final BufferedReader bufferedReader = new BufferedReader(new FileReader(MEMINFO_FILE))) {
|
||||
return getValue(bufferedReader.lines(), metricName);
|
||||
} catch (final IOException e) {
|
||||
return 0L;
|
||||
}
|
||||
})
|
||||
.register(registry);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static double getValue(final Stream<String> lines, final String metricName) {
|
||||
return lines.map(MEMORY_METRIC_PATTERN::matcher)
|
||||
.filter(Matcher::matches)
|
||||
.filter(matcher -> metricName.equalsIgnoreCase(matcher.group(1)))
|
||||
.map(matcher -> Double.parseDouble(matcher.group(2)))
|
||||
.findFirst()
|
||||
.orElse(0d);
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ import io.micrometer.core.instrument.Tag;
|
|||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.whispersystems.textsecuregcm.WhisperServerVersion;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||
|
@ -48,15 +49,15 @@ public class UserAgentTagUtil {
|
|||
}
|
||||
|
||||
public static Tag getPlatformTag(@Nullable final UserAgent userAgent) {
|
||||
return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.getPlatform().name().toLowerCase() : "unrecognized");
|
||||
return Tag.of(PLATFORM_TAG, userAgent != null ? userAgent.platform().name().toLowerCase() : "unrecognized");
|
||||
}
|
||||
|
||||
public static Optional<Tag> getClientVersionTag(final String userAgentString, final ClientReleaseManager clientReleaseManager) {
|
||||
try {
|
||||
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
|
||||
|
||||
if (clientReleaseManager.isVersionActive(userAgent.getPlatform(), userAgent.getVersion())) {
|
||||
return Optional.of(Tag.of(VERSION_TAG, userAgent.getVersion().toString()));
|
||||
if (clientReleaseManager.isVersionActive(userAgent.platform(), userAgent.version())) {
|
||||
return Optional.of(Tag.of(VERSION_TAG, userAgent.version().toString()));
|
||||
}
|
||||
} catch (final UnrecognizedUserAgentException ignored) {
|
||||
}
|
||||
|
@ -70,10 +71,8 @@ public class UserAgentTagUtil {
|
|||
|
||||
try {
|
||||
final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString);
|
||||
platform = userAgent.getPlatform().name().toLowerCase();
|
||||
libsignal = userAgent.getAdditionalSpecifiers()
|
||||
.map(additionalSpecifiers -> additionalSpecifiers.contains("libsignal"))
|
||||
.orElse(false);
|
||||
platform = userAgent.platform().name().toLowerCase();
|
||||
libsignal = StringUtils.contains(userAgent.additionalSpecifiers(), "libsignal");
|
||||
} catch (final UnrecognizedUserAgentException e) {
|
||||
platform = "unrecognized";
|
||||
libsignal = false;
|
||||
|
|
|
@ -13,7 +13,6 @@ import io.micrometer.core.instrument.DistributionSummary;
|
|||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.Tag;
|
||||
import io.micrometer.core.instrument.Tags;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
|
@ -21,6 +20,7 @@ import java.util.Optional;
|
|||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.signal.libsignal.protocol.util.Pair;
|
||||
|
@ -28,6 +28,7 @@ import org.whispersystems.textsecuregcm.controllers.MessageController;
|
|||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
|
@ -36,7 +37,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
|
|||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
|
||||
|
@ -53,10 +53,12 @@ public class MessageSender {
|
|||
|
||||
private final MessagesManager messagesManager;
|
||||
private final PushNotificationManager pushNotificationManager;
|
||||
private final ExperimentEnrollmentManager experimentEnrollmentManager;
|
||||
|
||||
// Note that these names deliberately reference `MessageController` for metric continuity
|
||||
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage");
|
||||
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize");
|
||||
private static final String EMPTY_MESSAGE_LIST_COUNTER_NAME = MetricsUtil.name(MessageSender.class, "emptyMessageList");
|
||||
|
||||
private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage");
|
||||
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
|
||||
|
@ -65,6 +67,7 @@ public class MessageSender {
|
|||
private static final String STORY_TAG_NAME = "story";
|
||||
private static final String SEALED_SENDER_TAG_NAME = "sealedSender";
|
||||
private static final String MULTI_RECIPIENT_TAG_NAME = "multiRecipient";
|
||||
private static final String SYNC_MESSAGE_TAG_NAME = "sync";
|
||||
|
||||
@VisibleForTesting
|
||||
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
|
||||
|
@ -72,9 +75,13 @@ public class MessageSender {
|
|||
@VisibleForTesting
|
||||
static final byte NO_EXCLUDED_DEVICE_ID = -1;
|
||||
|
||||
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
|
||||
public MessageSender(
|
||||
final MessagesManager messagesManager,
|
||||
final PushNotificationManager pushNotificationManager,
|
||||
final ExperimentEnrollmentManager experimentEnrollmentManager) {
|
||||
this.messagesManager = messagesManager;
|
||||
this.pushNotificationManager = pushNotificationManager;
|
||||
this.experimentEnrollmentManager = experimentEnrollmentManager;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -86,6 +93,8 @@ public class MessageSender {
|
|||
* @param destinationIdentifier the service identifier to which the messages are addressed
|
||||
* @param messagesByDeviceId a map of device IDs to message payloads
|
||||
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
|
||||
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
|
||||
* caller's own account), contains the ID of the device that sent the message
|
||||
* @param userAgent the User-Agent string for the sender; may be {@code null} if not known
|
||||
*
|
||||
* @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required
|
||||
|
@ -97,38 +106,48 @@ public class MessageSender {
|
|||
final ServiceIdentifier destinationIdentifier,
|
||||
final Map<Byte, Envelope> messagesByDeviceId,
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId,
|
||||
@Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException {
|
||||
|
||||
if (messagesByDeviceId.isEmpty()) {
|
||||
// TODO Simply return and don't throw an exception when iOS clients no longer depend on this behavior
|
||||
throw new MismatchedDevicesException(new MismatchedDevices(
|
||||
destination.getDevices().stream().map(Device::getId).collect(Collectors.toSet()),
|
||||
Collections.emptySet(),
|
||||
Collections.emptySet()));
|
||||
}
|
||||
|
||||
if (!destination.isIdentifiedBy(destinationIdentifier)) {
|
||||
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
|
||||
}
|
||||
|
||||
final Envelope firstMessage = messagesByDeviceId.values().iterator().next();
|
||||
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
|
||||
|
||||
final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) &&
|
||||
destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId()));
|
||||
if (messagesByDeviceId.isEmpty()) {
|
||||
Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME,
|
||||
Tags.of(SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment();
|
||||
}
|
||||
|
||||
final boolean isStory = firstMessage.getStory();
|
||||
final byte excludedDeviceId;
|
||||
if (syncMessageSenderDeviceId.isPresent()) {
|
||||
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) ||
|
||||
!destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
|
||||
|
||||
validateIndividualMessageContentLength(messagesByDeviceId.values(), isSyncMessage, isStory, userAgent);
|
||||
throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender");
|
||||
}
|
||||
excludedDeviceId = syncMessageSenderDeviceId.get();
|
||||
} else {
|
||||
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) &&
|
||||
destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
|
||||
|
||||
throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender");
|
||||
}
|
||||
excludedDeviceId = NO_EXCLUDED_DEVICE_ID;
|
||||
}
|
||||
|
||||
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
|
||||
destinationIdentifier,
|
||||
registrationIdsByDeviceId,
|
||||
isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID);
|
||||
excludedDeviceId);
|
||||
|
||||
if (maybeMismatchedDevices.isPresent()) {
|
||||
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
|
||||
}
|
||||
|
||||
validateIndividualMessageContentLength(messagesByDeviceId.values(), syncMessageSenderDeviceId.isPresent(), userAgent);
|
||||
|
||||
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
|
||||
.forEach((deviceId, destinationPresent) -> {
|
||||
final Envelope message = messagesByDeviceId.get(deviceId);
|
||||
|
@ -146,8 +165,9 @@ public class MessageSender {
|
|||
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
|
||||
STORY_TAG_NAME, String.valueOf(message.getStory()),
|
||||
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()),
|
||||
SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent()),
|
||||
MULTI_RECIPIENT_TAG_NAME, "false")
|
||||
.and(UserAgentTagUtil.getPlatformTag(userAgent));
|
||||
.and(platformTag);
|
||||
|
||||
Metrics.counter(SEND_COUNTER_NAME, tags).increment();
|
||||
});
|
||||
|
@ -230,6 +250,7 @@ public class MessageSender {
|
|||
URGENT_TAG_NAME, String.valueOf(isUrgent),
|
||||
STORY_TAG_NAME, String.valueOf(isStory),
|
||||
SEALED_SENDER_TAG_NAME, "true",
|
||||
SYNC_MESSAGE_TAG_NAME, "false",
|
||||
MULTI_RECIPIENT_TAG_NAME, "true")
|
||||
.and(UserAgentTagUtil.getPlatformTag(userAgent));
|
||||
|
||||
|
@ -295,11 +316,7 @@ public class MessageSender {
|
|||
// We know the device must be present because we've already filtered out device IDs that aren't associated
|
||||
// with the given account
|
||||
final Device device = account.getDevice(deviceId).orElseThrow();
|
||||
|
||||
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
|
||||
};
|
||||
final int expectedRegistrationId = device.getRegistrationId(serviceIdentifier.identityType());
|
||||
|
||||
return registrationId != expectedRegistrationId;
|
||||
})
|
||||
|
@ -313,14 +330,13 @@ public class MessageSender {
|
|||
|
||||
private static void validateIndividualMessageContentLength(final Iterable<Envelope> messages,
|
||||
final boolean isSyncMessage,
|
||||
final boolean isStory,
|
||||
@Nullable final String userAgent) throws MessageTooLargeException {
|
||||
|
||||
for (final Envelope message : messages) {
|
||||
MessageSender.validateContentLength(message.getContent().size(),
|
||||
false,
|
||||
isSyncMessage,
|
||||
isStory,
|
||||
message.getStory(),
|
||||
userAgent);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ import java.util.function.BiConsumer;
|
|||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
@ -28,6 +29,9 @@ public class PushNotificationManager {
|
|||
private final APNSender apnSender;
|
||||
private final FcmSender fcmSender;
|
||||
private final PushNotificationScheduler pushNotificationScheduler;
|
||||
private final ExperimentEnrollmentManager experimentEnrollmentManager;
|
||||
|
||||
public static final String SCHEDULE_LOW_URGENCY_FCM_PUSH_EXPERIMENT = "scheduleLowUregencyFcmPush";
|
||||
|
||||
private static final String SENT_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "sentPushNotification");
|
||||
private static final String FAILED_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "failedPushNotification");
|
||||
|
@ -38,12 +42,14 @@ public class PushNotificationManager {
|
|||
public PushNotificationManager(final AccountsManager accountsManager,
|
||||
final APNSender apnSender,
|
||||
final FcmSender fcmSender,
|
||||
final PushNotificationScheduler pushNotificationScheduler) {
|
||||
final PushNotificationScheduler pushNotificationScheduler,
|
||||
final ExperimentEnrollmentManager experimentEnrollmentManager) {
|
||||
|
||||
this.accountsManager = accountsManager;
|
||||
this.apnSender = apnSender;
|
||||
this.fcmSender = fcmSender;
|
||||
this.pushNotificationScheduler = pushNotificationScheduler;
|
||||
this.experimentEnrollmentManager = experimentEnrollmentManager;
|
||||
}
|
||||
|
||||
public CompletableFuture<Optional<SendPushNotificationResult>> sendNewMessageNotification(final Account destination, final byte destinationDeviceId, final boolean urgent) throws NotPushRegisteredException {
|
||||
|
@ -101,11 +107,11 @@ public class PushNotificationManager {
|
|||
|
||||
@VisibleForTesting
|
||||
CompletableFuture<Optional<SendPushNotificationResult>> sendNotification(final PushNotification pushNotification) {
|
||||
if (pushNotification.tokenType() == PushNotification.TokenType.APN && !pushNotification.urgent()) {
|
||||
// APNs imposes a per-device limit on background push notifications; schedule a notification for some time in the
|
||||
// future (possibly even now!) rather than sending a notification directly
|
||||
if (shouldScheduleNotification(pushNotification)) {
|
||||
// Schedule a notification for some time in the future (possibly even now!) rather than sending a notification
|
||||
// directly
|
||||
return pushNotificationScheduler
|
||||
.scheduleBackgroundApnsNotification(pushNotification.destination(), pushNotification.destinationDevice())
|
||||
.scheduleBackgroundNotification(pushNotification.tokenType(), pushNotification.destination(), pushNotification.destinationDevice())
|
||||
.whenComplete(logErrors())
|
||||
.thenApply(ignored -> Optional.<SendPushNotificationResult>empty())
|
||||
.toCompletableFuture();
|
||||
|
@ -150,6 +156,16 @@ public class PushNotificationManager {
|
|||
.thenApply(Optional::of);
|
||||
}
|
||||
|
||||
private boolean shouldScheduleNotification(final PushNotification pushNotification) {
|
||||
return !pushNotification.urgent() && switch (pushNotification.tokenType()) {
|
||||
// APNs imposes a per-device limit on background push notifications
|
||||
case APN -> true;
|
||||
case FCM -> experimentEnrollmentManager.isEnrolled(
|
||||
pushNotification.destination().getUuid(),
|
||||
SCHEDULE_LOW_URGENCY_FCM_PUSH_EXPERIMENT);
|
||||
};
|
||||
}
|
||||
|
||||
private static <T> BiConsumer<T, Throwable> logErrors() {
|
||||
return (ignored, throwable) -> {
|
||||
if (throwable != null) {
|
||||
|
|
|
@ -13,7 +13,6 @@ import io.lettuce.core.Range;
|
|||
import io.lettuce.core.ScriptOutputType;
|
||||
import io.lettuce.core.SetArgs;
|
||||
import io.lettuce.core.cluster.SlotHash;
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.io.IOException;
|
||||
import java.time.Clock;
|
||||
|
@ -44,14 +43,15 @@ public class PushNotificationScheduler implements Managed {
|
|||
|
||||
private static final Logger logger = LoggerFactory.getLogger(PushNotificationScheduler.class);
|
||||
|
||||
private static final String PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN";
|
||||
private static final String PENDING_BACKGROUND_APN_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN";
|
||||
private static final String PENDING_BACKGROUND_FCM_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_FCM";
|
||||
private static final String LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX = "LAST_BACKGROUND_NOTIFICATION";
|
||||
private static final String PENDING_DELAYED_NOTIFICATIONS_KEY_PREFIX = "DELAYED";
|
||||
|
||||
@VisibleForTesting
|
||||
static final String NEXT_SLOT_TO_PROCESS_KEY = "pending_notification_next_slot";
|
||||
|
||||
private static final Counter BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER = Metrics.counter(name(PushNotificationScheduler.class, "backgroundNotification", "scheduled"));
|
||||
private static final String BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER_NAME = name(PushNotificationScheduler.class, "backgroundNotification", "scheduled");
|
||||
private static final String BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME = name(PushNotificationScheduler.class, "backgroundNotification", "sent");
|
||||
|
||||
private static final String DELAYED_NOTIFICATION_SCHEDULED_COUNTER_NAME = name(PushNotificationScheduler.class, "delayedNotificationScheduled");
|
||||
|
@ -65,7 +65,7 @@ public class PushNotificationScheduler implements Managed {
|
|||
private final FaultTolerantRedisClusterClient pushSchedulingCluster;
|
||||
private final Clock clock;
|
||||
|
||||
private final ClusterLuaScript scheduleBackgroundApnsNotificationScript;
|
||||
private final ClusterLuaScript scheduleBackgroundNotificationScript;
|
||||
|
||||
private final Thread[] workerThreads;
|
||||
|
||||
|
@ -103,15 +103,18 @@ public class PushNotificationScheduler implements Managed {
|
|||
final int slot = (int) (pushSchedulingCluster.withCluster(connection ->
|
||||
connection.sync().incr(NEXT_SLOT_TO_PROCESS_KEY)) % SlotHash.SLOT_COUNT);
|
||||
|
||||
return processScheduledBackgroundApnsNotifications(slot) + processScheduledDelayedNotifications(slot);
|
||||
return processScheduledBackgroundNotifications(PushNotification.TokenType.APN, slot)
|
||||
+ processScheduledBackgroundNotifications(PushNotification.TokenType.FCM, slot)
|
||||
+ processScheduledDelayedNotifications(slot);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
long processScheduledBackgroundApnsNotifications(final int slot) {
|
||||
return processScheduledNotifications(getPendingBackgroundApnsNotificationQueueKey(slot),
|
||||
PushNotificationScheduler.this::sendBackgroundApnsNotification);
|
||||
long processScheduledBackgroundNotifications(PushNotification.TokenType tokenType, final int slot) {
|
||||
return processScheduledNotifications(getPendingBackgroundNotificationQueueKey(tokenType, slot),
|
||||
(account, device) -> sendBackgroundNotification(tokenType, account, device));
|
||||
}
|
||||
|
||||
|
||||
@VisibleForTesting
|
||||
long processScheduledDelayedNotifications(final int slot) {
|
||||
return processScheduledNotifications(getDelayedNotificationQueueKey(slot),
|
||||
|
@ -172,7 +175,7 @@ public class PushNotificationScheduler implements Managed {
|
|||
this.pushSchedulingCluster = pushSchedulingCluster;
|
||||
this.clock = clock;
|
||||
|
||||
this.scheduleBackgroundApnsNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster,
|
||||
this.scheduleBackgroundNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster,
|
||||
"lua/apn/schedule_background_notification.lua", ScriptOutputType.VALUE);
|
||||
|
||||
this.workerThreads = new Thread[dedicatedProcessThreadCount];
|
||||
|
@ -183,23 +186,21 @@ public class PushNotificationScheduler implements Managed {
|
|||
}
|
||||
|
||||
/**
|
||||
* Schedule a background APNs notification to be sent some time in the future.
|
||||
* Schedule a background push notification to be sent some time in the future.
|
||||
*
|
||||
* @return A CompletionStage that completes when the notification has successfully been scheduled
|
||||
*
|
||||
* @throws IllegalArgumentException if the given device does not have an APNs token
|
||||
* @throws IllegalArgumentException if the given device does not have a push token
|
||||
*/
|
||||
public CompletionStage<Void> scheduleBackgroundApnsNotification(final Account account, final Device device) {
|
||||
if (StringUtils.isBlank(device.getApnId())) {
|
||||
throw new IllegalArgumentException("Device must have an APNs token");
|
||||
public CompletionStage<Void> scheduleBackgroundNotification(final PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
if (StringUtils.isBlank(getPushToken(tokenType, device))) {
|
||||
throw new IllegalArgumentException("Device must have an " + tokenType + " token");
|
||||
}
|
||||
|
||||
BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER.increment();
|
||||
|
||||
return scheduleBackgroundApnsNotificationScript.executeAsync(
|
||||
Metrics.counter(BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER_NAME, "type", tokenType.name()).increment();
|
||||
return scheduleBackgroundNotificationScript.executeAsync(
|
||||
List.of(
|
||||
getLastBackgroundApnsNotificationTimestampKey(account, device),
|
||||
getPendingBackgroundApnsNotificationQueueKey(account, device)),
|
||||
getLastBackgroundNotificationTimestampKey(account, device),
|
||||
getPendingBackgroundNotificationQueueKey(tokenType, account, device)),
|
||||
List.of(
|
||||
encodeAciAndDeviceId(account, device),
|
||||
String.valueOf(clock.millis()),
|
||||
|
@ -236,14 +237,15 @@ public class PushNotificationScheduler implements Managed {
|
|||
*/
|
||||
public CompletionStage<Void> cancelScheduledNotifications(Account account, Device device) {
|
||||
return CompletableFuture.allOf(
|
||||
cancelBackgroundApnsNotifications(account, device),
|
||||
cancelBackgroundNotifications(PushNotification.TokenType.FCM, account, device),
|
||||
cancelBackgroundNotifications(PushNotification.TokenType.APN, account, device),
|
||||
cancelDelayedNotifications(account, device));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
CompletableFuture<Void> cancelBackgroundApnsNotifications(final Account account, final Device device) {
|
||||
CompletableFuture<Void> cancelBackgroundNotifications(final PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
return pushSchedulingCluster.withCluster(connection -> connection.async()
|
||||
.zrem(getPendingBackgroundApnsNotificationQueueKey(account, device), encodeAciAndDeviceId(account, device)))
|
||||
.zrem(getPendingBackgroundNotificationQueueKey(tokenType, account, device), encodeAciAndDeviceId(account, device)))
|
||||
.thenRun(Util.NOOP)
|
||||
.toCompletableFuture();
|
||||
}
|
||||
|
@ -276,17 +278,23 @@ public class PushNotificationScheduler implements Managed {
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
CompletableFuture<Void> sendBackgroundApnsNotification(final Account account, final Device device) {
|
||||
if (StringUtils.isBlank(device.getApnId())) {
|
||||
CompletableFuture<Void> sendBackgroundNotification(PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
final String pushToken = getPushToken(tokenType, device);
|
||||
if (StringUtils.isBlank(pushToken)) {
|
||||
return CompletableFuture.completedFuture(null);
|
||||
}
|
||||
|
||||
final PushNotificationSender sender = switch (tokenType) {
|
||||
case FCM -> fcmSender;
|
||||
case APN -> apnSender;
|
||||
};
|
||||
|
||||
// It's okay for the "last notification" timestamp to expire after the "cooldown" period has elapsed; a missing
|
||||
// timestamp and a timestamp older than the period are functionally equivalent.
|
||||
return pushSchedulingCluster.withCluster(connection -> connection.async().set(
|
||||
getLastBackgroundApnsNotificationTimestampKey(account, device),
|
||||
getLastBackgroundNotificationTimestampKey(account, device),
|
||||
String.valueOf(clock.millis()), new SetArgs().ex(BACKGROUND_NOTIFICATION_PERIOD)))
|
||||
.thenCompose(ignored -> apnSender.sendNotification(new PushNotification(device.getApnId(), PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, false)))
|
||||
.thenCompose(ignored -> sender.sendNotification(new PushNotification(pushToken, tokenType, PushNotification.NotificationType.NOTIFICATION, null, account, device, false)))
|
||||
.thenAccept(response -> Metrics.counter(BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME,
|
||||
ACCEPTED_TAG, String.valueOf(response.accepted()))
|
||||
.increment())
|
||||
|
@ -321,6 +329,10 @@ public class PushNotificationScheduler implements Managed {
|
|||
|
||||
@VisibleForTesting
|
||||
static String encodeAciAndDeviceId(final Account account, final Device device) {
|
||||
// Note: This does not include a device registration id. If a device is unlinked and a new device is linked with
|
||||
// the original device's id, the new device might get the old device's scheduled push, or the new device might
|
||||
// delay its own push because the old device had a recent push. An extra or delayed background push is harmless,
|
||||
// so this is okay.
|
||||
return account.getUuid() + ":" + device.getId();
|
||||
}
|
||||
|
||||
|
@ -351,15 +363,19 @@ public class PushNotificationScheduler implements Managed {
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static String getPendingBackgroundApnsNotificationQueueKey(final Account account, final Device device) {
|
||||
return getPendingBackgroundApnsNotificationQueueKey(SlotHash.getSlot(encodeAciAndDeviceId(account, device)));
|
||||
static String getPendingBackgroundNotificationQueueKey(final PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
return getPendingBackgroundNotificationQueueKey(tokenType, SlotHash.getSlot(encodeAciAndDeviceId(account, device)));
|
||||
}
|
||||
|
||||
private static String getPendingBackgroundApnsNotificationQueueKey(final int slot) {
|
||||
return PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
|
||||
private static String getPendingBackgroundNotificationQueueKey(final PushNotification.TokenType tokenType, final int slot) {
|
||||
final String prefix = switch (tokenType) {
|
||||
case APN -> PENDING_BACKGROUND_APN_NOTIFICATIONS_KEY_PREFIX;
|
||||
case FCM -> PENDING_BACKGROUND_FCM_NOTIFICATIONS_KEY_PREFIX;
|
||||
};
|
||||
return prefix + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
|
||||
}
|
||||
|
||||
private static String getLastBackgroundApnsNotificationTimestampKey(final Account account, final Device device) {
|
||||
private static String getLastBackgroundNotificationTimestampKey(final Account account, final Device device) {
|
||||
return LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX + "::{" + encodeAciAndDeviceId(account, device) + "}";
|
||||
}
|
||||
|
||||
|
@ -376,15 +392,15 @@ public class PushNotificationScheduler implements Managed {
|
|||
Optional<Instant> getLastBackgroundApnsNotificationTimestamp(final Account account, final Device device) {
|
||||
return Optional.ofNullable(
|
||||
pushSchedulingCluster.withCluster(connection ->
|
||||
connection.sync().get(getLastBackgroundApnsNotificationTimestampKey(account, device))))
|
||||
connection.sync().get(getLastBackgroundNotificationTimestampKey(account, device))))
|
||||
.map(timestampString -> Instant.ofEpochMilli(Long.parseLong(timestampString)));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Optional<Instant> getNextScheduledBackgroundApnsNotificationTimestamp(final Account account, final Device device) {
|
||||
Optional<Instant> getNextScheduledBackgroundNotificationTimestamp(PushNotification.TokenType tokenType, final Account account, final Device device) {
|
||||
return Optional.ofNullable(
|
||||
pushSchedulingCluster.withCluster(connection ->
|
||||
connection.sync().zscore(getPendingBackgroundApnsNotificationQueueKey(account, device),
|
||||
connection.sync().zscore(getPendingBackgroundNotificationQueueKey(tokenType, account, device),
|
||||
encodeAciAndDeviceId(account, device))))
|
||||
.map(timestamp -> Instant.ofEpochMilli(timestamp.longValue()));
|
||||
}
|
||||
|
@ -407,4 +423,11 @@ public class PushNotificationScheduler implements Managed {
|
|||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
private static String getPushToken(final PushNotification.TokenType tokenType, final Device device) {
|
||||
return switch (tokenType) {
|
||||
case FCM -> device.getGcmId();
|
||||
case APN -> device.getApnId();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.push;
|
|||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -60,16 +61,15 @@ public class ReceiptSender {
|
|||
.collect(Collectors.toMap(Device::getId, ignored -> message));
|
||||
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
|
||||
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
|
||||
}));
|
||||
.collect(Collectors.toMap(Device::getId,
|
||||
device -> device.getRegistrationId(destinationIdentifier.identityType())));
|
||||
|
||||
try {
|
||||
messageSender.sendMessages(destinationAccount,
|
||||
destinationIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId,
|
||||
Optional.empty(),
|
||||
UserAgentTagUtil.SERVER_UA);
|
||||
} catch (final Exception e) {
|
||||
logger.warn("Could not send delivery receipt", e);
|
||||
|
|
|
@ -38,8 +38,8 @@ public class SchedulingUtil {
|
|||
final LocalTime preferredTime,
|
||||
final Clock clock) {
|
||||
|
||||
final ZonedDateTime candidateNotificationTime = getZoneOffset(account, clock)
|
||||
.map(zoneOffset -> ZonedDateTime.now(zoneOffset).with(preferredTime))
|
||||
final ZonedDateTime candidateNotificationTime = getZoneId(account, clock)
|
||||
.map(zoneId -> ZonedDateTime.now(clock.withZone(zoneId)).with(preferredTime))
|
||||
.orElseGet(() -> {
|
||||
// We couldn't find a reasonable timezone for the account for some reason, so make an educated guess at a
|
||||
// reasonable time to send a notification based on the account's creation time.
|
||||
|
@ -59,7 +59,7 @@ public class SchedulingUtil {
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static Optional<ZoneOffset> getZoneOffset(final Account account, final Clock clock) {
|
||||
static Optional<ZoneId> getZoneId(final Account account, final Clock clock) {
|
||||
try {
|
||||
final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(account.getNumber(), null);
|
||||
|
||||
|
@ -70,7 +70,7 @@ public class SchedulingUtil {
|
|||
return Optional.empty();
|
||||
}
|
||||
|
||||
final List<ZoneOffset> sortedZoneOffsets = timeZonesForNumber
|
||||
final List<ZoneId> sortedZoneOffsets = timeZonesForNumber
|
||||
.stream()
|
||||
.map(id -> {
|
||||
try {
|
||||
|
@ -80,9 +80,6 @@ public class SchedulingUtil {
|
|||
}
|
||||
})
|
||||
.filter(Objects::nonNull)
|
||||
.map(ZoneId::getRules)
|
||||
.distinct()
|
||||
.map(zoneRules -> zoneRules.getOffset(clock.instant()))
|
||||
.sorted()
|
||||
.toList();
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import com.google.common.annotations.VisibleForTesting;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionException;
|
||||
import java.util.concurrent.Executor;
|
||||
|
@ -49,10 +50,15 @@ public class AccountLockManager {
|
|||
* @param task the task to execute once locks have been acquired
|
||||
* @param lockAcquisitionExecutor the executor on which to run blocking lock acquire/release tasks. this executor
|
||||
* should not use virtual threads.
|
||||
* @throws InterruptedException if interrupted while acquiring a lock
|
||||
*
|
||||
* @return the value returned by the given {@code task}
|
||||
*
|
||||
* @throws Exception if an exception is thrown by the given {@code task}
|
||||
*/
|
||||
public void withLock(final List<UUID> phoneNumberIdentifiers, final Runnable task,
|
||||
final Executor lockAcquisitionExecutor) {
|
||||
public <V> V withLock(final List<UUID> phoneNumberIdentifiers,
|
||||
final Callable<V> task,
|
||||
final Executor lockAcquisitionExecutor) throws Exception {
|
||||
|
||||
if (phoneNumberIdentifiers.isEmpty()) {
|
||||
throw new IllegalArgumentException("List of PNIs to lock must not be empty");
|
||||
}
|
||||
|
@ -75,7 +81,7 @@ public class AccountLockManager {
|
|||
}
|
||||
}, lockAcquisitionExecutor).join();
|
||||
|
||||
task.run();
|
||||
return task.call();
|
||||
} finally {
|
||||
CompletableFuture.runAsync(() -> {
|
||||
for (final LockItem lockItem : lockItems) {
|
||||
|
|
|
@ -11,6 +11,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
|
|||
import com.fasterxml.jackson.databind.ObjectWriter;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Throwables;
|
||||
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.Timer;
|
||||
import java.io.IOException;
|
||||
|
@ -36,9 +37,11 @@ import java.util.function.Predicate;
|
|||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nonnull;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.util.AsyncTimerUtil;
|
||||
import org.whispersystems.textsecuregcm.util.AttributeValues;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
@ -57,6 +60,7 @@ import software.amazon.awssdk.services.dynamodb.model.Delete;
|
|||
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
|
||||
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
|
||||
import software.amazon.awssdk.services.dynamodb.model.Put;
|
||||
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
|
||||
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
|
||||
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
|
||||
import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure;
|
||||
|
@ -87,7 +91,8 @@ public class Accounts {
|
|||
|
||||
static final List<String> ACCOUNT_FIELDS_TO_EXCLUDE_FROM_SERIALIZATION = List.of("uuid", "usernameLinkHandle");
|
||||
|
||||
private static final ObjectWriter ACCOUNT_DDB_JSON_WRITER = SystemMapper.jsonMapper()
|
||||
@VisibleForTesting
|
||||
static final ObjectWriter ACCOUNT_DDB_JSON_WRITER = SystemMapper.jsonMapper()
|
||||
.writer(SystemMapper.excludingField(Account.class, ACCOUNT_FIELDS_TO_EXCLUDE_FROM_SERIALIZATION));
|
||||
|
||||
private static final Timer CREATE_TIMER = Metrics.timer(name(Accounts.class, "create"));
|
||||
|
@ -299,8 +304,11 @@ public class Accounts {
|
|||
final Collection<TransactWriteItem> additionalWriteItems) {
|
||||
|
||||
if (!existingAccount.getUuid().equals(accountToCreate.getUuid()) ||
|
||||
!existingAccount.getNumber().equals(accountToCreate.getNumber())) {
|
||||
!existingAccount.getPhoneNumberIdentifier().equals(accountToCreate.getPhoneNumberIdentifier())) {
|
||||
|
||||
log.error("Reclaimed accounts must match. Old account {}:{}:{}, New account {}:{}:{}",
|
||||
existingAccount.getUuid(), redactPhoneNumber(existingAccount.getNumber()), existingAccount.getPhoneNumberIdentifier(),
|
||||
accountToCreate.getUuid(), redactPhoneNumber(accountToCreate.getNumber()), accountToCreate.getPhoneNumberIdentifier());
|
||||
throw new IllegalArgumentException("reclaimed accounts must match");
|
||||
}
|
||||
|
||||
|
@ -1399,8 +1407,7 @@ public class Accounts {
|
|||
final String tableName,
|
||||
final AttributeValue uuidAttr,
|
||||
final String keyName,
|
||||
final AttributeValue keyValue
|
||||
) {
|
||||
final AttributeValue keyValue) {
|
||||
return TransactWriteItem.builder()
|
||||
.put(Put.builder()
|
||||
.tableName(tableName)
|
||||
|
@ -1465,6 +1472,68 @@ public class Accounts {
|
|||
.build();
|
||||
}
|
||||
|
||||
public CompletableFuture<Void> regenerateConstraints(final Account account) {
|
||||
final List<CompletableFuture<?>> constraintFutures = new ArrayList<>();
|
||||
|
||||
constraintFutures.add(writeConstraint(phoneNumberConstraintTableName,
|
||||
account.getIdentifier(IdentityType.ACI),
|
||||
ATTR_ACCOUNT_E164,
|
||||
AttributeValues.fromString(account.getNumber())));
|
||||
|
||||
constraintFutures.add(writeConstraint(phoneNumberIdentifierConstraintTableName,
|
||||
account.getIdentifier(IdentityType.ACI),
|
||||
ATTR_PNI_UUID,
|
||||
AttributeValues.fromUUID(account.getPhoneNumberIdentifier())));
|
||||
|
||||
account.getUsernameHash().ifPresent(usernameHash ->
|
||||
constraintFutures.add(writeUsernameConstraint(account.getIdentifier(IdentityType.ACI),
|
||||
usernameHash,
|
||||
Optional.empty())));
|
||||
|
||||
account.getUsernameHolds().forEach(usernameHold ->
|
||||
constraintFutures.add(writeUsernameConstraint(account.getIdentifier(IdentityType.ACI),
|
||||
usernameHold.usernameHash(),
|
||||
Optional.of(Instant.ofEpochSecond(usernameHold.expirationSecs())))));
|
||||
|
||||
return CompletableFuture.allOf(constraintFutures.toArray(CompletableFuture[]::new));
|
||||
}
|
||||
|
||||
private CompletableFuture<Void> writeConstraint(
|
||||
final String tableName,
|
||||
final UUID accountIdentifier,
|
||||
final String keyName,
|
||||
final AttributeValue keyValue) {
|
||||
|
||||
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
|
||||
.tableName(tableName)
|
||||
.item(Map.of(
|
||||
keyName, keyValue,
|
||||
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
|
||||
.build())
|
||||
.thenRun(Util.NOOP);
|
||||
}
|
||||
|
||||
private CompletableFuture<Void> writeUsernameConstraint(
|
||||
final UUID accountIdentifier,
|
||||
final byte[] usernameHash,
|
||||
final Optional<Instant> maybeExpiration) {
|
||||
|
||||
final Map<String, AttributeValue> item = new HashMap<>(Map.of(
|
||||
UsernameTable.KEY_USERNAME_HASH, AttributeValues.fromByteArray(usernameHash),
|
||||
UsernameTable.ATTR_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier),
|
||||
UsernameTable.ATTR_CONFIRMED, AttributeValues.fromBool(maybeExpiration.isEmpty())
|
||||
));
|
||||
|
||||
maybeExpiration.ifPresent(expiration ->
|
||||
item.put(UsernameTable.ATTR_TTL, AttributeValues.fromLong(expiration.getEpochSecond())));
|
||||
|
||||
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
|
||||
.tableName(usernamesConstraintTableName)
|
||||
.item(item)
|
||||
.build())
|
||||
.thenRun(Util.NOOP);
|
||||
}
|
||||
|
||||
@Nonnull
|
||||
private static String extractCancellationReasonCodes(final TransactionCanceledException exception) {
|
||||
return exception.cancellationReasons().stream()
|
||||
|
@ -1522,4 +1591,15 @@ public class Accounts {
|
|||
private static boolean isTransactionConflict(final CancellationReason reason) {
|
||||
return TRANSACTION_CONFLICT.equals(reason.code());
|
||||
}
|
||||
|
||||
private static String redactPhoneNumber(final String phoneNumber) {
|
||||
final StringBuilder sb = new StringBuilder();
|
||||
sb.append("+");
|
||||
sb.append(Util.getCountryCode(phoneNumber));
|
||||
sb.append("???");
|
||||
sb.append(StringUtils.length(phoneNumber) < 3
|
||||
? ""
|
||||
: phoneNumber.substring(phoneNumber.length() - 2, phoneNumber.length()));
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,7 +62,6 @@ import java.util.stream.Stream;
|
|||
import javax.annotation.Nullable;
|
||||
import javax.crypto.Mac;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
import org.apache.commons.lang3.ObjectUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -274,11 +273,33 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
final DeviceSpec primaryDeviceSpec,
|
||||
@Nullable final String userAgent) throws InterruptedException {
|
||||
|
||||
final Account account = new Account();
|
||||
final UUID pni = phoneNumberIdentifiers.getPhoneNumberIdentifier(number).join();
|
||||
|
||||
return createTimer.record(() -> {
|
||||
accountLockManager.withLock(List.of(pni), () -> {
|
||||
try {
|
||||
return accountLockManager.withLock(List.of(pni),
|
||||
() -> create(number, pni, accountAttributes, accountBadges, aciIdentityKey, pniIdentityKey, primaryDeviceSpec, userAgent), accountLockExecutor);
|
||||
} catch (final Exception e) {
|
||||
if (e instanceof RuntimeException runtimeException) {
|
||||
throw runtimeException;
|
||||
}
|
||||
|
||||
logger.error("Unexpected exception while creating account", e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private Account create(final String number,
|
||||
final UUID pni,
|
||||
final AccountAttributes accountAttributes,
|
||||
final List<AccountBadge> accountBadges,
|
||||
final IdentityKey aciIdentityKey,
|
||||
final IdentityKey pniIdentityKey,
|
||||
final DeviceSpec primaryDeviceSpec,
|
||||
@Nullable final String userAgent) {
|
||||
|
||||
final Account account = new Account();
|
||||
final Optional<UUID> maybeRecentlyDeletedAccountIdentifier =
|
||||
accounts.findRecentlyDeletedAccountIdentifier(pni);
|
||||
|
||||
|
@ -390,10 +411,8 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
accountAttributes.recoveryPassword().ifPresent(registrationRecoveryPassword ->
|
||||
registrationRecoveryPasswordsManager.store(account.getIdentifier(IdentityType.PNI),
|
||||
registrationRecoveryPassword));
|
||||
}, accountLockExecutor);
|
||||
|
||||
return account;
|
||||
});
|
||||
}
|
||||
|
||||
public CompletableFuture<Pair<Account, Device>> addDevice(final Account account, final DeviceSpec deviceSpec, final String linkDeviceToken) {
|
||||
|
@ -580,6 +599,15 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
return Optional.of(aci);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unlink a device from the given account. The device will be immediately disconnected if it is
|
||||
* connected to any chat frontend, but it is the caller's responsibility to make sure that the
|
||||
* account's *other* devices are disconnected, either by use of
|
||||
* {@link org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider} or
|
||||
* directly by calling {@link DeviceDisconnectionManager#requestDisconnection}.
|
||||
*
|
||||
* @returns the updated Account
|
||||
*/
|
||||
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
|
||||
if (deviceId == Device.PRIMARY_ID) {
|
||||
throw new IllegalArgumentException("Cannot remove primary device");
|
||||
|
@ -633,26 +661,45 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
|
||||
public Account changeNumber(final Account account,
|
||||
final String targetNumber,
|
||||
@Nullable final IdentityKey pniIdentityKey,
|
||||
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
|
||||
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
|
||||
@Nullable final Map<Byte, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
|
||||
final IdentityKey pniIdentityKey,
|
||||
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
|
||||
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
|
||||
final Map<Byte, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
|
||||
|
||||
final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier();
|
||||
final UUID targetPhoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber).join();
|
||||
|
||||
if (originalPhoneNumberIdentifier.equals(targetPhoneNumberIdentifier)) {
|
||||
if (pniIdentityKey != null) {
|
||||
throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePniKeys");
|
||||
}
|
||||
return account;
|
||||
}
|
||||
|
||||
try {
|
||||
return accountLockManager.withLock(List.of(account.getPhoneNumberIdentifier(), targetPhoneNumberIdentifier),
|
||||
() -> changeNumber(account, targetNumber, targetPhoneNumberIdentifier, pniIdentityKey, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds), accountLockExecutor);
|
||||
} catch (final Exception e) {
|
||||
if (e instanceof MismatchedDevicesException mismatchedDevicesException) {
|
||||
throw mismatchedDevicesException;
|
||||
} if (e instanceof RuntimeException runtimeException) {
|
||||
throw runtimeException;
|
||||
}
|
||||
|
||||
logger.error("Unexpected exception when changing phone number", e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private Account changeNumber(final Account account,
|
||||
final String targetNumber,
|
||||
final UUID targetPhoneNumberIdentifier,
|
||||
final IdentityKey pniIdentityKey,
|
||||
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
|
||||
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
|
||||
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
|
||||
|
||||
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
|
||||
|
||||
final AtomicReference<Account> updatedAccount = new AtomicReference<>();
|
||||
final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier();
|
||||
|
||||
accountLockManager.withLock(List.of(account.getPhoneNumberIdentifier(), targetPhoneNumberIdentifier), () -> {
|
||||
redisDelete(account);
|
||||
|
||||
// There are three possible states for accounts associated with the target phone number:
|
||||
|
@ -685,9 +732,9 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
.join();
|
||||
|
||||
final Collection<TransactWriteItem> keyWriteItems =
|
||||
buildPniKeyWriteItems(uuid, targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys);
|
||||
buildPniKeyWriteItems(targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys);
|
||||
|
||||
final Account numberChangedAccount = updateWithRetries(
|
||||
return updateWithRetries(
|
||||
account,
|
||||
a -> {
|
||||
setPniKeys(account, pniIdentityKey, pniRegistrationIds);
|
||||
|
@ -696,26 +743,23 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
a -> accounts.changeNumber(a, targetNumber, targetPhoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems),
|
||||
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
|
||||
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
|
||||
|
||||
updatedAccount.set(numberChangedAccount);
|
||||
}, accountLockExecutor);
|
||||
|
||||
return updatedAccount.get();
|
||||
}
|
||||
|
||||
public Account updatePniKeys(final Account account,
|
||||
final IdentityKey pniIdentityKey,
|
||||
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
|
||||
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
|
||||
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
|
||||
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
|
||||
|
||||
try {
|
||||
return accountLockManager.withLock(List.of(account.getIdentifier(IdentityType.PNI)), () -> {
|
||||
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
|
||||
|
||||
final UUID aci = account.getIdentifier(IdentityType.ACI);
|
||||
final UUID pni = account.getIdentifier(IdentityType.PNI);
|
||||
|
||||
final Collection<TransactWriteItem> keyWriteItems =
|
||||
buildPniKeyWriteItems(pni, pni, pniSignedPreKeys, pniPqLastResortPreKeys);
|
||||
buildPniKeyWriteItems(pni, pniSignedPreKeys, pniPqLastResortPreKeys);
|
||||
|
||||
return redisDeleteAsync(account)
|
||||
.thenCompose(ignored -> keysManager.deleteSingleUsePreKeys(pni))
|
||||
|
@ -727,44 +771,38 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
AccountChangeValidator.GENERAL_CHANGE_VALIDATOR,
|
||||
MAX_UPDATE_ATTEMPTS))
|
||||
.join();
|
||||
}, accountLockExecutor);
|
||||
} catch (final Exception e) {
|
||||
if (e instanceof MismatchedDevicesException mismatchedDevicesException) {
|
||||
throw mismatchedDevicesException;
|
||||
} else if (e instanceof RuntimeException runtimeException) {
|
||||
throw runtimeException;
|
||||
}
|
||||
|
||||
logger.error("Unexpected exception when updating PNI key material", e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private Collection<TransactWriteItem> buildPniKeyWriteItems(
|
||||
final UUID enabledDevicesIdentifier,
|
||||
final UUID phoneNumberIdentifier,
|
||||
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
|
||||
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys) {
|
||||
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
|
||||
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys) {
|
||||
|
||||
final List<TransactWriteItem> keyWriteItems = new ArrayList<>();
|
||||
|
||||
if (pniSignedPreKeys != null) {
|
||||
pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
|
||||
keyWriteItems.add(keysManager.buildWriteItemForEcSignedPreKey(phoneNumberIdentifier, deviceId, signedPreKey)));
|
||||
}
|
||||
|
||||
if (pniPqLastResortPreKeys != null) {
|
||||
keysManager.getPqEnabledDevices(enabledDevicesIdentifier)
|
||||
.thenAccept(deviceIds -> deviceIds.stream()
|
||||
.filter(pniPqLastResortPreKeys::containsKey)
|
||||
.map(deviceId -> keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier,
|
||||
deviceId,
|
||||
pniPqLastResortPreKeys.get(deviceId)))
|
||||
.forEach(keyWriteItems::add))
|
||||
.join();
|
||||
}
|
||||
pniPqLastResortPreKeys.forEach((deviceId, lastResortKey) ->
|
||||
keyWriteItems.add(keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier, deviceId, lastResortKey)));
|
||||
|
||||
return keyWriteItems;
|
||||
}
|
||||
|
||||
private void setPniKeys(final Account account,
|
||||
@Nullable final IdentityKey pniIdentityKey,
|
||||
@Nullable final Map<Byte, Integer> pniRegistrationIds) {
|
||||
|
||||
if (ObjectUtils.allNull(pniIdentityKey, pniRegistrationIds)) {
|
||||
return;
|
||||
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniRegistrationIds)) {
|
||||
throw new IllegalArgumentException("PNI identity key and registration IDs must be all null or all non-null");
|
||||
}
|
||||
final IdentityKey pniIdentityKey,
|
||||
final Map<Byte, Integer> pniRegistrationIds) {
|
||||
|
||||
account.getDevices()
|
||||
.forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId())));
|
||||
|
@ -773,22 +811,15 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
}
|
||||
|
||||
private void validateDevices(final Account account,
|
||||
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
|
||||
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
|
||||
@Nullable final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
|
||||
if (pniSignedPreKeys == null && pniRegistrationIds == null) {
|
||||
return;
|
||||
} else if (pniSignedPreKeys == null || pniRegistrationIds == null) {
|
||||
throw new IllegalArgumentException("Signed pre-keys and registration IDs must both be null or both be non-null");
|
||||
}
|
||||
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
|
||||
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
|
||||
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
|
||||
|
||||
// Check that all including primary ID are in signed pre-keys
|
||||
validateCompleteDeviceList(account, pniSignedPreKeys.keySet());
|
||||
|
||||
// Check that all including primary ID are in Pq pre-keys
|
||||
if (pniPqLastResortPreKeys != null) {
|
||||
validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet());
|
||||
}
|
||||
|
||||
// Check that all devices are accounted for in the map of new PNI registration IDs
|
||||
validateCompleteDeviceList(account, pniRegistrationIds.keySet());
|
||||
|
@ -807,8 +838,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
extraDeviceIds.removeAll(accountDeviceIds);
|
||||
|
||||
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
|
||||
throw new MismatchedDevicesException(
|
||||
new MismatchedDevices(missingDeviceIds, extraDeviceIds, Collections.emptySet()));
|
||||
throw new MismatchedDevicesException(new MismatchedDevices(missingDeviceIds, extraDeviceIds, Set.of()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import java.time.Clock;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.ObjectUtils;
|
||||
|
@ -28,21 +30,26 @@ public class ChangeNumberManager {
|
|||
private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class);
|
||||
private final MessageSender messageSender;
|
||||
private final AccountsManager accountsManager;
|
||||
private final Clock clock;
|
||||
|
||||
public ChangeNumberManager(
|
||||
final MessageSender messageSender,
|
||||
final AccountsManager accountsManager) {
|
||||
final AccountsManager accountsManager,
|
||||
final Clock clock) {
|
||||
|
||||
this.messageSender = messageSender;
|
||||
this.accountsManager = accountsManager;
|
||||
this.clock = clock;
|
||||
}
|
||||
|
||||
public Account changeNumber(final Account account, final String number,
|
||||
@Nullable final IdentityKey pniIdentityKey,
|
||||
@Nullable final Map<Byte, ECSignedPreKey> deviceSignedPreKeys,
|
||||
@Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
|
||||
@Nullable final List<IncomingMessage> deviceMessages,
|
||||
@Nullable final Map<Byte, Integer> pniRegistrationIds,
|
||||
@Nullable final String senderUserAgent)
|
||||
public Account changeNumber(final Account account,
|
||||
final String number,
|
||||
final IdentityKey pniIdentityKey,
|
||||
final Map<Byte, ECSignedPreKey> deviceSignedPreKeys,
|
||||
final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
|
||||
final List<IncomingMessage> deviceMessages,
|
||||
final Map<Byte, Integer> pniRegistrationIds,
|
||||
final String senderUserAgent)
|
||||
throws InterruptedException, MismatchedDevicesException, MessageTooLargeException {
|
||||
|
||||
if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) ||
|
||||
|
@ -96,7 +103,7 @@ public class ChangeNumberManager {
|
|||
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
|
||||
|
||||
try {
|
||||
final long serverTimestamp = System.currentTimeMillis();
|
||||
final long serverTimestamp = clock.millis();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
|
||||
|
||||
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
|
||||
|
@ -113,10 +120,15 @@ public class ChangeNumberManager {
|
|||
.setEphemeral(false)
|
||||
.build()));
|
||||
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream()
|
||||
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId));
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = deviceMessages.stream()
|
||||
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
|
||||
|
||||
messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId, senderUserAgent);
|
||||
messageSender.sendMessages(account,
|
||||
serviceIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId,
|
||||
Optional.of(Device.PRIMARY_ID),
|
||||
senderUserAgent);
|
||||
} catch (final RuntimeException e) {
|
||||
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
|
||||
throw e;
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.util.stream.IntStream;
|
|||
import javax.annotation.Nullable;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter;
|
||||
import org.whispersystems.textsecuregcm.util.DeviceNameByteArrayAdapter;
|
||||
|
||||
|
@ -64,9 +65,8 @@ public class Device {
|
|||
@JsonProperty
|
||||
private int registrationId;
|
||||
|
||||
@Nullable
|
||||
@JsonProperty("pniRegistrationId")
|
||||
private Integer phoneNumberIdentityRegistrationId;
|
||||
private int phoneNumberIdentityRegistrationId;
|
||||
|
||||
@JsonProperty
|
||||
private long lastSeen;
|
||||
|
@ -208,18 +208,17 @@ public class Device {
|
|||
return getId() == PRIMARY_ID;
|
||||
}
|
||||
|
||||
public int getRegistrationId() {
|
||||
return registrationId;
|
||||
public int getRegistrationId(final IdentityType identityType) {
|
||||
return switch (identityType) {
|
||||
case ACI -> registrationId;
|
||||
case PNI -> phoneNumberIdentityRegistrationId;
|
||||
};
|
||||
}
|
||||
|
||||
public void setRegistrationId(int registrationId) {
|
||||
this.registrationId = registrationId;
|
||||
}
|
||||
|
||||
public OptionalInt getPhoneNumberIdentityRegistrationId() {
|
||||
return phoneNumberIdentityRegistrationId != null ? OptionalInt.of(phoneNumberIdentityRegistrationId) : OptionalInt.empty();
|
||||
}
|
||||
|
||||
public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRegistrationId) {
|
||||
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
@ -114,10 +113,6 @@ public class KeysManager {
|
|||
return ecSignedPreKeys.find(identifier, deviceId);
|
||||
}
|
||||
|
||||
public CompletableFuture<List<Byte>> getPqEnabledDevices(final UUID identifier) {
|
||||
return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().toFuture();
|
||||
}
|
||||
|
||||
public CompletableFuture<Integer> getEcCount(final UUID identifier, final byte deviceId) {
|
||||
return ecPreKeys.getCount(identifier, deviceId);
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue