Compare commits

...

33 Commits

Author SHA1 Message Date
Ravi Khadiwala ae2d98750c Add SecureValueRecoveryBController 2025-07-11 13:39:18 -05:00
Ravi Khadiwala 7d41c1219b Add /v2/svr as an alternative name for /v2/backup 2025-07-11 13:39:18 -05:00
Ravi Khadiwala 65e1f1b3a9 Arm the RemoveOrphanedPreKeyPagesCommand 2025-07-11 13:26:15 -05:00
Ameya Lokare 437b823c84 Update to the latest version of the spam filter 2025-07-09 13:27:04 -07:00
ravi-signal c9f21d5970
Always read from new and old PQ prekey stores, add experiment to start writing to new prekey store 2025-07-09 09:17:17 -05:00
Ravi Khadiwala 80c11e7eda Handle 429s from play API and add subscription docs 2025-07-09 09:15:29 -05:00
Jon Chambers 0745cabc87 Explicitly use synchronous flush mode when clearing Redis databases between tests 2025-07-09 09:15:15 -05:00
Jon Chambers 3e80669f4e Reuse/extend lifetime of Redis singleton resources 2025-07-09 09:15:15 -05:00
Jon Chambers b81cd9ec61 Reuse Redis clusters for the duration of a test run 2025-07-09 09:15:15 -05:00
Jon Chambers da6ed94443 Reuse client resources for lifetime of Redis cluster 2025-07-09 09:15:15 -05:00
Ameya Lokare 96d41b3716 Update to the latest version of the spam filter 2025-07-07 09:16:48 -07:00
Ravi Khadiwala 7dddc4d759 fix an incorrect backup metric 2025-07-07 18:14:53 +02:00
Katherine a87690d817
Include Redis cluster and shard address in circuit breaker log 2025-07-07 12:12:44 -04:00
Ameya Lokare 18ef3da261 Update dependencies 2025-06-30 14:17:03 -07:00
Ameya Lokare f4698dd5b2 Update to the latest version of the spam filter 2025-06-27 12:07:45 -07:00
Adel Lahlou d4322a2ed4 Remove latency based 1:1 call routing 2025-06-27 12:06:43 -07:00
Jon Chambers 7260a9d5b4 Make FoundationDB versions available at runtime 2025-06-27 11:21:50 -04:00
Jon Chambers 12b4ceb4aa Configure FoundationDB service container's database via Docker, removing `fdbcli` dependency 2025-06-27 11:08:58 -04:00
Jon Chambers fa1cd5c263 Install the Maven-fetched FoundationDB client library on GitHub Actions runner 2025-06-27 11:06:04 -04:00
Jon Chambers f8da13912d Fetch the FoundationDB client library as a pre-package step rather than including it in version control 2025-06-27 11:04:53 -04:00
Jon Chambers a3b3bf86ba Add a note about the FoundationDB client library requirement to the README 2025-06-27 11:04:52 -04:00
Jon Chambers a99f7bb87d Add test dependencies for FoundationDB 2025-06-27 11:04:52 -04:00
Jon Chambers d6f14d02dd Add a FoundationDB service container for tests 2025-06-27 11:04:46 -04:00
Jon Chambers d18671eaf9 Add FoundationDB runtime dependencies 2025-06-26 12:13:09 -04:00
Jon Chambers 87c30d00e8 Store compressed envelopes at rest 2025-06-25 15:20:19 -04:00
Jon Chambers c8f45685b8 Expand envelopes on load from storage 2025-06-25 14:31:19 -04:00
Jon Chambers bb90d80d22 Add a utility for compressing/expanding envelopes 2025-06-25 14:31:19 -04:00
Jon Chambers dcc541f86e Add binary representation fields for service IDs/UUIDs 2025-06-25 14:31:19 -04:00
Ravi Khadiwala aaa36fd8f5 Add a crawler for orphaned prekey pages 2025-06-24 13:46:48 -05:00
Ravi Khadiwala 2bb14892af Add paged prekey store 2025-06-24 13:46:48 -05:00
Ameya Lokare 6d8701665e Update to the latest version of the spam filter 2025-06-24 11:46:11 -07:00
Katherine c2b8fdac0d
Only log for an unexpected error from the key transparency service 2025-06-24 14:45:53 -04:00
Katherine 059caa4c57
Implement key transparency endpoints using `simple-grpc` 2025-06-24 14:01:35 -04:00
87 changed files with 3642 additions and 1424 deletions

View File

@ -12,6 +12,13 @@ jobs:
container: ubuntu:22.04
timeout-minutes: 20
services:
foundationdb:
# Note: this should generally match the version of the FoundationDB SERVER deployed in production; it's okay if
# it's a little behind the CLIENT version.
image: foundationdb/foundationdb:7.3.62
options: --name foundationdb
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up JDK 21
@ -26,6 +33,28 @@ jobs:
HOME: /root
- name: Install APT packages
# ca-certificates: required for AWS CRT client
run: apt update && apt install -y ca-certificates
run: |
# Add Docker's official GPG key:
apt update
apt install -y ca-certificates curl
install -m 0755 -d /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
chmod a+r /etc/apt/keyrings/docker.asc
# Add Docker repository to apt sources:
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
$(. /etc/os-release && echo "${UBUNTU_CODENAME:-$VERSION_CODENAME}") stable" | \
tee /etc/apt/sources.list.d/docker.list > /dev/null
# ca-certificates: required for AWS CRT client
apt update && apt install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin ca-certificates
- name: Configure FoundationDB database
run: docker exec foundationdb /usr/bin/fdbcli --exec 'configure new single memory'
- name: Download and install FoundationDB client
run: |
./mvnw -e -B -Pexclude-spam-filter clean prepare-package -DskipTests=true
cp service/target/jib-extra/usr/lib/libfdb_c.x86_64.so /usr/lib/libfdb_c.x86_64.so
ldconfig
- name: Build with Maven
run: ./mvnw -e -B verify
run: ./mvnw -e -B clean verify -DfoundationDb.serviceContainerName=foundationdb

View File

@ -11,6 +11,8 @@ https://signal.org/docs/
How to Build
------------
This project uses [FoundationDB](https://www.foundationdb.org/) and requires the FoundationDB client library to be installed on the host system. With that in place, the server can be built and tested with:
```shell script
$ ./mvnw clean test
```

65
pom.xml
View File

@ -37,46 +37,57 @@
</modules>
<properties>
<aws.sdk2.version>2.31.9</aws.sdk2.version>
<braintree.version>3.40.0</braintree.version>
<aws.sdk2.version>2.31.70</aws.sdk2.version>
<braintree.version>3.42.0</braintree.version>
<commons-csv.version>1.14.0</commons-csv.version>
<commons-io.version>2.18.0</commons-io.version>
<commons-io.version>2.19.0</commons-io.version>
<dropwizard.version>4.0.12</dropwizard.version>
<dropwizard-metrics-datadog.version>1.1.14</dropwizard-metrics-datadog.version>
<!-- can be updated to latest version with Dropwizard 5 (Jetty 12); will then need to disable telemetry -->
<dynamodblocal.version>2.2.1</dynamodblocal.version>
<google-cloud-libraries.version>26.57.0</google-cloud-libraries.version>
<grpc.version>1.70.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
<gson.version>2.12.1</gson.version>
<!-- Note: when updating FoundationDB, also include a copy of `libfdb_c.so` from the FoundationDB release at
src/main/jib/usr/lib/libfdb_c.so. We use x86_64 builds without AVX instructions enabled (i.e. FoundationDB versions
with even-numbered patch versions). Also when updating FoundationDB, make sure to update the version of FoundationDB
used by GitHub Actions. -->
<foundationdb.version>7.3.62</foundationdb.version>
<foundationdb.api-version>730</foundationdb.api-version>
<foundationdb.client-library-sha256>bfed237b787fae3cde1222676e6bfbb0d218fc27bf9e903397a7a7aa96fb2d33</foundationdb.client-library-sha256>
<google-cloud-libraries.version>26.62.0</google-cloud-libraries.version>
<grpc.version>1.73.0</grpc.version> <!-- should be kept in sync with the value from Google libraries-bom -->
<gson.version>2.13.1</gson.version>
<!-- several libraries (AWS, Google Cloud) use Apache http components transitively, and we need to align them -->
<httpcore.version>4.4.16</httpcore.version>
<httpclient.version>4.5.14</httpclient.version>
<jackson.version>2.18.3</jackson.version>
<jackson.version>2.19.1</jackson.version>
<junit-pioneer.version>2.3.0</junit-pioneer.version>
<jsr305.version>3.0.2</jsr305.version>
<kotlin.version>2.1.20</kotlin.version>
<kotlin.version>2.2.0</kotlin.version>
<!-- Logback 1.5.14+ has a null pointer bug: https://github.com/qos-ch/logback/issues/929. -->
<logback.version>1.5.13</logback.version>
<logback-access.version>2.0.5</logback-access.version>
<lettuce.version>6.5.5.RELEASE</lettuce.version>
<libphonenumber.version>9.0.2</libphonenumber.version>
<logstash.logback.version>7.3</logstash.logback.version>
<log4j-bom.version>2.24.3</log4j-bom.version>
<logback-access-common.version>2.0.5</logback-access-common.version>
<lettuce.version>6.7.1.RELEASE</lettuce.version>
<libphonenumber.version>9.0.8</libphonenumber.version>
<logstash.logback.version>8.1</logstash.logback.version>
<log4j-bom.version>2.25.0</log4j-bom.version>
<luajava.version>3.5.0</luajava.version>
<micrometer.version>1.14.5</micrometer.version>
<netty.version>4.1.119.Final</netty.version>
<micrometer.version>1.15.1</micrometer.version>
<netty.version>4.1.122.Final</netty.version>
<!-- Must be less than or equal to the value from Google libraries-bom which controls the protobuf runtime version.
See https://protobuf.dev/support/cross-version-runtime-guarantee/. -->
<protoc.version>4.29.4</protoc.version>
<pushy.version>0.15.4</pushy.version>
<reactive.grpc.version>1.2.4</reactive.grpc.version>
<reactor-bom.version>2024.0.4</reactor-bom.version> <!-- 3.7.4, see https://github.com/reactor/reactor#bom-versioning-scheme -->
<reactor-bom.version>2024.0.7</reactor-bom.version> <!-- 3.7.4, see https://github.com/reactor/reactor#bom-versioning-scheme -->
<resilience4j.version>2.3.0</resilience4j.version>
<semver4j.version>3.1.0</semver4j.version>
<simple-grpc.version>0.1.0</simple-grpc.version>
<slf4j.version>2.0.17</slf4j.version>
<stripe.version>23.10.0</stripe.version>
<swagger.version>2.2.27</swagger.version>
<swagger.version>2.2.31</swagger.version>
<testcontainers.version>1.21.2</testcontainers.version>
<!-- image to use in tests that run localstack via docker. -->
<localstack.image>localstack/localstack:3.5.0</localstack.image>
<!-- eclipse-temurin:21.0.6_7-jre-jammy (note: always use the multi-arch manifest *LIST* here) -->
<docker.image.sha256>02fc89fa8766a9ba221e69225f8d1c10bb91885ddbd3c112448e23488ba40ab6</docker.image.sha256>
@ -210,6 +221,11 @@
<artifactId>dropwizard-metrics-datadog</artifactId>
<version>${dropwizard-metrics-datadog.version}</version>
</dependency>
<dependency>
<groupId>org.foundationdb</groupId>
<artifactId>fdb-java</artifactId>
<version>${foundationdb.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
@ -309,7 +325,20 @@
<dependency>
<groupId>ch.qos.logback.access</groupId>
<artifactId>logback-access-common</artifactId>
<version>${logback-access.version}</version>
<version>${logback-access-common.version}</version>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers-bom</artifactId>
<version>${testcontainers.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>earth.adi</groupId>
<artifactId>testcontainers-foundationdb</artifactId>
<version>1.1.0</version>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>

View File

@ -16,6 +16,9 @@ directoryV2.client.userIdTokenSharedSecret: bbcdefghijklmnopqrstuvwxyz0123456789
svr2.userAuthenticationTokenSharedSecret: abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG= # base64-encoded secret shared with SVR2 to generate auth tokens for Signal users
svr2.userIdTokenSharedSecret: bbcdefghijklmnopqrstuvwxyz0123456789ABCDEFG= # base64-encoded secret shared with SVR2 to generate auth identity tokens for Signal users
svrb.userAuthenticationTokenSharedSecret: abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG= # base64-encoded secret shared with SVRB to generate auth tokens for Signal users
svrb.userIdTokenSharedSecret: bbcdefghijklmnopqrstuvwxyz0123456789ABCDEFG= # base64-encoded secret shared with SVRB to generate auth identity tokens for Signal users
tus.userAuthenticationTokenSharedSecret: abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG=
gcpAttachments.rsaSigningKey: |

View File

@ -138,6 +138,8 @@ dynamoDbTables:
tableName: Example_EC_Signed_Pre_Keys
pqKeys:
tableName: Example_PQ_Keys
pagedPqKeys:
tableName: Example_PQ_Paged_Keys
pqLastResortKeys:
tableName: Example_PQ_Last_Resort_Keys
messages:
@ -174,6 +176,10 @@ dynamoDbTables:
verificationSessions:
tableName: Example_VerificationSessions
pagedSingleUseKEMPreKeyStore:
bucket: preKeyBucket # S3 Bucket name
region: us-west-2 # AWS region
cacheCluster: # Redis server configuration for cache cluster
configurationUri: redis://redis.example.com:6379/
@ -219,6 +225,34 @@ svr2:
AAAAAAAAAAAAAAAAAAAA
-----END CERTIFICATE-----
svrb:
uri: svrb.example.com
userAuthenticationTokenSharedSecret: secret://svrb.userAuthenticationTokenSharedSecret
userIdTokenSharedSecret: secret://svrb.userIdTokenSharedSecret
svrCaCertificates:
- |
-----BEGIN CERTIFICATE-----
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
AAAAAAAAAAAAAAAAAAAA
-----END CERTIFICATE-----
messageCache: # Redis server configuration for message store cache
persistDelayMinutes: 1
cluster:

View File

@ -51,6 +51,13 @@
<groupId>io.swagger.core.v3</groupId>
<artifactId>swagger-jaxrs2-jakarta</artifactId>
<version>${swagger.version}</version>
<exclusions>
<!-- conflicts with jackson-dataformat-yaml -->
<exclusion>
<groupId>org.yaml</groupId>
<artifactId>snakeyaml</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>jakarta.servlet</groupId>
@ -351,6 +358,11 @@
<artifactId>reactor-grpc-stub</artifactId>
</dependency>
<dependency>
<groupId>org.foundationdb</groupId>
<artifactId>fdb-java</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>apache-client</artifactId>
@ -485,6 +497,24 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>localstack</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>earth.adi</groupId>
<artifactId>testcontainers-foundationdb</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.auth</groupId>
<artifactId>google-auth-library-oauth2-http</artifactId>
@ -520,6 +550,28 @@
<id>exclude-spam-filter</id>
<build>
<plugins>
<plugin>
<groupId>io.github.download-maven-plugin</groupId>
<artifactId>download-maven-plugin</artifactId>
<version>2.0.0</version>
<executions>
<execution>
<id>install-foundationdb-client-library</id>
<phase>prepare-package</phase>
<goals>
<goal>wget</goal>
</goals>
</execution>
</executions>
<configuration>
<url>https://github.com/apple/foundationdb/releases/download/${foundationdb.version}/libfdb_c.x86_64.so</url>
<outputDirectory>${project.build.directory}/jib-extra/usr/lib</outputDirectory>
<sha256>${foundationdb.client-library-sha256}</sha256>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
@ -641,6 +693,9 @@
<includes>*.yml</includes>
<into>/usr/share/signal/</into>
</path>
<path>
<from>${project.build.directory}/jib-extra</from>
</path>
</paths>
</extraDirectories>
</configuration>
@ -712,6 +767,9 @@
<configuration>
<!-- add-opens: work around PATCH not being a supported method on HttpUrlConnection -->
<argLine>-javaagent:${org.mockito:mockito-core:jar} --add-opens=java.base/java.net=ALL-UNNAMED</argLine>
<systemPropertyVariables>
<localstackImage>${localstack.image}</localstackImage>
</systemPropertyVariables>
</configuration>
</plugin>

View File

@ -0,0 +1,20 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class FoundationDbVersion {
private static final String VERSION = "${foundationdb.version}";
private static final int API_VERSION = ${foundationdb.api-version};
public static String getFoundationDbVersion() {
return VERSION;
}
public static int getFoundationDbApiVersion() {
return API_VERSION;
}
}

View File

@ -45,13 +45,14 @@ import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalit
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.NoiseTunnelConfiguration;
import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.PagedSingleUseKEMPreKeyStoreConfiguration;
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.RegistrationServiceClientFactory;
import org.whispersystems.textsecuregcm.configuration.RemoteConfigConfiguration;
import org.whispersystems.textsecuregcm.configuration.ReportMessageConfiguration;
import org.whispersystems.textsecuregcm.configuration.S3ObjectMonitorFactory;
import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecoveryConfiguration;
import org.whispersystems.textsecuregcm.configuration.ShortCodeExpanderConfiguration;
import org.whispersystems.textsecuregcm.configuration.SpamFilterConfiguration;
import org.whispersystems.textsecuregcm.configuration.StripeConfiguration;
@ -155,7 +156,12 @@ public class WhisperServerConfiguration extends Configuration {
@NotNull
@Valid
@JsonProperty
private SecureValueRecovery2Configuration svr2;
private SecureValueRecoveryConfiguration svr2;
@NotNull
@Valid
@JsonProperty
private SecureValueRecoveryConfiguration svrb;
@NotNull
@Valid
@ -257,6 +263,11 @@ public class WhisperServerConfiguration extends Configuration {
@NotNull
private OneTimeDonationConfiguration oneTimeDonations;
@Valid
@JsonProperty
@NotNull
private PagedSingleUseKEMPreKeyStoreConfiguration pagedSingleUseKEMPreKeyStore;
@Valid
@NotNull
@JsonProperty
@ -383,10 +394,14 @@ public class WhisperServerConfiguration extends Configuration {
return pubsub;
}
public SecureValueRecovery2Configuration getSvr2Configuration() {
public SecureValueRecoveryConfiguration getSvr2Configuration() {
return svr2;
}
public SecureValueRecoveryConfiguration getSvrbConfiguration() {
return svrb;
}
public DirectoryV2Configuration getDirectoryV2Configuration() {
return directoryV2;
}
@ -478,6 +493,10 @@ public class WhisperServerConfiguration extends Configuration {
return oneTimeDonations;
}
public PagedSingleUseKEMPreKeyStoreConfiguration getPagedSingleUseKEMPreKeyStore() {
return pagedSingleUseKEMPreKeyStore;
}
public ReportMessageConfiguration getReportMessageConfiguration() {
return reportMessage;
}

View File

@ -124,6 +124,7 @@ import org.whispersystems.textsecuregcm.controllers.RegistrationController;
import org.whispersystems.textsecuregcm.controllers.RemoteConfigController;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecoveryBController;
import org.whispersystems.textsecuregcm.controllers.StickerController;
import org.whispersystems.textsecuregcm.controllers.SubscriptionController;
import org.whispersystems.textsecuregcm.controllers.VerificationController;
@ -225,6 +226,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.OneTimeDonationsManager;
import org.whispersystems.textsecuregcm.storage.PagedSingleUseKEMPreKeyStore;
import org.whispersystems.textsecuregcm.storage.PersistentTimer;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles;
@ -235,8 +237,12 @@ import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswords;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.RemoteConfigs;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.storage.RepeatedUseECSignedPreKeyStore;
import org.whispersystems.textsecuregcm.storage.RepeatedUseKEMSignedPreKeyStore;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.SingleUseECPreKeyStore;
import org.whispersystems.textsecuregcm.storage.SingleUseKEMPreKeyStore;
import org.whispersystems.textsecuregcm.storage.SubscriptionManager;
import org.whispersystems.textsecuregcm.storage.Subscriptions;
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
@ -274,6 +280,7 @@ import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand;
import org.whispersystems.textsecuregcm.workers.RemoveExpiredBackupsCommand;
import org.whispersystems.textsecuregcm.workers.RemoveExpiredLinkedDevicesCommand;
import org.whispersystems.textsecuregcm.workers.RemoveExpiredUsernameHoldsCommand;
import org.whispersystems.textsecuregcm.workers.RemoveOrphanedPreKeyPagesCommand;
import org.whispersystems.textsecuregcm.workers.ScheduledApnPushNotificationSenderServiceCommand;
import org.whispersystems.textsecuregcm.workers.ServerVersionCommand;
import org.whispersystems.textsecuregcm.workers.SetRequestLoggingEnabledTask;
@ -327,6 +334,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
bootstrap.addCommand(new RemoveExpiredAccountsCommand(Clock.systemUTC()));
bootstrap.addCommand(new RemoveExpiredUsernameHoldsCommand(Clock.systemUTC()));
bootstrap.addCommand(new RemoveExpiredBackupsCommand(Clock.systemUTC()));
bootstrap.addCommand(new RemoveOrphanedPreKeyPagesCommand(Clock.systemUTC()));
bootstrap.addCommand(new BackupMetricsCommand(Clock.systemUTC()));
bootstrap.addCommand(new BackupUsageRecalculationCommand());
bootstrap.addCommand(new RemoveExpiredLinkedDevicesCommand());
@ -363,6 +371,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
MetricsUtil.configureRegistries(config, environment, dynamicConfigurationManager);
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager);
if (config.getServerFactory() instanceof DefaultServerFactory defaultServerFactory) {
defaultServerFactory.getApplicationConnectors()
.forEach(connectorFactory -> {
@ -425,13 +435,22 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName());
S3AsyncClient asyncKeysS3Client = S3AsyncClient.builder()
.credentialsProvider(awsCredentialsProvider)
.region(Region.of(config.getPagedSingleUseKEMPreKeyStore().region()))
.build();
KeysManager keysManager = new KeysManager(
dynamoDbAsyncClient,
config.getDynamoDbTables().getEcKeys().getTableName(),
config.getDynamoDbTables().getKemKeys().getTableName(),
config.getDynamoDbTables().getEcSignedPreKeys().getTableName(),
config.getDynamoDbTables().getKemLastResortKeys().getTableName()
);
new SingleUseECPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getEcKeys().getTableName()),
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getKemKeys().getTableName()),
new PagedSingleUseKEMPreKeyStore(
dynamoDbAsyncClient,
asyncKeysS3Client,
config.getDynamoDbTables().getPagedKemKeys().getTableName(),
config.getPagedSingleUseKEMPreKeyStore().bucket()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getKemLastResortKeys().getTableName()),
experimentEnrollmentManager);
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration(),
@ -554,8 +573,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.maxThreads(2)
.minThreads(2)
.build();
ExecutorService keyTransparencyCallbackExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "keyTransparency-%d"));
ExecutorService googlePlayBillingExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "googlePlayBilling-%d"));
ExecutorService appleAppStoreExecutor = environment.lifecycle()
@ -592,9 +609,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getPaymentsServiceConfiguration());
ExternalServiceCredentialsGenerator svr2CredentialsGenerator = SecureValueRecovery2Controller.credentialsGenerator(
config.getSvr2Configuration());
ExternalServiceCredentialsGenerator svrbCredentialsGenerator = SecureValueRecoveryBController.credentialsGenerator(
config.getSvrbConfiguration());
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(
dynamicConfigurationManager);
RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager =
new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords);
UsernameHashZkProofVerifier usernameHashZkProofVerifier = new UsernameHashZkProofVerifier();
@ -606,8 +623,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getKeyTransparencyServiceConfiguration().port(),
config.getKeyTransparencyServiceConfiguration().tlsCertificate(),
config.getKeyTransparencyServiceConfiguration().clientCertificate(),
config.getKeyTransparencyServiceConfiguration().clientPrivateKey().value(),
keyTransparencyCallbackExecutor);
config.getKeyTransparencyServiceConfiguration().clientPrivateKey().value());
SecureValueRecovery2Client secureValueRecovery2Client = new SecureValueRecovery2Client(svr2CredentialsGenerator,
secureValueRecovery2ServiceExecutor, secureValueRecoveryServiceRetryExecutor, config.getSvr2Configuration());
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,
@ -1108,6 +1124,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new RemoteConfigController(remoteConfigsManager, config.getRemoteConfigConfiguration().globalConfig(), clock),
new SecureStorageController(storageCredentialsGenerator),
new SecureValueRecovery2Controller(svr2CredentialsGenerator, accountsManager),
new SecureValueRecoveryBController(svrbCredentialsGenerator),
new StickerController(rateLimiters, config.getCdnConfiguration().credentials().accessKeyId().value(),
config.getCdnConfiguration().credentials().secretAccessKey().value(), config.getCdnConfiguration().region(),
config.getCdnConfiguration().bucket()),

View File

@ -297,7 +297,7 @@ public class BackupsDb {
.tags(tags)
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
.record(mediaCount);
.record(bytesUsed);
// Report that the backup is out of quota if it cannot store a max size media object
final boolean quotaExhausted = bytesUsed >=

View File

@ -1,34 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.calls.routing;
import jakarta.validation.constraints.NotNull;
import java.net.InetAddress;
import java.util.List;
import java.util.Map;
public record CallDnsRecords(
@NotNull
Map<String, List<InetAddress>> aByRegion,
@NotNull
Map<String, List<InetAddress>> aaaaByRegion
) {
public String getSummary() {
int numARecords = aByRegion.values().stream().mapToInt(List::size).sum();
int numAAAARecords = aaaaByRegion.values().stream().mapToInt(List::size).sum();
return String.format(
"(A records, %s regions, %s records), (AAAA records, %s regions, %s records)",
aByRegion.size(),
numARecords,
aaaaByRegion.size(),
numAAAARecords
);
}
public static CallDnsRecords empty() {
return new CallDnsRecords(Map.of(), Map.of());
}
}

View File

@ -1,80 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.calls.routing;
import com.fasterxml.jackson.core.StreamReadFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.S3ObjectMonitorFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.s3.S3ObjectMonitor;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
public class CallDnsRecordsManager implements Supplier<CallDnsRecords>, Managed {
private final S3ObjectMonitor objectMonitor;
private final AtomicReference<CallDnsRecords> callDnsRecords = new AtomicReference<>();
private final Timer refreshTimer;
private static final Logger log = LoggerFactory.getLogger(CallDnsRecordsManager.class);
private static final ObjectMapper objectMapper = JsonMapper.builder()
.enable(StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION)
.build();
public CallDnsRecordsManager(final ScheduledExecutorService executorService,
final AwsCredentialsProvider awsCredentialsProvider, final S3ObjectMonitorFactory configuration) {
this.objectMonitor = configuration.build(awsCredentialsProvider, executorService);
this.callDnsRecords.set(CallDnsRecords.empty());
this.refreshTimer = Metrics.timer(MetricsUtil.name(CallDnsRecordsManager.class, "refresh"));
}
private void handleDatabaseChanged(final InputStream inputStream) {
refreshTimer.record(() -> {
try (final InputStream bufferedInputStream = new BufferedInputStream(inputStream)) {
final CallDnsRecords newRecords = parseRecords(bufferedInputStream);
final CallDnsRecords oldRecords = callDnsRecords.getAndSet(newRecords);
log.info("Replaced dns records, old summary=[{}], new summary=[{}]", oldRecords != null ? oldRecords.getSummary() : "null", newRecords);
} catch (final IOException e) {
log.error("Failed to load Call DNS Records");
}
});
}
static CallDnsRecords parseRecords(InputStream inputStream) throws IOException {
return objectMapper.readValue(inputStream, CallDnsRecords.class);
}
@Override
public void start() throws Exception {
objectMonitor.start(this::handleDatabaseChanged);
}
@Override
public void stop() throws Exception {
objectMonitor.stop();
callDnsRecords.getAndSet(null);
}
@Override
public CallDnsRecords get() {
return this.callDnsRecords.get();
}
}

View File

@ -1,193 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.calls.routing;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import java.math.BigInteger;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Stream;
public class CallRoutingTable {
private final TreeMap<Integer, Map<Integer, List<String>>> ipv4Map;
private final TreeMap<Integer, Map<BigInteger, List<String>>> ipv6Map;
private final Map<GeoKey, List<String>> geoToDatacenter;
public CallRoutingTable(
Map<CidrBlock.IpV4CidrBlock, List<String>> ipv4SubnetToDatacenter,
Map<CidrBlock.IpV6CidrBlock, List<String>> ipv6SubnetToDatacenter,
Map<GeoKey, List<String>> geoToDatacenter
) {
this.ipv4Map = new TreeMap<>();
for (Map.Entry<CidrBlock.IpV4CidrBlock, List<String>> t : ipv4SubnetToDatacenter.entrySet()) {
if (!this.ipv4Map.containsKey(t.getKey().cidrBlockSize())) {
this.ipv4Map.put(t.getKey().cidrBlockSize(), new HashMap<>());
}
this.ipv4Map
.get(t.getKey().cidrBlockSize())
.put(t.getKey().subnet(), t.getValue());
}
this.ipv6Map = new TreeMap<>();
for (Map.Entry<CidrBlock.IpV6CidrBlock, List<String>> t : ipv6SubnetToDatacenter.entrySet()) {
if (!this.ipv6Map.containsKey(t.getKey().cidrBlockSize())) {
this.ipv6Map.put(t.getKey().cidrBlockSize(), new HashMap<>());
}
this.ipv6Map
.get(t.getKey().cidrBlockSize())
.put(t.getKey().subnet(), t.getValue());
}
this.geoToDatacenter = geoToDatacenter;
}
public static CallRoutingTable empty() {
return new CallRoutingTable(Map.of(), Map.of(), Map.of());
}
public enum Protocol {
v4,
v6
}
public record GeoKey(
@NotBlank String continent,
@NotBlank String country,
@NotNull Optional<String> subdivision,
@NotBlank Protocol protocol
) {}
/**
* Returns ordered list of fastest datacenters based on IP & Geo info. Prioritize the results based on subnet.
* Returns at most three, 2 by subnet and 1 by geo. Takes more from either bucket to hit 3.
*/
public List<String> getDatacentersFor(
InetAddress address,
String continent,
String country,
Optional<String> subdivision
) {
final int NUM_DATACENTERS = 3;
if(this.isEmpty()) {
return Collections.emptyList();
}
List<String> dcsBySubnet = getDatacentersBySubnet(address);
List<String> dcsByGeo = getDatacentersByGeo(continent, country, subdivision).stream()
.limit(NUM_DATACENTERS)
.filter(dc ->
(dcsBySubnet.isEmpty() || !dc.equals(dcsBySubnet.getFirst()))
&& (dcsBySubnet.size() < 2 || !dc.equals(dcsBySubnet.get(1)))
).toList();
return Stream.concat(
dcsBySubnet.stream().limit(dcsByGeo.isEmpty() ? NUM_DATACENTERS : NUM_DATACENTERS - 1),
dcsByGeo.stream())
.limit(NUM_DATACENTERS)
.toList();
}
public boolean isEmpty() {
return this.ipv4Map.isEmpty() && this.ipv6Map.isEmpty() && this.geoToDatacenter.isEmpty();
}
/**
* Returns ordered list of fastest datacenters based on ip info. Prioritizes V4 connections.
*/
public List<String> getDatacentersBySubnet(InetAddress address) throws IllegalArgumentException {
if(address instanceof Inet4Address) {
for(Map.Entry<Integer, Map<Integer, List<String>>> t: this.ipv4Map.descendingMap().entrySet()) {
int maskedIp = CidrBlock.IpV4CidrBlock.maskToSize((Inet4Address) address, t.getKey());
if(t.getValue().containsKey(maskedIp)) {
return t.getValue().get(maskedIp);
}
}
} else if (address instanceof Inet6Address) {
for(Map.Entry<Integer, Map<BigInteger, List<String>>> t: this.ipv6Map.descendingMap().entrySet()) {
BigInteger maskedIp = CidrBlock.IpV6CidrBlock.maskToSize((Inet6Address) address, t.getKey());
if(t.getValue().containsKey(maskedIp)) {
return t.getValue().get(maskedIp);
}
}
} else {
throw new IllegalArgumentException("Expected either an Inet4Address or Inet6Address");
}
return Collections.emptyList();
}
/**
* Returns ordered list of fastest datacenters based on geo info. Attempts to match based on subdivision, falls back
* to country based lookup. Does not attempt to look for nearby subdivisions. Prioritizes V4 connections.
*/
public List<String> getDatacentersByGeo(
String continent,
String country,
Optional<String> subdivision
) {
GeoKey v4Key = new GeoKey(continent, country, subdivision, Protocol.v4);
List<String> v4Options = this.geoToDatacenter.getOrDefault(v4Key, Collections.emptyList());
List<String> v4OptionsBackup = v4Options.isEmpty() && subdivision.isPresent() ?
this.geoToDatacenter.getOrDefault(
new GeoKey(continent, country, Optional.empty(), Protocol.v4),
Collections.emptyList())
: Collections.emptyList();
GeoKey v6Key = new GeoKey(continent, country, subdivision, Protocol.v6);
List<String> v6Options = this.geoToDatacenter.getOrDefault(v6Key, Collections.emptyList());
List<String> v6OptionsBackup = v6Options.isEmpty() && subdivision.isPresent() ?
this.geoToDatacenter.getOrDefault(
new GeoKey(continent, country, Optional.empty(), Protocol.v6),
Collections.emptyList())
: Collections.emptyList();
return Stream.of(
v4Options.stream(),
v6Options.stream(),
v4OptionsBackup.stream(),
v6OptionsBackup.stream()
)
.flatMap(Function.identity())
.distinct()
.toList();
}
public String toSummaryString() {
return String.format(
"[Ipv4Table=%s rows, Ipv6Table=%s rows, GeoTable=%s rows]",
ipv4Map.size(),
ipv6Map.size(),
geoToDatacenter.size()
);
}
@Override
public boolean equals(final Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
CallRoutingTable that = (CallRoutingTable) o;
return Objects.equals(ipv4Map, that.ipv4Map) && Objects.equals(ipv6Map, that.ipv6Map) && Objects.equals(
geoToDatacenter, that.geoToDatacenter);
}
@Override
public int hashCode() {
return Objects.hash(ipv4Map, ipv6Map, geoToDatacenter);
}
}

View File

@ -1,73 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.calls.routing;
import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.S3ObjectMonitorFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.s3.S3ObjectMonitor;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
public class CallRoutingTableManager implements Supplier<CallRoutingTable>, Managed {
private final S3ObjectMonitor objectMonitor;
private final AtomicReference<CallRoutingTable> routingTable = new AtomicReference<>();
private final String tableTag;
private final Timer refreshTimer;
private static final Logger log = LoggerFactory.getLogger(CallRoutingTableManager.class);
public CallRoutingTableManager(final ScheduledExecutorService executorService,
final AwsCredentialsProvider awsCredentialsProvider, final S3ObjectMonitorFactory configuration,
final String tableTag) {
this.objectMonitor = configuration.build(awsCredentialsProvider, executorService);
this.tableTag = tableTag;
this.routingTable.set(CallRoutingTable.empty());
this.refreshTimer = Metrics.timer(MetricsUtil.name(CallRoutingTableManager.class, tableTag));
}
private void handleDatabaseChanged(final InputStream inputStream) {
refreshTimer.record(() -> {
try(InputStreamReader reader = new InputStreamReader(inputStream)) {
CallRoutingTable newTable = CallRoutingTableParser.fromJson(reader);
this.routingTable.set(newTable);
log.info("Replaced {} call routing table: {}", tableTag, newTable.toSummaryString());
} catch (final IOException e) {
log.error("Failed to parse and update {} call routing table", tableTag);
}
});
}
@Override
public void start() throws Exception {
objectMonitor.start(this::handleDatabaseChanged);
}
@Override
public void stop() throws Exception {
objectMonitor.stop();
routingTable.getAndSet(null);
}
@Override
public CallRoutingTable get() {
return this.routingTable.get();
}
}

View File

@ -1,185 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.calls.routing;
import com.fasterxml.jackson.core.StreamReadFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
final class CallRoutingTableParser {
private final static int IPV4_DEFAULT_BLOCK_SIZE = 24;
private final static int IPV6_DEFAULT_BLOCK_SIZE = 48;
private static final ObjectMapper objectMapper = JsonMapper.builder()
.enable(StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION)
.build();
/** Used for parsing JSON */
private static class RawCallRoutingTable {
public Map<String, List<String>> ipv4GeoToDataCenters = Map.of();
public Map<String, List<String>> ipv6GeoToDataCenters = Map.of();
public Map<String, List<String>> ipv4SubnetsToDatacenters = Map.of();
public Map<String, List<String>> ipv6SubnetsToDatacenters = Map.of();
}
private final static String WHITESPACE_REGEX = "\\s+";
public static CallRoutingTable fromJson(final Reader inputReader) throws IOException {
try (final BufferedReader reader = new BufferedReader(inputReader)) {
RawCallRoutingTable rawTable = objectMapper.readValue(reader, RawCallRoutingTable.class);
Map<CidrBlock.IpV4CidrBlock, List<String>> ipv4SubnetToDatacenter = rawTable.ipv4SubnetsToDatacenters
.entrySet()
.stream()
.collect(Collectors.toUnmodifiableMap(
e -> (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock(e.getKey(), IPV4_DEFAULT_BLOCK_SIZE),
Map.Entry::getValue
));
Map<CidrBlock.IpV6CidrBlock, List<String>> ipv6SubnetToDatacenter = rawTable.ipv6SubnetsToDatacenters
.entrySet()
.stream()
.collect(Collectors.toUnmodifiableMap(
e -> (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock(e.getKey(), IPV6_DEFAULT_BLOCK_SIZE),
Map.Entry::getValue
));
Map<CallRoutingTable.GeoKey, List<String>> geoToDatacenter = Stream.concat(
rawTable.ipv4GeoToDataCenters
.entrySet()
.stream()
.map(e -> Map.entry(parseRawGeoKey(e.getKey(), CallRoutingTable.Protocol.v4), e.getValue())),
rawTable.ipv6GeoToDataCenters
.entrySet()
.stream()
.map(e -> Map.entry(parseRawGeoKey(e.getKey(), CallRoutingTable.Protocol.v6), e.getValue()))
).collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue));
return new CallRoutingTable(
ipv4SubnetToDatacenter,
ipv6SubnetToDatacenter,
geoToDatacenter
);
}
}
private static CallRoutingTable.GeoKey parseRawGeoKey(String rawKey, CallRoutingTable.Protocol protocol) {
String[] splits = rawKey.split("-");
if (splits.length < 2 || splits.length > 3) {
throw new IllegalArgumentException("Invalid raw key");
}
Optional<String> subdivision = splits.length < 3 ? Optional.empty() : Optional.of(splits[2]);
return new CallRoutingTable.GeoKey(splits[0], splits[1], subdivision, protocol);
}
/**
* Parses a call routing table in TSV format. Example below - see tests for more examples:
192.0.2.0/24 northamerica-northeast1
198.51.100.0/24 us-south1
203.0.113.0/24 asia-southeast1
2001:db8:b0a9::/48 us-east4
2001:db8:b0f5::/48 us-central1 northamerica-northeast1 us-east4
2001:db8:9406::/48 us-east1 us-central1
SA-SR-v4 us-east1 us-east4
SA-SR-v6 us-east1 us-south1
SA-UY-v4 southamerica-west1 southamerica-east1 europe-west3
SA-UY-v6 southamerica-west1 europe-west4
SA-VE-v4 us-east1 us-east4 us-south1
SA-VE-v6 us-east1 northamerica-northeast1 us-east4
ZZ-ZZ-v4 asia-south1 europe-southwest1 australia-southeast1
*/
public static CallRoutingTable fromTsv(final Reader inputReader) throws IOException {
try (final BufferedReader reader = new BufferedReader(inputReader)) {
// use maps to silently dedupe CidrBlocks
Map<CidrBlock.IpV4CidrBlock, List<String>> ipv4Map = new HashMap<>();
Map<CidrBlock.IpV6CidrBlock, List<String>> ipv6Map = new HashMap<>();
Map<CallRoutingTable.GeoKey, List<String>> ipGeoTable = new HashMap<>();
String line;
while((line = reader.readLine()) != null) {
if(line.isBlank()) {
continue;
}
List<String> splits = Arrays.stream(line.split(WHITESPACE_REGEX)).filter(s -> !s.isBlank()).toList();
if (splits.size() < 2) {
throw new IllegalStateException("Invalid row, expected some key and list of values");
}
List<String> datacenters = splits.subList(1, splits.size());
switch (guessLineType(splits)) {
case v4 -> {
CidrBlock cidrBlock = CidrBlock.parseCidrBlock(splits.getFirst());
if(!(cidrBlock instanceof CidrBlock.IpV4CidrBlock)) {
throw new IllegalArgumentException("Expected an ipv4 cidr block");
}
ipv4Map.put((CidrBlock.IpV4CidrBlock) cidrBlock, datacenters);
}
case v6 -> {
CidrBlock cidrBlock = CidrBlock.parseCidrBlock(splits.getFirst());
if(!(cidrBlock instanceof CidrBlock.IpV6CidrBlock)) {
throw new IllegalArgumentException("Expected an ipv6 cidr block");
}
ipv6Map.put((CidrBlock.IpV6CidrBlock) cidrBlock, datacenters);
}
case Geo -> {
String[] geo = splits.getFirst().split("-");
if(geo.length < 3) {
throw new IllegalStateException("Geo row key invalid, expected atleast continent, country, and protocol");
}
String continent = geo[0];
String country = geo[1];
Optional<String> subdivision = geo.length > 3 ? Optional.of(geo[2]) : Optional.empty();
CallRoutingTable.Protocol protocol = CallRoutingTable.Protocol.valueOf(geo[geo.length - 1].toLowerCase());
CallRoutingTable.GeoKey tableKey = new CallRoutingTable.GeoKey(
continent,
country,
subdivision,
protocol
);
ipGeoTable.put(tableKey, datacenters);
}
}
}
return new CallRoutingTable(
ipv4Map,
ipv6Map,
ipGeoTable
);
}
}
private static LineType guessLineType(List<String> splits) {
String first = splits.getFirst();
if (first.contains("-")) {
return LineType.Geo;
} else if(first.contains(":")) {
return LineType.v6;
} else if (first.contains(".")) {
return LineType.v4;
}
throw new IllegalArgumentException(String.format("Invalid line, could not determine type from '%s'", first));
}
private enum LineType {
v4, v6, Geo
}
}

View File

@ -1,137 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.calls.routing;
import java.math.BigInteger;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
/**
* Can be used to check if an IP is in the CIDR block
*/
public interface CidrBlock {
boolean ipInBlock(InetAddress address);
static CidrBlock parseCidrBlock(String cidrBlock, int defaultBlockSize) {
String[] splits = cidrBlock.split("/");
if(splits.length > 2) {
throw new IllegalArgumentException("Invalid cidr block format, expected {address}/{blocksize}");
}
try {
int blockSize = splits.length == 2 ? Integer.parseInt(splits[1]) : defaultBlockSize;
return parseCidrBlockInner(splits[0], blockSize);
} catch (NumberFormatException e) {
throw new IllegalArgumentException(String.format("Invalid block size specified: '%s'", splits[1]));
}
}
static CidrBlock parseCidrBlock(String cidrBlock) {
String[] splits = cidrBlock.split("/");
if (splits.length != 2) {
throw new IllegalArgumentException("Invalid cidr block format, expected {address}/{blocksize}");
}
try {
int blockSize = Integer.parseInt(splits[1]);
return parseCidrBlockInner(splits[0], blockSize);
} catch (NumberFormatException e) {
throw new IllegalArgumentException(String.format("Invalid block size specified: '%s'", splits[1]));
}
}
private static CidrBlock parseCidrBlockInner(String rawAddress, int blockSize) {
try {
InetAddress address = InetAddress.getByName(rawAddress);
if(address instanceof Inet4Address) {
return IpV4CidrBlock.of((Inet4Address) address, blockSize);
} else if (address instanceof Inet6Address) {
return IpV6CidrBlock.of((Inet6Address) address, blockSize);
} else {
throw new IllegalArgumentException("Must be an ipv4 or ipv6 string");
}
} catch (UnknownHostException e) {
throw new IllegalArgumentException(e);
}
}
record IpV4CidrBlock(int subnet, int subnetMask, int cidrBlockSize) implements CidrBlock {
public static IpV4CidrBlock of(Inet4Address subnet, int cidrBlockSize) {
if(cidrBlockSize > 32 || cidrBlockSize < 0) {
throw new IllegalArgumentException("Invalid cidrBlockSize");
}
int subnetMask = mask(cidrBlockSize);
int maskedIp = ipToInt(subnet) & subnetMask;
return new IpV4CidrBlock(maskedIp, subnetMask, cidrBlockSize);
}
public boolean ipInBlock(InetAddress address) {
if(!(address instanceof Inet4Address)) {
return false;
}
int ip = ipToInt((Inet4Address) address);
return (ip & subnetMask) == subnet;
}
private static int ipToInt(Inet4Address address) {
byte[] octets = address.getAddress();
return (octets[0] & 0xff) << 24 |
(octets[1] & 0xff) << 16 |
(octets[2] & 0xff) << 8 |
octets[3] & 0xff;
}
private static int mask(int cidrBlockSize) {
return (int) (-1L << (32 - cidrBlockSize));
}
public static int maskToSize(Inet4Address address, int cidrBlockSize) {
return ipToInt(address) & mask(cidrBlockSize);
}
}
record IpV6CidrBlock(BigInteger subnet, BigInteger subnetMask, int cidrBlockSize) implements CidrBlock {
private static final BigInteger MINUS_ONE = BigInteger.valueOf(-1);
public static IpV6CidrBlock of(Inet6Address subnet, int cidrBlockSize) {
if(cidrBlockSize > 128 || cidrBlockSize < 0) {
throw new IllegalArgumentException("Invalid cidrBlockSize");
}
BigInteger subnetMask = mask(cidrBlockSize);
BigInteger maskedIp = ipToInt(subnet).and(subnetMask);
return new IpV6CidrBlock(maskedIp, subnetMask, cidrBlockSize);
}
public boolean ipInBlock(InetAddress address) {
if(!(address instanceof Inet6Address)) {
return false;
}
BigInteger ip = ipToInt((Inet6Address) address);
return ip.and(subnetMask).equals(subnet);
}
private static BigInteger ipToInt(Inet6Address ipAddress) {
byte[] octets = ipAddress.getAddress();
assert octets.length == 16;
return new BigInteger(octets);
}
private static BigInteger mask(int cidrBlockSize) {
return MINUS_ONE.shiftLeft(128 - cidrBlockSize);
}
public static BigInteger maskToSize(Inet6Address address, int cidrBlockSize) {
return ipToInt(address).and(mask(cidrBlockSize));
}
}
}

View File

@ -1,16 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.calls.routing;
import java.util.List;
import java.util.Optional;
public record TurnServerOptions(
String hostname,
Optional<List<String>> urlsWithIps,
Optional<List<String>> urlsWithHostname
) {
}

View File

@ -60,6 +60,7 @@ public class DynamoDbTables {
private final Table ecSignedPreKeys;
private final Table kemKeys;
private final Table kemLastResortKeys;
private final Table pagedKemKeys;
private final TableWithExpiration messages;
private final TableWithExpiration onetimeDonations;
private final Table phoneNumberIdentifiers;
@ -88,6 +89,7 @@ public class DynamoDbTables {
@JsonProperty("ecSignedPreKeys") final Table ecSignedPreKeys,
@JsonProperty("pqKeys") final Table kemKeys,
@JsonProperty("pqLastResortKeys") final Table kemLastResortKeys,
@JsonProperty("pagedPqKeys") final Table pagedKemKeys,
@JsonProperty("messages") final TableWithExpiration messages,
@JsonProperty("onetimeDonations") final TableWithExpiration onetimeDonations,
@JsonProperty("phoneNumberIdentifiers") final Table phoneNumberIdentifiers,
@ -114,6 +116,7 @@ public class DynamoDbTables {
this.ecKeys = ecKeys;
this.ecSignedPreKeys = ecSignedPreKeys;
this.kemKeys = kemKeys;
this.pagedKemKeys = pagedKemKeys;
this.kemLastResortKeys = kemLastResortKeys;
this.messages = messages;
this.onetimeDonations = onetimeDonations;
@ -202,6 +205,12 @@ public class DynamoDbTables {
return kemKeys;
}
@NotNull
@Valid
public Table getPagedKemKeys() {
return pagedKemKeys;
}
@NotNull
@Valid
public Table getKemLastResortKeys() {

View File

@ -0,0 +1,15 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
public record PagedSingleUseKEMPreKeyStoreConfiguration(
@NotBlank String bucket,
@NotBlank String region) {
}

View File

@ -12,7 +12,7 @@ import java.util.List;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.util.ExactlySize;
public record SecureValueRecovery2Configuration(
public record SecureValueRecoveryConfiguration(
@NotBlank String uri,
@ExactlySize(32) SecretBytes userAuthenticationTokenSharedSecret,
@ExactlySize(32) SecretBytes userIdTokenSharedSecret,
@ -20,7 +20,7 @@ public record SecureValueRecovery2Configuration(
@NotNull @Valid CircuitBreakerConfiguration circuitBreaker,
@NotNull @Valid RetryConfiguration retry) {
public SecureValueRecovery2Configuration {
public SecureValueRecoveryConfiguration {
if (circuitBreaker == null) {
circuitBreaker = new CircuitBreakerConfiguration();
}

View File

@ -31,8 +31,7 @@ import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletionException;
import org.glassfish.jersey.server.ManagedAsync;
import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest;
@ -48,15 +47,12 @@ import org.whispersystems.textsecuregcm.entities.KeyTransparencySearchResponse;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
@Path("/v1/key-transparency")
@Tag(name = "KeyTransparency")
public class KeyTransparencyController {
private static final Logger LOGGER = LoggerFactory.getLogger(KeyTransparencyController.class);
@VisibleForTesting
static final Duration KEY_TRANSPARENCY_RPC_TIMEOUT = Duration.ofSeconds(15);
private final KeyTransparencyServiceClient keyTransparencyServiceClient;
public KeyTransparencyController(
@ -88,6 +84,7 @@ public class KeyTransparencyController {
@Path("/search")
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_SEARCH_PER_IP)
@Produces(MediaType.APPLICATION_JSON)
@ManagedAsync
public KeyTransparencySearchResponse search(
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final KeyTransparencySearchRequest request) {
@ -104,19 +101,16 @@ public class KeyTransparencyController {
.build()
));
return keyTransparencyServiceClient.search(
return new KeyTransparencySearchResponse(
keyTransparencyServiceClient.search(
ByteString.copyFrom(request.aci().toCompactByteArray()),
ByteString.copyFrom(request.aciIdentityKey().serialize()),
request.usernameHash().map(ByteString::copyFrom),
maybeE164SearchRequest,
request.lastTreeHeadSize(),
request.distinguishedTreeHeadSize(),
KEY_TRANSPARENCY_RPC_TIMEOUT)
.thenApply(KeyTransparencySearchResponse::new).join();
} catch (final CancellationException exception) {
LOGGER.error("Unexpected cancellation from key transparency service", exception);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception);
} catch (final CompletionException exception) {
request.distinguishedTreeHeadSize())
.toByteArray());
} catch (final StatusRuntimeException exception) {
handleKeyTransparencyServiceError(exception);
}
// This is unreachable
@ -140,6 +134,7 @@ public class KeyTransparencyController {
@Path("/monitor")
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP)
@Produces(MediaType.APPLICATION_JSON)
@ManagedAsync
public KeyTransparencyMonitorResponse monitor(
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final KeyTransparencyMonitorRequest request) {
@ -173,13 +168,9 @@ public class KeyTransparencyController {
usernameHashMonitorRequest,
e164MonitorRequest,
request.lastNonDistinguishedTreeHeadSize(),
request.lastDistinguishedTreeHeadSize(),
KEY_TRANSPARENCY_RPC_TIMEOUT).join());
} catch (final CancellationException exception) {
LOGGER.error("Unexpected cancellation from key transparency service", exception);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception);
} catch (final CompletionException exception) {
request.lastDistinguishedTreeHeadSize())
.toByteArray());
} catch (final StatusRuntimeException exception) {
handleKeyTransparencyServiceError(exception);
}
// This is unreachable
@ -202,6 +193,7 @@ public class KeyTransparencyController {
@Path("/distinguished")
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP)
@Produces(MediaType.APPLICATION_JSON)
@ManagedAsync
public KeyTransparencyDistinguishedKeyResponse getDistinguishedKey(
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@ -212,34 +204,28 @@ public class KeyTransparencyController {
requireNotAuthenticated(authenticatedAccount);
try {
return keyTransparencyServiceClient.getDistinguishedKey(lastTreeHeadSize, KEY_TRANSPARENCY_RPC_TIMEOUT)
.thenApply(KeyTransparencyDistinguishedKeyResponse::new)
.join();
} catch (final CancellationException exception) {
LOGGER.error("Unexpected cancellation from key transparency service", exception);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception);
} catch (final CompletionException exception) {
return new KeyTransparencyDistinguishedKeyResponse(
keyTransparencyServiceClient.getDistinguishedKey(lastTreeHeadSize)
.toByteArray());
} catch (final StatusRuntimeException exception) {
handleKeyTransparencyServiceError(exception);
}
// This is unreachable
return null;
}
private void handleKeyTransparencyServiceError(final CompletionException exception) {
final Throwable unwrapped = ExceptionUtils.unwrap(exception);
if (unwrapped instanceof StatusRuntimeException e) {
final Status.Code code = e.getStatus().getCode();
final String description = e.getStatus().getDescription();
switch (code) {
case NOT_FOUND -> throw new NotFoundException(description);
case PERMISSION_DENIED -> throw new ForbiddenException(description);
case INVALID_ARGUMENT -> throw new WebApplicationException(description, 422);
default -> throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, unwrapped);
private void handleKeyTransparencyServiceError(final StatusRuntimeException exception) {
final Status.Code code = exception.getStatus().getCode();
final String description = exception.getStatus().getDescription();
switch (code) {
case NOT_FOUND -> throw new NotFoundException(description);
case PERMISSION_DENIED -> throw new ForbiddenException(description);
case INVALID_ARGUMENT -> throw new WebApplicationException(description, 422);
default -> {
LOGGER.error("Unexpected error calling key transparency service", exception);
throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, exception);
}
}
LOGGER.error("Unexpected key transparency service failure", unwrapped);
throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, unwrapped);
}
private void requireNotAuthenticated(final Optional<AuthenticatedDevice> authenticatedAccount) {

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.controllers;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.auth.Auth;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
@ -27,7 +28,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsSelector;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecoveryConfiguration;
import org.whispersystems.textsecuregcm.entities.AuthCheckRequest;
import org.whispersystems.textsecuregcm.entities.AuthCheckResponseV2;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
@ -35,18 +36,19 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@Path("/v2/backup")
@Path("/v2/{name: backup|svr}")
@Tag(name = "Secure Value Recovery")
@Schema(description = "Note: /v2/backup is deprecated. Use /v2/svr instead.")
public class SecureValueRecovery2Controller {
private static final long MAX_AGE_SECONDS = TimeUnit.DAYS.toSeconds(30);
public static ExternalServiceCredentialsGenerator credentialsGenerator(final SecureValueRecovery2Configuration cfg) {
public static ExternalServiceCredentialsGenerator credentialsGenerator(final SecureValueRecoveryConfiguration cfg) {
return credentialsGenerator(cfg, Clock.systemUTC());
}
@VisibleForTesting
public static ExternalServiceCredentialsGenerator credentialsGenerator(final SecureValueRecovery2Configuration cfg, final Clock clock) {
public static ExternalServiceCredentialsGenerator credentialsGenerator(final SecureValueRecoveryConfiguration cfg, final Clock clock) {
return ExternalServiceCredentialsGenerator
.builder(cfg.userAuthenticationTokenSharedSecret())
.withUserDerivationKey(cfg.userIdTokenSharedSecret().value())

View File

@ -0,0 +1,63 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.auth.Auth;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;
import java.time.Clock;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecoveryConfiguration;
@Path("/v1/svrb")
@Tag(name = "Secure Value Recovery B")
public class SecureValueRecoveryBController {
public static ExternalServiceCredentialsGenerator credentialsGenerator(final SecureValueRecoveryConfiguration cfg) {
return credentialsGenerator(cfg, Clock.systemUTC());
}
@VisibleForTesting
public static ExternalServiceCredentialsGenerator credentialsGenerator(final SecureValueRecoveryConfiguration cfg,
final Clock clock) {
return ExternalServiceCredentialsGenerator
.builder(cfg.userAuthenticationTokenSharedSecret())
.withUserDerivationKey(cfg.userIdTokenSharedSecret().value())
.prependUsername(false)
.withDerivedUsernameTruncateLength(16)
.withClock(clock)
.build();
}
private final ExternalServiceCredentialsGenerator svrbCredentialGenerator;
public SecureValueRecoveryBController(final ExternalServiceCredentialsGenerator svrbCredentialGenerator) {
this.svrbCredentialGenerator = svrbCredentialGenerator;
}
@GET
@Path("/auth")
@Produces(MediaType.APPLICATION_JSON)
@Operation(
summary = "Generate credentials for SVRB",
description = """
Generate SVRB service credentials. Generated credentials have an expiration time of 1 day (subject to change)
"""
)
@ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public ExternalServiceCredentials getAuth(@Auth final AuthenticatedDevice auth) {
return svrbCredentialGenerator.generateFor(auth.accountIdentifier().toString());
}
}

View File

@ -15,6 +15,7 @@ import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.ExternalDocumentation;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
@ -66,6 +67,7 @@ import org.whispersystems.textsecuregcm.configuration.SubscriptionConfiguration;
import org.whispersystems.textsecuregcm.configuration.SubscriptionLevelConfiguration;
import org.whispersystems.textsecuregcm.entities.Badge;
import org.whispersystems.textsecuregcm.entities.PurchasableBadge;
import org.whispersystems.textsecuregcm.mappers.SubscriptionExceptionMapper;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.PaymentTime;
@ -218,6 +220,19 @@ public class SubscriptionController {
@DELETE
@Path("/{subscriberId}")
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Cancel a subscription", description = """
Cancels any current subscription at the end of the current subscription period.
Note: Apple IAP subscriptions do not support server-side cancellation, so this method should only be called after
cancelling a subscription from storekit to keep server data up to date.
""")
@ApiResponse(responseCode = "200", description = "All subscriptions cancelled")
@ApiResponse(responseCode = "403", description = "Account authentication is present")
@ApiResponse(responseCode = "404", description = "subscriberId is not found or malformed")
@ApiResponse(responseCode = "400", description = "The associated subscription is not a type that can be cancelled")
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public CompletableFuture<Response> deleteSubscriber(
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) throws SubscriptionException {
@ -230,6 +245,16 @@ public class SubscriptionController {
@Path("/{subscriberId}")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Create/refresh a subscriber", description = """
Creates a subscriber record if it does not exist, otherwise refreshes its last access time.
Subscribers MUST periodically hit this endpoint to update the access time on the subscription record. Subscribers
SHOULD attempt to make an update call approximately every 3 days. Not accessing this endpoint for an extended
period of time will result in the subscription being canceled.
""")
@ApiResponse(responseCode = "200", description = "The subscriber was successfully created or refreshed")
@ApiResponse(responseCode = "403", description = "subscriberId authentication failure OR account authentication is present")
@ApiResponse(responseCode = "404", description = "subscriberId is malformed")
public CompletableFuture<Response> updateSubscriber(
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) throws SubscriptionException {
@ -429,7 +454,9 @@ public class SubscriptionController {
@ApiResponse(responseCode = "403", description = "subscriberId authentication failure OR account authentication is present")
@ApiResponse(responseCode = "404", description = "No such subscriberId exists or subscriberId is malformed or the specified transaction does not exist")
@ApiResponse(responseCode = "409", description = "subscriberId is already linked to a processor that does not support appstore payments. Delete this subscriberId and use a new one.")
@ApiResponse(responseCode = "429", description = "Rate limit exceeded.")
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public CompletableFuture<SetSubscriptionLevelSuccessResponse> setAppStoreSubscription(
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@ -471,6 +498,9 @@ public class SubscriptionController {
@ApiResponse(responseCode = "403", description = "subscriberId authentication failure OR account authentication is present")
@ApiResponse(responseCode = "404", description = "No such subscriberId exists or subscriberId is malformed or the purchaseToken does not exist")
@ApiResponse(responseCode = "409", description = "subscriberId is already linked to a processor that does not support Play Billing. Delete this subscriberId and use a new one.")
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public CompletableFuture<SetSubscriptionLevelSuccessResponse> setPlayStoreSubscription(
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@ -625,6 +655,9 @@ public class SubscriptionController {
@ApiResponse(responseCode = "200", description = "The subscriberId exists", content = @Content(schema = @Schema(implementation = GetSubscriptionInformationResponse.class)))
@ApiResponse(responseCode = "403", description = "subscriberId authentication failure OR account authentication is present")
@ApiResponse(responseCode = "404", description = "No such subscriberId exists or subscriberId is malformed")
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public CompletableFuture<Response> getSubscriptionInformation(
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) throws SubscriptionException {
@ -650,16 +683,64 @@ public class SubscriptionController {
.orElseGet(() -> Response.ok(new GetSubscriptionInformationResponse(null, null)).build()));
}
public record GetReceiptCredentialsRequest(@NotEmpty byte[] receiptCredentialRequest) {
public record GetReceiptCredentialsRequest(
@Schema(description = "A ReceiptCredentialRequest encoded in standard base64 with padding")
@NotEmpty byte[] receiptCredentialRequest) {
}
public record GetReceiptCredentialsResponse(@NotEmpty byte[] receiptCredentialResponse) {
public record GetReceiptCredentialsResponse(
@Schema(description = "A ReceiptCredentialResponse encoded in standard base64 with padding")
@NotEmpty byte[] receiptCredentialResponse) {
}
@POST
@Path("/{subscriberId}/receipt_credentials")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Create receipt credentials", description = """
Create a receipt from a valid payment invoice that can be used to obtain an entitlement
This request is repeatable so long as the ReceiptCredentialRequest remains the same. Clients should use the same
ReceiptCredentialRequest value until they attempt to redeem the resulting ReceiptCredentialPresentation. After
this point, the ReceiptCredentialRequest MUST NOT be reused or you may not be able to redeem a valid payment
invoice. Clients SHOULD retry requests at this endpoint with the same ReceiptCredentialRequest value until
receiving a response. After receiving a response, clients should then compute the ReceiptCredentialPresentation
and redeem it at the receipt redemption endpoint. Once the first attempt is made there, the same
ReceiptCredentialRequest MUST NOT be used again to request receipt credentials.
Note that you may in fact redeem TWO or more invoices for the same ReceiptCredentialRequest while retrying this
operation if a later invoice gets paid while you are retrying. However, the returned receipt is always for the
latest invoice, so it will have the latest expiration possible and no entitlement time will be lost. The important
thing is not to reuse ReceiptCredentialRequest after you have started attempting to redeem the associated
ReceiptCredentialPresentation. Then you may produce a ReceiptCredentialPresentation for a later invoice that
cannot be redeemed.
Clients MUST validate that the generated receipt credential's level and expiration matches their expectations.
""")
@ApiResponse(responseCode = "200", description = "Successfully created receipt", content = @Content(schema = @Schema(implementation = GetReceiptCredentialsResponse.class)))
@ApiResponse(responseCode = "204", description = "No invoice has been issued for this subscription OR invoice is in 'open' state")
@ApiResponse(responseCode = "400", description = "Bad ReceiptCredentialRequest")
@ApiResponse(responseCode = "402", description = "Invoice is in any state other than 'open' or 'paid'. May include chargeFailure details in body.",
content = @Content(schema = @Schema(
nullable = true,
example = """
{
"chargeFailure": {
"code": "incorrect_account_holder_name",
"message": "The transaction can't be processed because your customer's account information is missing [...]",
"outcomeNetworkStatus": "declined_by_network",
"outcomeReason": "generic_decline",
"outcomeType": "issuer_declined"
}
}
""",
implementation = SubscriptionExceptionMapper.ChargeFailureResponse.class)))
@ApiResponse(responseCode = "403", description = "subscriberId authentication failure OR account authentication is present")
@ApiResponse(responseCode = "404", description = "subscriberId is not found OR malformed OR no subscription setup on the subscriber id")
@ApiResponse(responseCode = "409", description = "latest paid receipt on subscription was already redeemed for a receipt credential but with a different receipt credential request")
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public CompletableFuture<Response> createSubscriptionReceiptCredentials(
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,

View File

@ -1,64 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.Map;
import javax.annotation.Nullable;
import jakarta.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record AuthCheckResponseV3(
@Schema(description = """
A dictionary with the auth check results, keyed by the token corresponding token provided in the request.
""")
@NotNull Map<String, Result> matches) {
public record Result(
@Schema(description = "The status of the credential. Either match, no-match, or invalid")
CredentialStatus status,
@Schema(description = """
If the credential was a match, the stored shareSet that can be used to restore a value from SVR. Encoded in
standard un-padded base64.
""", implementation = String.class)
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
@Nullable byte[] shareSet) {
public static Result invalid() {
return new Result(CredentialStatus.INVALID, null);
}
public static Result noMatch() {
return new Result(CredentialStatus.NO_MATCH, null);
}
public static Result match(@Nullable final byte[] shareSet) {
return new Result(CredentialStatus.MATCH, shareSet);
}
}
public enum CredentialStatus {
MATCH("match"),
NO_MATCH("no-match"),
INVALID("invalid");
private final String clientCode;
CredentialStatus(final String clientCode) {
this.clientCode = clientCode;
}
@JsonValue
public String clientCode() {
return clientCode;
}
}
}

View File

@ -18,7 +18,7 @@ import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.DirectoryV2ClientConfiguration;
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecoveryConfiguration;
enum ExternalServiceDefinitions {
DIRECTORY(ExternalServiceType.EXTERNAL_SERVICE_TYPE_DIRECTORY, (chatConfig, clock) -> {
@ -38,7 +38,17 @@ enum ExternalServiceDefinitions {
.build();
}),
SVR(ExternalServiceType.EXTERNAL_SERVICE_TYPE_SVR, (chatConfig, clock) -> {
final SecureValueRecovery2Configuration cfg = chatConfig.getSvr2Configuration();
final SecureValueRecoveryConfiguration cfg = chatConfig.getSvr2Configuration();
return ExternalServiceCredentialsGenerator
.builder(cfg.userAuthenticationTokenSharedSecret())
.withUserDerivationKey(cfg.userIdTokenSharedSecret().value())
.prependUsername(false)
.withDerivedUsernameTruncateLength(16)
.withClock(clock)
.build();
}),
SVRB(ExternalServiceType.EXTERNAL_SERVICE_TYPE_SVRB, (chatConfig, clock) -> {
final SecureValueRecoveryConfiguration cfg = chatConfig.getSvrbConfiguration();
return ExternalServiceCredentialsGenerator
.builder(cfg.userAuthenticationTokenSharedSecret())
.withUserDerivationKey(cfg.userIdTokenSharedSecret().value())

View File

@ -0,0 +1,140 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Status;
import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.ConsistencyParameters;
import org.signal.keytransparency.client.DistinguishedRequest;
import org.signal.keytransparency.client.DistinguishedResponse;
import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.MonitorRequest;
import org.signal.keytransparency.client.MonitorResponse;
import org.signal.keytransparency.client.SearchRequest;
import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.SimpleKeyTransparencyQueryServiceGrpc;
import org.signal.keytransparency.client.UsernameHashMonitorRequest;
import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
public class KeyTransparencyGrpcService extends
SimpleKeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceImplBase {
@VisibleForTesting
static final int COMMITMENT_INDEX_LENGTH = 32;
private final RateLimiters rateLimiters;
private final KeyTransparencyServiceClient client;
public KeyTransparencyGrpcService(final RateLimiters rateLimiters,
final KeyTransparencyServiceClient client) {
this.rateLimiters = rateLimiters;
this.client = client;
}
@Override
public SearchResponse search(final SearchRequest request) throws RateLimitExceededException {
rateLimiters.getKeyTransparencySearchLimiter().validate(RequestAttributesUtil.getRemoteAddress().getHostAddress());
return client.search(validateSearchRequest(request));
}
@Override
public MonitorResponse monitor(final MonitorRequest request) throws RateLimitExceededException {
rateLimiters.getKeyTransparencyMonitorLimiter().validate(RequestAttributesUtil.getRemoteAddress().getHostAddress());
return client.monitor(validateMonitorRequest(request));
}
@Override
public DistinguishedResponse distinguished(final DistinguishedRequest request) throws RateLimitExceededException {
rateLimiters.getKeyTransparencyDistinguishedLimiter().validate(RequestAttributesUtil.getRemoteAddress().getHostAddress());
// A client's very first distinguished request will not have a "last" parameter
if (request.hasLast() && request.getLast() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Last tree head size must be positive").asRuntimeException();
}
return client.distinguished(request);
}
private SearchRequest validateSearchRequest(final SearchRequest request) {
if (request.hasE164SearchRequest()) {
final E164SearchRequest e164SearchRequest = request.getE164SearchRequest();
if (e164SearchRequest.getUnidentifiedAccessKey().isEmpty() != e164SearchRequest.getE164().isEmpty()) {
throw Status.INVALID_ARGUMENT.withDescription("Unidentified access key and E164 must be provided together or not at all").asRuntimeException();
}
}
if (!request.getConsistency().hasDistinguished()) {
throw Status.INVALID_ARGUMENT.withDescription("Must provide distinguished tree head size").asRuntimeException();
}
validateConsistencyParameters(request.getConsistency());
return request;
}
private MonitorRequest validateMonitorRequest(final MonitorRequest request) {
final AciMonitorRequest aciMonitorRequest = request.getAci();
try {
AciServiceIdentifier.fromBytes(aciMonitorRequest.getAci().toByteArray());
} catch (IllegalArgumentException e) {
throw Status.INVALID_ARGUMENT.withDescription("Invalid ACI").asRuntimeException();
}
if (aciMonitorRequest.getEntryPosition() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Aci entry position must be positive").asRuntimeException();
}
if (aciMonitorRequest.getCommitmentIndex().size() != COMMITMENT_INDEX_LENGTH) {
throw Status.INVALID_ARGUMENT.withDescription("Aci commitment index must be 32 bytes").asRuntimeException();
}
if (request.hasUsernameHash()) {
final UsernameHashMonitorRequest usernameHashMonitorRequest = request.getUsernameHash();
if (usernameHashMonitorRequest.getUsernameHash().isEmpty()) {
throw Status.INVALID_ARGUMENT.withDescription("Username hash cannot be empty").asRuntimeException();
}
if (usernameHashMonitorRequest.getUsernameHash().size() != AccountController.USERNAME_HASH_LENGTH) {
throw Status.INVALID_ARGUMENT.withDescription("Invalid username hash length").asRuntimeException();
}
if (usernameHashMonitorRequest.getEntryPosition() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Username hash entry position must be positive").asRuntimeException();
}
if (usernameHashMonitorRequest.getCommitmentIndex().size() != COMMITMENT_INDEX_LENGTH) {
throw Status.INVALID_ARGUMENT.withDescription("Username hash commitment index must be 32 bytes").asRuntimeException();
}
}
if (request.hasE164()) {
final E164MonitorRequest e164MonitorRequest = request.getE164();
if (e164MonitorRequest.getE164().isEmpty()) {
throw Status.INVALID_ARGUMENT.withDescription("E164 cannot be empty").asRuntimeException();
}
if (e164MonitorRequest.getEntryPosition() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("E164 entry position must be positive").asRuntimeException();
}
if (e164MonitorRequest.getCommitmentIndex().size() != COMMITMENT_INDEX_LENGTH) {
throw Status.INVALID_ARGUMENT.withDescription("E164 commitment index must be 32 bytes").asRuntimeException();
}
}
if (!request.getConsistency().hasDistinguished() || !request.getConsistency().hasLast()) {
throw Status.INVALID_ARGUMENT.withDescription("Must provide distinguished and last tree head sizes").asRuntimeException();
}
validateConsistencyParameters(request.getConsistency());
return request;
}
private static void validateConsistencyParameters(final ConsistencyParameters consistency) {
if (consistency.getDistinguished() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Distinguished tree head size must be positive").asRuntimeException();
}
if (consistency.hasLast() && consistency.getLast() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Last tree head size must be positive").asRuntimeException();
}
}
}

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import java.util.UUID;
import org.signal.chat.common.IdentityType;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
@ -40,4 +39,12 @@ public class ServiceIdentifierUtil {
.setUuid(UUIDUtil.toByteString(serviceIdentifier.uuid()))
.build();
}
public static ByteString toCompactByteString(final ServiceIdentifier serviceIdentifier) {
return ByteString.copyFrom(serviceIdentifier.toCompactByteArray());
}
public static ServiceIdentifier fromByteString(final ByteString byteString) {
return ServiceIdentifier.fromBytes(byteString.toByteArray());
}
}

View File

@ -1,6 +1,5 @@
package org.whispersystems.textsecuregcm.keytransparency;
import com.google.protobuf.AbstractMessageLite;
import com.google.protobuf.ByteString;
import io.dropwizard.lifecycle.Managed;
import io.grpc.ChannelCredentials;
@ -20,44 +19,43 @@ import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.ConsistencyParameters;
import org.signal.keytransparency.client.DistinguishedRequest;
import org.signal.keytransparency.client.DistinguishedResponse;
import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc;
import org.signal.keytransparency.client.MonitorRequest;
import org.signal.keytransparency.client.MonitorResponse;
import org.signal.keytransparency.client.SearchRequest;
import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.UsernameHashMonitorRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.CompletableFutureUtil;
public class KeyTransparencyServiceClient implements Managed {
private static final String DAYS_UNTIL_CLIENT_CERTIFICATE_EXPIRATION_GAUGE_NAME =
MetricsUtil.name(KeyTransparencyServiceClient.class, "daysUntilClientCertificateExpiration");
private static final Duration KEY_TRANSPARENCY_RPC_TIMEOUT = Duration.ofSeconds(15);
private static final Logger logger = LoggerFactory.getLogger(KeyTransparencyServiceClient.class);
private final Executor callbackExecutor;
private final String host;
private final int port;
private final ChannelCredentials tlsChannelCredentials;
private ManagedChannel channel;
private KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceFutureStub stub;
private KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceBlockingStub stub;
public KeyTransparencyServiceClient(
final String host,
final int port,
final String tlsCertificate,
final String clientCertificate,
final String clientPrivateKey,
final Executor callbackExecutor
final String clientPrivateKey
) throws IOException {
this.host = host;
this.port = port;
@ -76,7 +74,6 @@ public class KeyTransparencyServiceClient implements Managed {
configureClientCertificateMetrics(clientCertificate);
}
this.callbackExecutor = callbackExecutor;
}
private void configureClientCertificateMetrics(String clientCertificate) {
@ -113,14 +110,13 @@ public class KeyTransparencyServiceClient implements Managed {
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public CompletableFuture<byte[]> search(
public SearchResponse search(
final ByteString aci,
final ByteString aciIdentityKey,
final Optional<ByteString> usernameHash,
final Optional<E164SearchRequest> e164SearchRequest,
final Optional<Long> lastTreeHeadSize,
final long distinguishedTreeHeadSize,
final Duration timeout) {
final long distinguishedTreeHeadSize) {
final SearchRequest.Builder searchRequestBuilder = SearchRequest.newBuilder()
.setAci(aci)
.setAciIdentityKey(aciIdentityKey);
@ -133,19 +129,20 @@ public class KeyTransparencyServiceClient implements Managed {
lastTreeHeadSize.ifPresent(consistency::setLast);
searchRequestBuilder.setConsistency(consistency.build());
return search(searchRequestBuilder.build());
}
return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout))
.search(searchRequestBuilder.build()), callbackExecutor)
.thenApply(AbstractMessageLite::toByteArray);
public SearchResponse search(final SearchRequest request) {
return stub.withDeadline(toDeadline(KEY_TRANSPARENCY_RPC_TIMEOUT))
.search(request);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public CompletableFuture<byte[]> monitor(final AciMonitorRequest aciMonitorRequest,
public MonitorResponse monitor(final AciMonitorRequest aciMonitorRequest,
final Optional<UsernameHashMonitorRequest> usernameHashMonitorRequest,
final Optional<E164MonitorRequest> e164MonitorRequest,
final long lastTreeHeadSize,
final long distinguishedTreeHeadSize,
final Duration timeout) {
final long distinguishedTreeHeadSize) {
final MonitorRequest.Builder monitorRequestBuilder = MonitorRequest.newBuilder()
.setAci(aciMonitorRequest)
.setConsistency(ConsistencyParameters.newBuilder()
@ -155,20 +152,26 @@ public class KeyTransparencyServiceClient implements Managed {
usernameHashMonitorRequest.ifPresent(monitorRequestBuilder::setUsernameHash);
e164MonitorRequest.ifPresent(monitorRequestBuilder::setE164);
return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout))
.monitor(monitorRequestBuilder.build()), callbackExecutor)
.thenApply(AbstractMessageLite::toByteArray);
return monitor(monitorRequestBuilder.build());
}
public MonitorResponse monitor(final MonitorRequest request) {
return stub.withDeadline(toDeadline(KEY_TRANSPARENCY_RPC_TIMEOUT))
.monitor(request);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public CompletableFuture<byte[]> getDistinguishedKey(final Optional<Long> lastTreeHeadSize, final Duration timeout) {
public DistinguishedResponse getDistinguishedKey(final Optional<Long> lastTreeHeadSize) {
final DistinguishedRequest request = lastTreeHeadSize.map(
last -> DistinguishedRequest.newBuilder().setLast(last).build())
.orElseGet(DistinguishedRequest::getDefaultInstance);
return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)).distinguished(request),
callbackExecutor)
.thenApply(AbstractMessageLite::toByteArray);
return distinguished(request);
}
public DistinguishedResponse distinguished(final DistinguishedRequest request) {
return stub.withDeadline(toDeadline(KEY_TRANSPARENCY_RPC_TIMEOUT))
.distinguished(request);
}
private static Deadline toDeadline(final Duration timeout) {
@ -180,7 +183,7 @@ public class KeyTransparencyServiceClient implements Managed {
channel = Grpc.newChannelBuilderForAddress(host, port, tlsChannelCredentials)
.idleTimeout(1, TimeUnit.MINUTES)
.build();
stub = KeyTransparencyQueryServiceGrpc.newFutureStub(channel);
stub = KeyTransparencyQueryServiceGrpc.newBlockingStub(channel);
}
@Override

View File

@ -206,4 +206,16 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public RateLimiter getWaitForTransferArchiveLimiter() {
return forDescriptor(For.WAIT_FOR_TRANSFER_ARCHIVE);
}
public RateLimiter getKeyTransparencySearchLimiter() {
return forDescriptor(For.KEY_TRANSPARENCY_SEARCH_PER_IP);
}
public RateLimiter getKeyTransparencyDistinguishedLimiter() {
return forDescriptor(For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP);
}
public RateLimiter getKeyTransparencyMonitorLimiter() {
return forDescriptor(For.KEY_TRANSPARENCY_MONITOR_PER_IP);
}
}

View File

@ -13,11 +13,14 @@ import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.ext.ExceptionMapper;
import java.util.Map;
import org.whispersystems.textsecuregcm.storage.SubscriptionException;
import org.whispersystems.textsecuregcm.subscriptions.ChargeFailure;
public class SubscriptionExceptionMapper implements ExceptionMapper<SubscriptionException> {
@VisibleForTesting
public static final int PROCESSOR_ERROR_STATUS_CODE = 440;
public record ChargeFailureResponse(String processor, ChargeFailure chargeFailure) {}
@Override
public Response toResponse(final SubscriptionException exception) {
@ -31,17 +34,14 @@ public class SubscriptionExceptionMapper implements ExceptionMapper<Subscription
}
if (exception instanceof SubscriptionException.ProcessorException e) {
return Response.status(PROCESSOR_ERROR_STATUS_CODE)
.entity(Map.of(
"processor", e.getProcessor().name(),
"chargeFailure", e.getChargeFailure()
))
.entity(new ChargeFailureResponse(e.getProcessor().name(), e.getChargeFailure()))
.type(MediaType.APPLICATION_JSON_TYPE)
.build();
}
if (exception instanceof SubscriptionException.ChargeFailurePaymentRequired e) {
return Response
.status(Response.Status.PAYMENT_REQUIRED)
.entity(Map.of("chargeFailure", e.getChargeFailure()))
.entity(new ChargeFailureResponse(e.getProcessor().name(), e.getChargeFailure()))
.type(MediaType.APPLICATION_JSON_TYPE)
.build();
}

View File

@ -248,7 +248,7 @@ public class LettuceShardCircuitBreaker implements NettyCustomizer {
// RedisNoScriptException doesnt indicate a fault the breaker can protect
if (throwable != null && !(throwable instanceof RedisNoScriptException)) {
breaker.onError(durationNanos, TimeUnit.NANOSECONDS, throwable);
logger.warn("Command completed with error", throwable);
logger.warn("Command completed with error for: {}/{}", clusterName, shardAddress, throwable);
} else {
breaker.onSuccess(durationNanos, TimeUnit.NANOSECONDS);
}

View File

@ -21,7 +21,7 @@ import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecoveryConfiguration;
import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient;
import org.whispersystems.textsecuregcm.util.HttpUtils;
@ -39,7 +39,7 @@ public class SecureValueRecovery2Client {
public SecureValueRecovery2Client(final ExternalServiceCredentialsGenerator secureValueRecoveryCredentialsGenerator,
final Executor executor, final ScheduledExecutorService retryExecutor,
final SecureValueRecovery2Configuration configuration)
final SecureValueRecoveryConfiguration configuration)
throws CertificateException {
this.secureValueRecoveryCredentialsGenerator = secureValueRecoveryCredentialsGenerator;
this.deleteUri = URI.create(configuration.uri()).resolve(DELETE_PATH);

View File

@ -0,0 +1,24 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
/**
* The prekey pages stored for a particular device
*
* @param identifier The account identifier or phone number identifier that the keys belong to
* @param deviceId The device identifier
* @param currentPage If present, the active stored page prekeys are being distributed from
* @param pageIdToLastModified The last modified time for all the device's stored pages, keyed by the pageId
*/
public record DeviceKEMPreKeyPages(
UUID identifier, byte deviceId,
Optional<UUID> currentPage,
Map<UUID, Instant> pageIdToLastModified) {}

View File

@ -0,0 +1,106 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.UUID;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.grpc.ServiceIdentifierUtil;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
/**
* Provides utility methods for "compressing" and "expanding" envelopes. Historically UUID-like fields in envelopes have
* been represented as strings (e.g. "c15f1dfb-ae2c-43a8-9bb9-baba97ac416c"), but <em>could</em> be represented as more
* compact byte arrays instead. Existing clients generally expect string representations (though that should change in
* the near future), but we can use the more compressed forms at rest for more efficient storage and transfer.
*/
public class EnvelopeUtil {
/**
* Converts all "compressible" UUID-like fields in the given envelope to more compact binary representations.
*
* @param envelope the envelope to compress
*
* @return an envelope with string-based UUID-like fields compressed to binary representations
*/
public static MessageProtos.Envelope compress(final MessageProtos.Envelope envelope) {
final MessageProtos.Envelope.Builder builder = envelope.toBuilder();
if (builder.hasSourceServiceId()) {
final ServiceIdentifier sourceServiceId = ServiceIdentifier.valueOf(builder.getSourceServiceId());
builder.setSourceServiceIdBinary(ServiceIdentifierUtil.toCompactByteString(sourceServiceId));
builder.clearSourceServiceId();
}
if (builder.hasDestinationServiceId()) {
final ServiceIdentifier destinationServiceId = ServiceIdentifier.valueOf(builder.getDestinationServiceId());
builder.setDestinationServiceIdBinary(ServiceIdentifierUtil.toCompactByteString(destinationServiceId));
builder.clearDestinationServiceId();
}
if (builder.hasServerGuid()) {
final UUID serverGuid = UUID.fromString(builder.getServerGuid());
builder.setServerGuidBinary(UUIDUtil.toByteString(serverGuid));
builder.clearServerGuid();
}
if (builder.hasUpdatedPni()) {
final UUID updatedPni = UUID.fromString(builder.getUpdatedPni());
builder.setUpdatedPniBinary(UUIDUtil.toByteString(updatedPni));
builder.clearUpdatedPni();
}
return builder.build();
}
/**
* "Expands" all binary representations of UUID-like fields to string representations to meet current client
* expectations.
*
* @param envelope the envelope to expand
*
* @return an envelope with binary representations of UUID-like fields expanded to string representations
*/
public static MessageProtos.Envelope expand(final MessageProtos.Envelope envelope) {
final MessageProtos.Envelope.Builder builder = envelope.toBuilder();
if (builder.hasSourceServiceIdBinary()) {
final ServiceIdentifier sourceServiceId =
ServiceIdentifierUtil.fromByteString(builder.getSourceServiceIdBinary());
builder.setSourceServiceId(sourceServiceId.toServiceIdentifierString());
builder.clearSourceServiceIdBinary();
}
if (builder.hasDestinationServiceIdBinary()) {
final ServiceIdentifier destinationServiceId =
ServiceIdentifierUtil.fromByteString(builder.getDestinationServiceIdBinary());
builder.setDestinationServiceId(destinationServiceId.toServiceIdentifierString());
builder.clearDestinationServiceIdBinary();
}
if (builder.hasServerGuidBinary()) {
final UUID serverGuid = UUIDUtil.fromByteString(builder.getServerGuidBinary());
builder.setServerGuid(serverGuid.toString());
builder.clearServerGuidBinary();
}
if (builder.hasUpdatedPniBinary()) {
final UUID updatedPni = UUIDUtil.fromByteString(builder.getUpdatedPniBinary());
// Note that expanded envelopes include BOTH forms of the `updatedPni` field
builder.setUpdatedPni(updatedPni.toString());
}
return builder.build();
}
}

View File

@ -0,0 +1,136 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.nio.ByteBuffer;
import java.util.List;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
class KEMPreKeyPage {
static final byte FORMAT = 1;
// Serialized pages start with a 4 byte magic constant, followed by 3 bytes of 0s and then the format byte
static final int HEADER_MAGIC = 0xC21C6DB8;
static final int HEADER_SIZE = 8;
// Serialize bigendian to produce the serialized page header
private static final long HEADER = ((long) HEADER_MAGIC) << 32L | (long) FORMAT;
// The length of libsignal's serialized KEM public key, which is a single-byte version followed by the public key
private static final int SERIALIZED_PUBKEY_LENGTH = 1569;
private static final int SERIALIZED_SIGNATURE_LENGTH = 64;
private static final int KEY_ID_LENGTH = Long.BYTES;
// The internal prefix byte libsignal uses to indicate a key is of type KEMKeyType.KYBER_1024. Currently, this
// is the only type of key allowed to be written to a prekey page
private static final byte KEM_KEY_TYPE_KYBER_1024 = 0x08;
@VisibleForTesting
static final int SERIALIZED_PREKEY_LENGTH = KEY_ID_LENGTH + SERIALIZED_PUBKEY_LENGTH + SERIALIZED_SIGNATURE_LENGTH;
private KEMPreKeyPage() {}
/**
* Serialize the list of preKeys into a single buffer
*
* @param format the format to serialize as. Currently, the only valid format is {@link KEMPreKeyPage#FORMAT}
* @param preKeys the preKeys to serialize
* @return The serialized buffer and a format to store alongside the buffer
*/
static ByteBuffer serialize(final byte format, final List<KEMSignedPreKey> preKeys) {
if (format != FORMAT) {
throw new IllegalArgumentException("Unknown format: " + format + ", must be " + FORMAT);
}
if (preKeys.isEmpty()) {
throw new IllegalArgumentException("PreKeys cannot be empty");
}
final ByteBuffer buffer = ByteBuffer.allocate(HEADER_SIZE + SERIALIZED_PREKEY_LENGTH * preKeys.size());
buffer.putLong(HEADER);
for (KEMSignedPreKey preKey : preKeys) {
buffer.putLong(preKey.keyId());
final byte[] publicKeyBytes = preKey.serializedPublicKey();
if (publicKeyBytes[0] != KEM_KEY_TYPE_KYBER_1024) {
// 0x08 is libsignal's current KEM key format. If some future version of libsignal supports additional KEM
// keys, we'll have to roll out read support before rolling out write support. Otherwise, we may write keys
// to storage that are not readable by other chat instances.
throw new IllegalArgumentException("Format 1 only supports " + KEM_KEY_TYPE_KYBER_1024 + " public keys");
}
if (publicKeyBytes.length != SERIALIZED_PUBKEY_LENGTH) {
throw new IllegalArgumentException("Unexpected public key length " + publicKeyBytes.length);
}
buffer.put(publicKeyBytes);
if (preKey.signature().length != SERIALIZED_SIGNATURE_LENGTH) {
throw new IllegalArgumentException("prekey signature length must be " + SERIALIZED_SIGNATURE_LENGTH);
}
buffer.put(preKey.signature());
}
buffer.flip();
return buffer;
}
/**
* Deserialize a single {@link KEMSignedPreKey}
*
* @param format The format of the page this buffer is from
* @param buffer The key to deserialize. The position of the buffer should be the start of the key, and the limit of
* the buffer should be the end of the key. After a successful deserialization the position of the
* buffer will be the limit
* @return The deserialized key
* @throws InvalidKeyException
*/
static KEMSignedPreKey deserializeKey(int format, ByteBuffer buffer) throws InvalidKeyException {
if (format != FORMAT) {
throw new IllegalArgumentException("Unknown prekey page format " + format);
}
if (buffer.remaining() != SERIALIZED_PREKEY_LENGTH) {
throw new IllegalArgumentException("PreKeys must be length " + SERIALIZED_PREKEY_LENGTH);
}
final long keyId = buffer.getLong();
final byte[] publicKeyBytes = new byte[SERIALIZED_PUBKEY_LENGTH];
buffer.get(publicKeyBytes);
final KEMPublicKey kemPublicKey = new KEMPublicKey(publicKeyBytes);
final byte[] signature = new byte[SERIALIZED_SIGNATURE_LENGTH];
buffer.get(signature);
return new KEMSignedPreKey(keyId, kemPublicKey, signature);
}
/**
* The location of a specific key within a serialized page
*/
record KeyLocation(int start, int length) {
int getStartInclusive() {
return start;
}
int getEndInclusive() {
return start + length - 1;
}
}
/**
* Get the location of the key at the provided index within a page
*
* @param format The format of the page
* @param index The index of the key to retrieve
* @return An {@link KeyLocation} indicating where within the page the key is
*/
static KeyLocation keyLocation(final int format, final int index) {
if (format != FORMAT) {
throw new IllegalArgumentException("unknown format " + format);
}
final int startOffset = HEADER_SIZE + (index * SERIALIZED_PREKEY_LENGTH);
return new KeyLocation(startOffset, SERIALIZED_PREKEY_LENGTH);
}
}

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage;
import io.micrometer.core.instrument.Metrics;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@ -12,26 +13,37 @@ import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
public class KeysManager {
private final SingleUseECPreKeyStore ecPreKeys;
private final SingleUseKEMPreKeyStore pqPreKeys;
private final PagedSingleUseKEMPreKeyStore pagedPqPreKeys;
private final RepeatedUseECSignedPreKeyStore ecSignedPreKeys;
private final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
public static String PAGED_KEYS_EXPERIMENT_NAME = "pagedPreKeys";
private static final String TAKE_PQ_NAME = MetricsUtil.name(KeysManager.class, "takePq");
public KeysManager(
final DynamoDbAsyncClient dynamoDbAsyncClient,
final String ecTableName,
final String pqTableName,
final String ecSignedPreKeysTableName,
final String pqLastResortTableName) {
this.ecPreKeys = new SingleUseECPreKeyStore(dynamoDbAsyncClient, ecTableName);
this.pqPreKeys = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, pqTableName);
this.ecSignedPreKeys = new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, ecSignedPreKeysTableName);
this.pqLastResortKeys = new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, pqLastResortTableName);
final SingleUseECPreKeyStore ecPreKeys,
final SingleUseKEMPreKeyStore pqPreKeys,
final PagedSingleUseKEMPreKeyStore pagedPqPreKeys,
final RepeatedUseECSignedPreKeyStore ecSignedPreKeys,
final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.ecPreKeys = ecPreKeys;
this.pqPreKeys = pqPreKeys;
this.pagedPqPreKeys = pagedPqPreKeys;
this.ecSignedPreKeys = ecSignedPreKeys;
this.pqLastResortKeys = pqLastResortKeys;
this.experimentEnrollmentManager = experimentEnrollmentManager;
}
public TransactWriteItem buildWriteItemForEcSignedPreKey(final UUID identifier,
@ -76,22 +88,31 @@ public class KeysManager {
);
}
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final byte deviceId, final ECSignedPreKey ecSignedPreKey) {
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final byte deviceId,
final ECSignedPreKey ecSignedPreKey) {
return ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey);
}
public CompletableFuture<Void> storePqLastResort(final UUID identifier, final byte deviceId, final KEMSignedPreKey lastResortKey) {
public CompletableFuture<Void> storePqLastResort(final UUID identifier, final byte deviceId,
final KEMSignedPreKey lastResortKey) {
return pqLastResortKeys.store(identifier, deviceId, lastResortKey);
}
public CompletableFuture<Void> storeEcOneTimePreKeys(final UUID identifier, final byte deviceId,
final List<ECPreKey> preKeys) {
final List<ECPreKey> preKeys) {
return ecPreKeys.store(identifier, deviceId, preKeys);
}
public CompletableFuture<Void> storeKemOneTimePreKeys(final UUID identifier, final byte deviceId,
final List<KEMSignedPreKey> preKeys) {
return pqPreKeys.store(identifier, deviceId, preKeys);
final List<KEMSignedPreKey> preKeys) {
final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME);
final CompletableFuture<Void> deleteOtherKeys = enrolledInPagedKeys
? pqPreKeys.delete(identifier, deviceId)
: pagedPqPreKeys.delete(identifier, deviceId);
return deleteOtherKeys.thenCompose(ignored -> enrolledInPagedKeys
? pagedPqPreKeys.store(identifier, deviceId, preKeys)
: pqPreKeys.store(identifier, deviceId, preKeys));
}
public CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final byte deviceId) {
@ -99,10 +120,36 @@ public class KeysManager {
}
public CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final byte deviceId) {
return pqPreKeys.take(identifier, deviceId)
final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME);
return tagTakePQ(pagedPqPreKeys.take(identifier, deviceId), PQSource.PAGE, enrolledInPagedKeys)
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(ignored -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> tagTakePQ(pqPreKeys.take(identifier, deviceId), PQSource.ROW, enrolledInPagedKeys)))
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> pqLastResortKeys.find(identifier, deviceId)));
.orElseGet(() -> tagTakePQ(pqLastResortKeys.find(identifier, deviceId), PQSource.LAST_RESORT, enrolledInPagedKeys)));
}
private enum PQSource {
PAGE,
ROW,
LAST_RESORT
}
private CompletableFuture<Optional<KEMSignedPreKey>> tagTakePQ(CompletableFuture<Optional<KEMSignedPreKey>> prekey, final PQSource source, final boolean enrolledInPagedKeys) {
return prekey.thenApply(maybeSingleUsePreKey -> {
final Optional<String> maybeSourceTag = maybeSingleUsePreKey
// If we found a PK, use this source tag
.map(ignore -> source.name())
// If we didn't and this is our last resort, we didn't find a PK
.or(() -> source == PQSource.LAST_RESORT ? Optional.of("absent") : Optional.empty());
maybeSourceTag.ifPresent(sourceTag -> {
Metrics.counter(TAKE_PQ_NAME,
"source", sourceTag,
"enrolled", Boolean.toString(enrolledInPagedKeys))
.increment();
});
return maybeSingleUsePreKey;
});
}
public CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final byte deviceId) {
@ -118,20 +165,48 @@ public class KeysManager {
}
public CompletableFuture<Integer> getPqCount(final UUID identifier, final byte deviceId) {
return pqPreKeys.getCount(identifier, deviceId);
return pagedPqPreKeys.getCount(identifier, deviceId).thenCompose(count -> count == 0
? pqPreKeys.getCount(identifier, deviceId)
: CompletableFuture.completedFuture(count));
}
public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID identifier) {
return CompletableFuture.allOf(
ecPreKeys.delete(identifier),
pqPreKeys.delete(identifier)
pqPreKeys.delete(identifier),
pagedPqPreKeys.delete(identifier)
);
}
public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) {
return CompletableFuture.allOf(
ecPreKeys.delete(accountUuid, deviceId),
pqPreKeys.delete(accountUuid, deviceId)
ecPreKeys.delete(accountUuid, deviceId),
pqPreKeys.delete(accountUuid, deviceId),
pagedPqPreKeys.delete(accountUuid, deviceId)
);
}
/**
* List all the current remotely stored prekey pages across all devices. Pages that are no longer in use can be
* removed with {@link #pruneDeadPage}
*
* @param lookupConcurrency the number of concurrent lookup operations to perform when populating list results
* @return All stored prekey pages
*/
public Flux<DeviceKEMPreKeyPages> listStoredKEMPreKeyPages(int lookupConcurrency) {
return pagedPqPreKeys.listStoredPages(lookupConcurrency);
}
/**
* Remove a prekey page that is no longer in use. A page should only be removed if it is not the active page and
* it has no chance of being updated to be.
*
* @param identifier The owner of the dead page
* @param deviceId The device of the dead page
* @param pageId The dead page to remove from storage
* @return A future that completes when the page has been removed
*/
public CompletableFuture<Void> pruneDeadPage(final UUID identifier, final byte deviceId, final UUID pageId) {
return pagedPqPreKeys.deleteBundleFromS3(identifier, deviceId, pageId);
}
}

View File

@ -15,8 +15,6 @@ import io.lettuce.core.Range;
import io.lettuce.core.ScoredValue;
import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.models.partitions.ClusterPartitionParser;
import io.lettuce.core.cluster.models.partitions.Partitions;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
@ -260,7 +258,7 @@ public class MessagesCache {
for (final byte[] bytes : serialized) {
try {
final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(bytes);
final MessageProtos.Envelope envelope = parseEnvelope(bytes);
removedMessages.add(RemovedMessage.fromEnvelope(envelope));
if (envelope.hasSharedMrmKey()) {
serviceIdentifierToMrmKeys.computeIfAbsent(
@ -387,7 +385,7 @@ public class MessagesCache {
for (int i = 0; i < queueItems.size() - 1; i += 2) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
final MessageProtos.Envelope message = parseEnvelope(queueItems.get(i));
final Mono<MessageProtos.Envelope> messageMono;
if (message.hasSharedMrmKey()) {
@ -606,7 +604,7 @@ public class MessagesCache {
return serializedMessages
.mapNotNull(message -> {
try {
return MessageProtos.Envelope.parseFrom(message);
return parseEnvelope(message);
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
return null;
@ -646,7 +644,7 @@ public class MessagesCache {
final List<String> processedMessages = new ArrayList<>(messagesToProcess.size());
for (byte[] serialized : messagesToProcess) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(serialized);
final MessageProtos.Envelope message = parseEnvelope(serialized);
processedMessages.add(message.getServerGuid());
@ -754,4 +752,10 @@ public class MessagesCache {
static byte getDeviceIdFromQueueName(final String queueName) {
return Byte.parseByte(queueName.substring(queueName.lastIndexOf("::") + 2, queueName.lastIndexOf('}')));
}
private static MessageProtos.Envelope parseEnvelope(final byte[] envelopeBytes)
throws InvalidProtocolBufferException {
return EnvelopeUtil.expand(MessageProtos.Envelope.parseFrom(envelopeBytes));
}
}

View File

@ -57,7 +57,7 @@ class MessagesCacheInsertScript {
);
final List<byte[]> args = new ArrayList<>(Arrays.asList(
envelope.toByteArray(), // message
EnvelopeUtil.compress(envelope).toByteArray(), // message
String.valueOf(envelope.getServerTimestamp()).getBytes(StandardCharsets.UTF_8), // currentTime
envelope.getServerGuid().getBytes(StandardCharsets.UTF_8), // guid
NEW_MESSAGE_EVENT_BYTES // eventPayload

View File

@ -105,7 +105,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
.put(KEY_SORT, convertSortKey(message.getServerTimestamp(), messageUuid))
.put(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT, convertLocalIndexMessageUuidSortKey(messageUuid))
.put(KEY_TTL, AttributeValues.fromLong(getTtlForMessage(message)))
.put(KEY_ENVELOPE_BYTES, AttributeValue.builder().b(SdkBytes.fromByteArray(message.toByteArray())).build());
.put(KEY_ENVELOPE_BYTES, AttributeValue.builder().b(SdkBytes.fromByteArray(EnvelopeUtil.compress(message).toByteArray())).build());
writeItems.add(WriteRequest.builder().putRequest(PutRequest.builder()
.item(item.build())
@ -227,7 +227,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
static MessageProtos.Envelope convertItemToEnvelope(final Map<String, AttributeValue> item)
throws InvalidProtocolBufferException {
return MessageProtos.Envelope.parseFrom(item.get(KEY_ENVELOPE_BYTES).b().asByteArray());
return EnvelopeUtil.expand(MessageProtos.Envelope.parseFrom(item.get(KEY_ENVELOPE_BYTES).b().asByteArray()));
}
private long getTtlForMessage(MessageProtos.Envelope message) {

View File

@ -0,0 +1,430 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Object;
/**
* @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on
* an object store and referenced via dynamodb. Each device may only have a single active page at a time. Crashes or
* errors may leave orphaned pages which are no longer referenced by the database. A background process must
* periodically check for orphaned pages and remove them.
* @see SingleUsePreKeyStore
*/
public class PagedSingleUseKEMPreKeyStore {
private static final Logger log = LoggerFactory.getLogger(PagedSingleUseKEMPreKeyStore.class);
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final S3AsyncClient s3AsyncClient;
private final String tableName;
private final String bucketName;
private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount"));
private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch"));
private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice"));
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary
.builder(name(getClass(), "availableKeyCount"))
.publishPercentileHistogram()
.register(Metrics.globalRegistry);
private final String takeKeyTimerName = name(getClass(), "takeKey");
private static final String KEY_PRESENT_TAG_NAME = "keyPresent";
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID = "D";
static final String ATTR_PAGE_ID = "ID";
static final String ATTR_PAGE_IDX = "I";
static final String ATTR_PAGE_NUM_KEYS = "N";
static final String ATTR_PAGE_FORMAT_VERSION = "F";
public PagedSingleUseKEMPreKeyStore(
final DynamoDbAsyncClient dynamoDbAsyncClient,
final S3AsyncClient s3AsyncClient,
final String tableName,
final String bucketName) {
this.s3AsyncClient = s3AsyncClient;
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName;
this.bucketName = bucketName;
}
/**
* Stores a batch of single-use pre-keys for a specific device. All previously-stored keys for the device are cleared
* before storing new keys.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @param preKeys a collection of single-use pre-keys to store for the target device
* @return a future that completes when all previously-stored keys have been removed and the given collection of
* pre-keys has been stored in its place
*/
public CompletableFuture<Void> store(
final UUID identifier, final byte deviceId, final List<KEMSignedPreKey> preKeys) {
final Timer.Sample sample = Timer.start();
final List<KEMSignedPreKey> sorted = preKeys.stream().sorted(Comparator.comparing(KEMSignedPreKey::keyId)).toList();
final int bundleFormat = KEMPreKeyPage.FORMAT;
final ByteBuffer bundle = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, sorted);
// Write the bundle to S3, then update the database. Delete the S3 object that was in the database before. This can
// leave orphans in S3 if we fail to update after writing to S3, or fail to delete the old page. However, it can
// never leave a broken pointer in the database. To keep this invariant, we must make sure to generate a new
// name for the page any time we were to retry this entire operation.
return writeBundleToS3(identifier, deviceId, bundle)
.thenCompose(pageId -> dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId),
ATTR_PAGE_ID, AttributeValues.fromUUID(pageId),
ATTR_PAGE_IDX, AttributeValues.fromInt(0),
ATTR_PAGE_NUM_KEYS, AttributeValues.fromInt(sorted.size()),
ATTR_PAGE_FORMAT_VERSION, AttributeValues.fromInt(bundleFormat)
))
.returnValues(ReturnValue.ALL_OLD)
.build()))
.thenCompose(response -> {
if (response.hasAttributes()) {
final UUID pageId = AttributeValues.getUUID(response.attributes(), ATTR_PAGE_ID, null);
if (pageId == null) {
log.error("Replaced record: {} with no pageId", response.attributes());
return CompletableFuture.completedFuture(null);
}
return deleteBundleFromS3(identifier, deviceId, pageId);
} else {
return CompletableFuture.completedFuture(null);
}
})
.whenComplete((result, error) -> sample.stop(storeKeyBatchTimer));
}
/**
* Attempts to retrieve a single-use pre-key for a specific device. Keys may only be returned by this method at most
* once; once the key is returned, it is removed from the key store and subsequent calls to this method will never
* return the same key.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that yields a single-use pre-key if one is available or empty if no single-use pre-keys are
* available for the target device
*/
public CompletableFuture<Optional<KEMSignedPreKey>> take(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)))
.updateExpression("SET #index = #index + :one")
.conditionExpression("#id = :id AND #index < #numkeys")
.expressionAttributeNames(Map.of(
"#id", KEY_ACCOUNT_UUID,
"#index", ATTR_PAGE_IDX,
"#numkeys", ATTR_PAGE_NUM_KEYS))
.expressionAttributeValues(Map.of(
":one", AttributeValues.n(1),
":id", AttributeValues.fromUUID(identifier)))
.returnValues(ReturnValue.ALL_OLD)
.build())
.thenCompose(updateItemResponse -> {
if (!updateItemResponse.hasAttributes()) {
throw new IllegalStateException("update succeeded but did not return an item");
}
final int index = AttributeValues.getInt(updateItemResponse.attributes(), ATTR_PAGE_IDX, -1);
final UUID pageId = AttributeValues.getUUID(updateItemResponse.attributes(), ATTR_PAGE_ID, null);
final int format = AttributeValues.getInt(updateItemResponse.attributes(), ATTR_PAGE_FORMAT_VERSION, -1);
if (index < 0 || format < 0 || pageId == null) {
throw new CompletionException(
new IOException("unexpected page descriptor " + updateItemResponse.attributes()));
}
return readPreKeyAtIndexFromS3(identifier, deviceId, pageId, format, index).thenApply(Optional::of);
})
// If this check fails, it means that the item did not exist, or its index was already at the last key. Either
// way, there are no keys left so we return empty
.exceptionally(ExceptionUtils.exceptionallyHandler(
ConditionalCheckFailedException.class,
e -> Optional.empty()))
.whenComplete((maybeKey, throwable) ->
sample.stop(Metrics.timer(
takeKeyTimerName,
KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent()))));
}
/**
* Returns the number of single-use pre-keys available for a given device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that yields the approximate number of single-use pre-keys currently available for the target
* device
*/
public CompletableFuture<Integer> getCount(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)))
.consistentRead(true)
.projectionExpression("#total, #index")
.expressionAttributeNames(Map.of(
"#total", ATTR_PAGE_NUM_KEYS,
"#index", ATTR_PAGE_IDX))
.build())
.thenApply(getResponse -> {
if (!getResponse.hasItem()) {
return 0;
}
final int numKeys = AttributeValues.getInt(getResponse.item(), ATTR_PAGE_NUM_KEYS, -1);
final int index = AttributeValues.getInt(getResponse.item(), ATTR_PAGE_IDX, -1);
if (numKeys < 0 || index < 0 || index > numKeys) {
log.error("unexpected index/length in page descriptor: {}", getResponse.item());
return 0;
}
return numKeys - index;
})
.whenComplete((keyCount, throwable) -> {
sample.stop(getKeyCountTimer);
if (throwable == null && keyCount != null) {
availableKeyCountDistributionSummary.record(keyCount);
}
});
}
/**
* Removes all single-use pre-keys for all devices associated with the given account/identity.
*
* @param identifier the identifier for the account/identity for which to remove single-use pre-keys
* @return a future that completes when all single-use pre-keys have been removed for all devices associated with the
* given account/identity
*/
public CompletableFuture<Void> delete(final UUID identifier) {
final Timer.Sample sample = Timer.start();
return deleteItems(identifier, Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
.projectionExpression("#uuid,#deviceid,#pageid")
.expressionAttributeNames(Map.of(
"#uuid", KEY_ACCOUNT_UUID,
"#deviceid", KEY_DEVICE_ID,
"#pageid", ATTR_PAGE_ID))
.expressionAttributeValues(Map.of(":uuid", AttributeValues.fromUUID(identifier)))
.consistentRead(true)
.build())
.items()))
.thenRun(() -> sample.stop(deleteForAccountTimer));
}
/**
* Removes all single-use pre-keys for a specific device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that completes when all single-use pre-keys have been removed for the target device
*/
public CompletableFuture<Void> delete(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)))
.consistentRead(true)
.projectionExpression("#uuid,#deviceid,#pageid")
.expressionAttributeNames(Map.of(
"#uuid", KEY_ACCOUNT_UUID,
"#deviceid", KEY_DEVICE_ID,
"#pageid", ATTR_PAGE_ID))
.build())
.thenCompose(getItemResponse -> deleteItems(identifier, getItemResponse.hasItem()
? Flux.just(getItemResponse.item())
: Flux.empty()))
.thenRun(() -> sample.stop(deleteForDeviceTimer));
}
public Flux<DeviceKEMPreKeyPages> listStoredPages(int lookupConcurrency) {
return Flux
.from(s3AsyncClient.listObjectsV2Paginator(ListObjectsV2Request.builder()
.bucket(bucketName)
.build()))
.flatMapIterable(ListObjectsV2Response::contents)
.map(PagedSingleUseKEMPreKeyStore::parseS3Key)
.bufferUntilChanged(Function.identity(), S3PageKey::fromSameDevice)
.flatMapSequential(pages -> {
final UUID identifier = pages.getFirst().identifier();
final byte deviceId = pages.getFirst().deviceId();
return Mono.fromCompletionStage(() -> dynamoDbAsyncClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)))
// Make sure we get the most up to date pageId to minimize cases where we see a new page in S3 but
// view a stale dynamodb record
.consistentRead(true)
.projectionExpression("#uuid,#deviceid,#pageid")
.expressionAttributeNames(Map.of(
"#uuid", KEY_ACCOUNT_UUID,
"#deviceid", KEY_DEVICE_ID,
"#pageid", ATTR_PAGE_ID))
.build())
.thenApply(getItemResponse -> new DeviceKEMPreKeyPages(
identifier,
deviceId,
Optional.ofNullable(AttributeValues.getUUID(getItemResponse.item(), ATTR_PAGE_ID, null)),
pages.stream().collect(Collectors.toMap(S3PageKey::pageId, S3PageKey::lastModified)))));
}, lookupConcurrency);
}
private CompletableFuture<Void> deleteItems(final UUID identifier,
final Flux<Map<String, AttributeValue>> items) {
return items
.flatMap(item -> {
final UUID aci = AttributeValues.getUUID(item, KEY_ACCOUNT_UUID, null);
final byte deviceId = (byte) AttributeValues.getInt(item, KEY_DEVICE_ID, -1);
final UUID pageId = AttributeValues.getUUID(item, ATTR_PAGE_ID, null);
if (aci == null || deviceId < 0 || pageId == null) {
log.error("can't delete page from unexpected page descriptor {}", item);
}
return Mono.fromFuture(deleteBundleFromS3(aci, deviceId, pageId))
.thenReturn(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)));
})
.flatMap(itemToDelete -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder()
.tableName(tableName)
.key(itemToDelete)
.build())))
.then()
.toFuture()
.thenRun(Util.NOOP);
}
private static String s3Key(final UUID identifier, final byte deviceId, final UUID pageId) {
return String.format("%s/%s/%s", identifier, deviceId, pageId);
}
private record S3PageKey(UUID identifier, byte deviceId, UUID pageId, Instant lastModified) {
boolean fromSameDevice(final S3PageKey other) {
return deviceId == other.deviceId && identifier.equals(other.identifier);
}
}
private static S3PageKey parseS3Key(final S3Object page) {
try {
final String[] parts = page.key().split("/", 3);
if (parts.length != 3 || parts[2].contains("/")) {
throw new IllegalArgumentException("wrong number of path components");
}
return new S3PageKey(
UUID.fromString(parts[0]),
Byte.parseByte(parts[1]),
UUID.fromString(parts[2]), page.lastModified());
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException("invalid s3 page key: " + page.key(), e);
}
}
private CompletableFuture<UUID> writeBundleToS3(final UUID identifier, final byte deviceId,
final ByteBuffer bundle) {
final UUID pageId = UUID.randomUUID();
return s3AsyncClient.putObject(PutObjectRequest.builder()
.bucket(bucketName)
.key(s3Key(identifier, deviceId, pageId)).build(),
AsyncRequestBody.fromByteBuffer(bundle))
.thenApply(ignoredResponse -> pageId);
}
CompletableFuture<Void> deleteBundleFromS3(final UUID identifier, final byte deviceId, final UUID pageId) {
return s3AsyncClient.deleteObject(DeleteObjectRequest.builder()
.bucket(bucketName)
.key(s3Key(identifier, deviceId, pageId))
.build())
.thenRun(Util.NOOP);
}
private CompletableFuture<KEMSignedPreKey> readPreKeyAtIndexFromS3(
final UUID identifier, final byte deviceId, final UUID pageId, final int format, final int index) {
final KEMPreKeyPage.KeyLocation keyLocation = KEMPreKeyPage.keyLocation(format, index);
return s3AsyncClient.getObject(GetObjectRequest.builder()
.bucket(bucketName)
.key(s3Key(identifier, deviceId, pageId))
// An RFC9110 range header, inclusive on both ends
// https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2
.range("bytes=%s-%s".formatted(keyLocation.getStartInclusive(), keyLocation.getEndInclusive()))
.build(), AsyncResponseTransformer.toBytes())
.thenApply(bytes -> {
final ByteBuffer serialized = bytes.asByteBuffer();
if (serialized.remaining() != keyLocation.length()) {
log.error("Unexpected ranged read response, requested {} got {} for offset {} in page {}",
keyLocation.length(), serialized.remaining(), keyLocation, s3Key(identifier, deviceId, pageId));
throw new CompletionException(new IOException("Invalid response to ranged read"));
}
try {
return KEMPreKeyPage.deserializeKey(format, bytes.asByteBuffer());
} catch (InvalidKeyException e) {
throw new CompletionException(new IOException(e));
}
});
}
}

View File

@ -19,7 +19,7 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<ECPreKey> {
private static final String PARSE_BYTE_ARRAY_COUNTER_NAME = name(SingleUseECPreKeyStore.class, "parseByteArray");
protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
public SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName);
}

View File

@ -16,7 +16,7 @@ import java.util.UUID;
public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore<KEMSignedPreKey> {
protected SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
public SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName);
}

View File

@ -102,16 +102,23 @@ public class SubscriptionException extends Exception {
public static class ChargeFailurePaymentRequired extends SubscriptionException {
private final PaymentProvider processor;
private final ChargeFailure chargeFailure;
public ChargeFailurePaymentRequired(final ChargeFailure chargeFailure) {
public ChargeFailurePaymentRequired(final PaymentProvider processor, final ChargeFailure chargeFailure) {
super(null, null);
this.processor = processor;
this.chargeFailure = chargeFailure;
}
public PaymentProvider getProcessor() {
return processor;
}
public ChargeFailure getChargeFailure() {
return chargeFailure;
}
}
public static class ProcessorException extends SubscriptionException {

View File

@ -633,7 +633,7 @@ public class BraintreeManager implements CustomerAwareSubscriptionPaymentProcess
if (subscriptionStatus.equals(SubscriptionStatus.ACTIVE) || subscriptionStatus.equals(SubscriptionStatus.PAST_DUE)) {
throw ExceptionUtils.wrap(new SubscriptionException.ReceiptRequestedForOpenPayment());
}
throw ExceptionUtils.wrap(new SubscriptionException.ChargeFailurePaymentRequired(createChargeFailure(transaction)));
throw ExceptionUtils.wrap(new SubscriptionException.ChargeFailurePaymentRequired(getProvider(), createChargeFailure(transaction)));
}
final Instant paidAt = transaction.getSubscriptionDetails().getBillingPeriodStartDate().toInstant();

View File

@ -43,6 +43,7 @@ import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.PaymentTime;
import org.whispersystems.textsecuregcm.storage.SubscriptionException;
@ -362,10 +363,14 @@ public class GooglePlayBillingManager implements SubscriptionPaymentProcessor {
|| e.getStatusCode() == Response.Status.GONE.getStatusCode()) {
throw ExceptionUtils.wrap(new SubscriptionException.NotFound());
}
if (e.getStatusCode() == Response.Status.TOO_MANY_REQUESTS.getStatusCode()) {
throw ExceptionUtils.wrap(new RateLimitExceededException(null));
}
final String details = e instanceof GoogleJsonResponseException
? ((GoogleJsonResponseException) e).getDetails().toString()
: "";
logger.warn("Unexpected HTTP status code {} from androidpublisher: {}", e.getStatusCode(), details, e);
logger.warn("Unexpected HTTP status code {} from androidpublisher: {}", e.getStatusCode(), details);
throw ExceptionUtils.wrap(e);
}));
}

View File

@ -645,7 +645,8 @@ public class StripeManager implements CustomerAwareSubscriptionPaymentProcessor
// If the charge object has a failure reason we can present to the user, create a detailed exception
.filter(charge -> charge.getFailureCode() != null || charge.getFailureMessage() != null)
.<SubscriptionException> map(charge -> new SubscriptionException.ChargeFailurePaymentRequired(createChargeFailure(charge)))
.<SubscriptionException> map(charge ->
new SubscriptionException.ChargeFailurePaymentRequired(getProvider(), createChargeFailure(charge)))
// Otherwise, return a generic payment required error
.orElseGet(() -> new SubscriptionException.PaymentRequired())));

View File

@ -33,6 +33,7 @@ import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
@ -57,13 +58,18 @@ import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PagedSingleUseKEMPreKeyStore;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswords;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.RepeatedUseECSignedPreKeyStore;
import org.whispersystems.textsecuregcm.storage.RepeatedUseKEMSignedPreKeyStore;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.SingleUseECPreKeyStore;
import org.whispersystems.textsecuregcm.storage.SingleUseKEMPreKeyStore;
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
@ -117,6 +123,9 @@ record CommandDependencies(
new DynamicConfigurationManager<>(
configuration.getDynamicConfig().build(awsCredentialsProvider, dynamicConfigurationExecutor), DynamicConfiguration.class);
dynamicConfigurationManager.start();
ExperimentEnrollmentManager experimentEnrollmentManager =
new ExperimentEnrollmentManager(dynamicConfigurationManager);
final ClientResources.Builder redisClientResourcesBuilder = ClientResources.builder();
FaultTolerantRedisClusterClient cacheCluster = configuration.getCacheClusterConfiguration()
@ -204,13 +213,23 @@ record CommandDependencies(
configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName());
S3AsyncClient asyncKeysS3Client = S3AsyncClient.builder()
.credentialsProvider(awsCredentialsProvider)
.region(Region.of(configuration.getPagedSingleUseKEMPreKeyStore().region()))
.build();
PagedSingleUseKEMPreKeyStore pagedSingleUseKEMPreKeyStore = new PagedSingleUseKEMPreKeyStore(
dynamoDbAsyncClient, asyncKeysS3Client,
configuration.getDynamoDbTables().getPagedKemKeys().getTableName(),
configuration.getPagedSingleUseKEMPreKeyStore().bucket());
KeysManager keys = new KeysManager(
dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getKemKeys().getTableName(),
configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName(),
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName()
);
new SingleUseECPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcKeys().getTableName()),
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getKemKeys().getTableName()),
pagedSingleUseKEMPreKeyStore,
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName()),
experimentEnrollmentManager);
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -0,0 +1,143 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.workers;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.core.Application;
import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.Metrics;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Stream;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.DeviceKEMPreKeyPages;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.retry.Retry;
public class RemoveOrphanedPreKeyPagesCommand extends AbstractCommandWithDependencies {
private final Logger logger = LoggerFactory.getLogger(getClass());
private static final String PAGE_CONSIDERED_COUNTER_NAME = MetricsUtil.name(RemoveOrphanedPreKeyPagesCommand.class,
"pageConsidered");
@VisibleForTesting
static final String DRY_RUN_ARGUMENT = "dry-run";
@VisibleForTesting
static final String CONCURRENCY_ARGUMENT = "concurrency";
private static final int DEFAULT_CONCURRENCY = 10;
@VisibleForTesting
static final String MINIMUM_ORPHAN_AGE_ARGUMENT = "orphan-age";
private static final Duration DEFAULT_MINIMUM_ORPHAN_AGE = Duration.ofDays(7);
private final Clock clock;
public RemoveOrphanedPreKeyPagesCommand(final Clock clock) {
super(new Application<>() {
@Override
public void run(final WhisperServerConfiguration configuration, final Environment environment) {
}
}, "remove-orphaned-pre-key-pages", "Remove pre-key pages that are unreferenced");
this.clock = clock;
}
@Override
public void configure(final Subparser subparser) {
super.configure(subparser);
subparser.addArgument("--concurrency")
.type(Integer.class)
.dest(CONCURRENCY_ARGUMENT)
.required(false)
.setDefault(DEFAULT_CONCURRENCY)
.help("The maximum number of parallel dynamodb operations to process concurrently");
subparser.addArgument("--dry-run")
.type(Boolean.class)
.dest(DRY_RUN_ARGUMENT)
.required(false)
.setDefault(true)
.help("If true, don't actually remove orphaned pre-key pages");
subparser.addArgument("--minimum-orphan-age")
.type(String.class)
.dest(MINIMUM_ORPHAN_AGE_ARGUMENT)
.required(false)
.setDefault(DEFAULT_MINIMUM_ORPHAN_AGE.toString())
.help("Only remove orphans that are at least this old. Provide as an ISO-8601 duration string");
}
@Override
protected void run(final Environment environment, final Namespace namespace,
final WhisperServerConfiguration configuration, final CommandDependencies commandDependencies) throws Exception {
final int concurrency = Objects.requireNonNull(namespace.getInt(CONCURRENCY_ARGUMENT));
final boolean dryRun = Objects.requireNonNull(namespace.getBoolean(DRY_RUN_ARGUMENT));
final Duration orphanAgeMinimum =
Duration.parse(Objects.requireNonNull(namespace.getString(MINIMUM_ORPHAN_AGE_ARGUMENT)));
final Instant olderThan = clock.instant().minus(orphanAgeMinimum);
logger.info("Crawling preKey page store with concurrency={}, processors={}, dryRun={}. Removing orphans written before={}",
concurrency,
Runtime.getRuntime().availableProcessors(),
dryRun,
olderThan);
final KeysManager keysManager = commandDependencies.keysManager();
final int deletedPages = keysManager.listStoredKEMPreKeyPages(concurrency)
.flatMap(storedPages -> Flux.fromStream(getDetetablePages(storedPages, olderThan))
.concatMap(pageId -> dryRun
? Mono.just(0)
: Mono.fromCompletionStage(() ->
keysManager.pruneDeadPage(storedPages.identifier(), storedPages.deviceId(), pageId))
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.thenReturn(1)), concurrency)
.reduce(0, Integer::sum)
.block();
logger.info("Deleted {} orphaned pages", deletedPages);
}
private static Stream<UUID> getDetetablePages(final DeviceKEMPreKeyPages storedPages, final Instant olderThan) {
return storedPages.pageIdToLastModified()
.entrySet()
.stream()
.filter(page -> {
final UUID pageId = page.getKey();
final Instant lastModified = page.getValue();
return shouldDeletePage(storedPages.currentPage(), pageId, olderThan, lastModified);
})
.map(Map.Entry::getKey);
}
@VisibleForTesting
static boolean shouldDeletePage(
final Optional<UUID> currentPage, final UUID page,
final Instant deleteBefore, final Instant lastModified) {
final boolean isCurrentPageForDevice = currentPage.map(uuid -> uuid.equals(page)).orElse(false);
final boolean isStale = lastModified.isBefore(deleteBefore);
Metrics.counter(PAGE_CONSIDERED_COUNTER_NAME,
"isCurrentPageForDevice", Boolean.toString(isCurrentPageForDevice),
"stale", Boolean.toString(isStale))
.increment();
return !isCurrentPageForDevice && isStale;
}
}

View File

@ -10,6 +10,8 @@ option java_package = "org.signal.keytransparency.client";
package kt_query;
import "org/signal/chat/require.proto";
/**
* An external-facing, read-only key transparency service used by Signal's chat server
* to look up and monitor identifiers.
@ -19,8 +21,13 @@ package kt_query;
* - A username hash which also maps to an ACI
* Separately, the log also stores and periodically updates a fixed value known as the `distinguished` key.
* Clients use the verified tree head from looking up this key for future calls to the Search and Monitor endpoints.
*
* Note that this service definition is used in two different contexts:
* 1. Implementing the endpoints with rate-limiting and request validation
* 2. Using the generated client stub to forward requests to the remote key transparency service
*/
service KeyTransparencyQueryService {
option (org.signal.chat.require.auth) = AUTH_ONLY_ANONYMOUS;
/**
* An endpoint used by clients to retrieve the most recent distinguished tree
* head, which should be used to derive consistency parameters for
@ -44,15 +51,15 @@ message SearchRequest {
/**
* The ACI to look up in the log.
*/
bytes aci = 1;
bytes aci = 1 [(org.signal.chat.require.exactlySize) = 16];
/**
* The ACI identity key that the client thinks the ACI maps to in the log.
*/
bytes aci_identity_key = 2;
bytes aci_identity_key = 2 [(org.signal.chat.require.nonEmpty) = true];
/**
* The username hash to look up in the log.
*/
optional bytes username_hash = 3;
optional bytes username_hash = 3 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32];
/**
* The E164 to look up in the log along with associated data.
*/
@ -60,7 +67,7 @@ message SearchRequest {
/**
* The tree head size(s) to prove consistency against.
*/
ConsistencyParameters consistency = 5;
ConsistencyParameters consistency = 5 [(org.signal.chat.require.present) = true];
}
/**
@ -70,7 +77,7 @@ message E164SearchRequest {
/**
* The E164 that the client wishes to look up in the transparency log.
*/
string e164 = 1;
optional string e164 = 1 [(org.signal.chat.require.e164) = true];
/**
* The unidentified access key of the account associated with the provided E164.
*/
@ -328,28 +335,28 @@ message PrefixSearchResult {
}
message MonitorRequest {
AciMonitorRequest aci = 1;
AciMonitorRequest aci = 1 [(org.signal.chat.require.present) = true];
optional UsernameHashMonitorRequest username_hash = 2;
optional E164MonitorRequest e164 = 3;
ConsistencyParameters consistency = 4;
ConsistencyParameters consistency = 4 [(org.signal.chat.require.present) = true];
}
message AciMonitorRequest {
bytes aci = 1;
bytes aci = 1 [(org.signal.chat.require.exactlySize) = 16];
uint64 entry_position = 2;
bytes commitment_index = 3;
bytes commitment_index = 3 [(org.signal.chat.require.exactlySize) = 32];
}
message UsernameHashMonitorRequest {
bytes username_hash = 1;
bytes username_hash = 1 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32];
uint64 entry_position = 2;
bytes commitment_index = 3;
bytes commitment_index = 3 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32];
}
message E164MonitorRequest {
string e164 = 1;
optional string e164 = 1 [(org.signal.chat.require.e164) = true];
uint64 entry_position = 2;
bytes commitment_index = 3;
bytes commitment_index = 3 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32];
}
message MonitorProof {

View File

@ -35,7 +35,11 @@ message Envelope {
optional bool story = 16; // indicates that the content is a story.
optional bytes report_spam_token = 17; // token sent when reporting spam
optional bytes shared_mrm_key = 18; // indicates content should be fetched from multi-recipient message datastore
// next: 19
optional bytes source_service_id_binary = 19; // service ID binary (i.e. 16 byte UUID for ACI, 1 byte prefix + 16 byte UUID for PNI)
optional bytes destination_service_id_binary = 20; // service ID binary (i.e. 16 byte UUID for ACI, 1 byte prefix + 16 byte UUID for PNI)
optional bytes server_guid_binary = 21; // 16-byte UUID
optional bytes updated_pni_binary = 22; // 16-byte UUID
// next: 22
}
message ProvisioningAddress {

View File

@ -53,6 +53,7 @@ enum ExternalServiceType {
EXTERNAL_SERVICE_TYPE_PAYMENTS = 2;
EXTERNAL_SERVICE_TYPE_STORAGE = 3;
EXTERNAL_SERVICE_TYPE_SVR = 4;
EXTERNAL_SERVICE_TYPE_SVRB = 5;
}
message GetExternalServiceCredentialsRequest {

View File

@ -33,9 +33,8 @@ public class LocalFaultTolerantRedisClientFactory implements FaultTolerantRedisC
if (shutdownHookConfigured.compareAndSet(false, true)) {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
REDIS_SERVER_EXTENSION.afterEach(null);
REDIS_SERVER_EXTENSION.afterAll(null);
} catch (Exception e) {
REDIS_SERVER_EXTENSION.close();
} catch (Throwable e) {
throw new RuntimeException(e);
}
}));

View File

@ -34,8 +34,8 @@ public class LocalFaultTolerantRedisClusterFactory implements FaultTolerantRedis
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
redisClusterExtension.afterEach(null);
redisClusterExtension.afterAll(null);
} catch (Exception e) {
redisClusterExtension.close();
} catch (Throwable e) {
throw new RuntimeException(e);
}
}));

View File

@ -35,12 +35,8 @@ import jakarta.ws.rs.client.WebTarget;
import jakarta.ws.rs.core.Response;
import java.io.UncheckedIOException;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
@ -54,8 +50,10 @@ import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.signal.keytransparency.client.CondensedTreeSearchResponse;
import org.signal.keytransparency.client.DistinguishedResponse;
import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.FullTreeHead;
import org.signal.keytransparency.client.MonitorResponse;
import org.signal.keytransparency.client.SearchProof;
import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.UpdateValue;
@ -81,16 +79,16 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
public class KeyTransparencyControllerTest {
private static final String NUMBER = PhoneNumberUtil.getInstance().format(
public static final String NUMBER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
private static final AciServiceIdentifier ACI = new AciServiceIdentifier(UUID.randomUUID());
private static final byte[] USERNAME_HASH = TestRandomUtil.nextBytes(20);
public static final AciServiceIdentifier ACI = new AciServiceIdentifier(UUID.randomUUID());
public static final byte[] USERNAME_HASH = TestRandomUtil.nextBytes(20);
private static final TestRemoteAddressFilterProvider TEST_REMOTE_ADDRESS_FILTER_PROVIDER
= new TestRemoteAddressFilterProvider("127.0.0.1");
private static final IdentityKey ACI_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
public static final IdentityKey ACI_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
private static final byte[] COMMITMENT_INDEX = new byte[32];
private static final byte[] UNIDENTIFIED_ACCESS_KEY = new byte[16];
public static final byte[] UNIDENTIFIED_ACCESS_KEY = new byte[16];
private final KeyTransparencyServiceClient keyTransparencyServiceClient = mock(KeyTransparencyServiceClient.class);
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter searchRatelimiter = mock(RateLimiter.class);
@ -141,8 +139,8 @@ public class KeyTransparencyControllerTest {
e164.ifPresent(ignored -> searchResponseBuilder.setE164(CondensedTreeSearchResponse.getDefaultInstance()));
usernameHash.ifPresent(ignored -> searchResponseBuilder.setUsernameHash(CondensedTreeSearchResponse.getDefaultInstance()));
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong(), any()))
.thenReturn(CompletableFuture.completedFuture(searchResponseBuilder.build().toByteArray()));
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong()))
.thenReturn(searchResponseBuilder.build());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/search")
@ -167,8 +165,7 @@ public class KeyTransparencyControllerTest {
ArgumentCaptor<Optional<E164SearchRequest>> e164Argument = ArgumentCaptor.forClass(Optional.class);
verify(keyTransparencyServiceClient).search(aciArgument.capture(), aciIdentityKeyArgument.capture(),
usernameHashArgument.capture(), e164Argument.capture(), eq(Optional.of(3L)), eq(4L),
eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
usernameHashArgument.capture(), e164Argument.capture(), eq(Optional.of(3L)), eq(4L));
assertArrayEquals(ACI.toCompactByteArray(), aciArgument.getValue().toByteArray());
assertArrayEquals(ACI_IDENTITY_KEY.serialize(), aciIdentityKeyArgument.getValue().toByteArray());
@ -218,8 +215,8 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest
@MethodSource
void searchGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus))));
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong()))
.thenThrow(new StatusRuntimeException(grpcStatus));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/search")
@ -228,7 +225,7 @@ public class KeyTransparencyControllerTest {
Entity.json(createRequestJson(new KeyTransparencySearchRequest(ACI, Optional.empty(), Optional.empty(),
ACI_IDENTITY_KEY, Optional.empty(), Optional.empty(), 4L))))) {
assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), anyLong(), any());
verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), anyLong());
}
}
@ -295,8 +292,8 @@ public class KeyTransparencyControllerTest {
@Test
void monitorSuccess() {
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16)));
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong()))
.thenReturn(MonitorResponse.getDefaultInstance());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/monitor")
@ -314,7 +311,7 @@ public class KeyTransparencyControllerTest {
assertNotNull(keyTransparencyMonitorResponse.serializedResponse());
verify(keyTransparencyServiceClient, times(1)).monitor(
any(), any(), any(), eq(3L), eq(4L), eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
any(), any(), any(), eq(3L), eq(4L));
}
}
@ -337,8 +334,8 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest
@MethodSource
void monitorGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus))));
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong()))
.thenThrow(new StatusRuntimeException(grpcStatus));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/monitor")
@ -349,7 +346,7 @@ public class KeyTransparencyControllerTest {
new KeyTransparencyMonitorRequest.AciMonitor(ACI, 3, COMMITMENT_INDEX),
Optional.empty(), Optional.empty(), 3L, 4L))))) {
assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), anyLong(), anyLong(), any());
verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), anyLong(), anyLong());
}
}
@ -500,8 +497,8 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest
@CsvSource(", 1")
void distinguishedSuccess(@Nullable Long lastTreeHeadSize) {
when(keyTransparencyServiceClient.getDistinguishedKey(any(), any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16)));
when(keyTransparencyServiceClient.getDistinguishedKey(any()))
.thenReturn(DistinguishedResponse.getDefaultInstance());
WebTarget webTarget = resources.getJerseyTest()
.target("/v1/key-transparency/distinguished");
@ -518,8 +515,7 @@ public class KeyTransparencyControllerTest {
assertNotNull(distinguishedKeyResponse.serializedResponse());
verify(keyTransparencyServiceClient, times(1))
.getDistinguishedKey(eq(Optional.ofNullable(lastTreeHeadSize)),
eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
.getDistinguishedKey(eq(Optional.ofNullable(lastTreeHeadSize)));
}
}
@ -538,15 +534,15 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest
@MethodSource
void distinguishedGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.getDistinguishedKey(any(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus))));
when(keyTransparencyServiceClient.getDistinguishedKey(any()))
.thenThrow(new StatusRuntimeException(grpcStatus));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/distinguished")
.request();
try (Response response = request.get()) {
assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient).getDistinguishedKey(any(), any());
verify(keyTransparencyServiceClient).getDistinguishedKey(any());
}
}
@ -561,8 +557,8 @@ public class KeyTransparencyControllerTest {
@Test
void distinguishedInvalidRequest() {
when(keyTransparencyServiceClient.getDistinguishedKey(any(), any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16)));
when(keyTransparencyServiceClient.getDistinguishedKey(any()))
.thenReturn(DistinguishedResponse.getDefaultInstance());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/distinguished")

View File

@ -6,28 +6,43 @@
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.whispersystems.textsecuregcm.util.MockUtils.randomSecretBytes;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.util.Map;
import java.util.stream.Collectors;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecoveryConfiguration;
import org.whispersystems.textsecuregcm.entities.AuthCheckRequest;
import org.whispersystems.textsecuregcm.entities.AuthCheckResponseV2;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.MutableClock;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@ExtendWith(DropwizardExtensionsSupport.class)
public class SecureValueRecovery2ControllerTest extends SecureValueRecoveryControllerBaseTest {
public class SecureValueRecovery2ControllerTest {
private static final SecureValueRecovery2Configuration CFG = new SecureValueRecovery2Configuration(
private static final SecureValueRecoveryConfiguration CFG = new SecureValueRecoveryConfiguration(
"",
randomSecretBytes(32),
randomSecretBytes(32),
@ -52,19 +67,305 @@ public class SecureValueRecovery2ControllerTest extends SecureValueRecoveryContr
.addResource(CONTROLLER)
.build();
protected SecureValueRecovery2ControllerTest() {
super("/v2", ACCOUNTS_MANAGER, CLOCK, RESOURCES, CREDENTIAL_GENERATOR);
@Nested
class WithBackupsPrefix extends SecureValueRecoveryControllerBaseTest {
protected WithBackupsPrefix() {
super("/v2/backup");
}
}
@Override
Map<String, CheckStatus> parseCheckResponse(final Response response) {
final AuthCheckResponseV2 authCheckResponseV2 = response.readEntity(AuthCheckResponseV2.class);
return authCheckResponseV2.matches().entrySet().stream().collect(Collectors.toMap(
Map.Entry::getKey, e -> switch (e.getValue()) {
case MATCH -> CheckStatus.MATCH;
case INVALID -> CheckStatus.INVALID;
case NO_MATCH -> CheckStatus.NO_MATCH;
}
));
@Nested
class WithSvr2Prefix extends SecureValueRecoveryControllerBaseTest {
protected WithSvr2Prefix() {
super("/v2/svr");
}
}
static abstract class SecureValueRecoveryControllerBaseTest {
private static final UUID USER_1 = UUID.randomUUID();
private static final UUID USER_2 = UUID.randomUUID();
private static final UUID USER_3 = UUID.randomUUID();
private static final String E164_VALID = "+18005550123";
private static final String E164_INVALID = "1(800)555-0123";
private final String pathPrefix;
@BeforeEach
public void before() throws Exception {
Mockito.reset(ACCOUNTS_MANAGER);
Mockito.when(ACCOUNTS_MANAGER.getByE164(E164_VALID)).thenReturn(Optional.of(account(USER_1)));
}
protected SecureValueRecoveryControllerBaseTest(final String pathPrefix) {
this.pathPrefix = pathPrefix;
}
enum CheckStatus {
MATCH,
NO_MATCH,
INVALID
}
private Map<String, CheckStatus> parseCheckResponse(final Response response) {
final AuthCheckResponseV2 authCheckResponseV2 = response.readEntity(AuthCheckResponseV2.class);
return authCheckResponseV2.matches().entrySet().stream().collect(Collectors.toMap(
Map.Entry::getKey, e -> switch (e.getValue()) {
case MATCH -> CheckStatus.MATCH;
case INVALID -> CheckStatus.INVALID;
case NO_MATCH -> CheckStatus.NO_MATCH;
}
));
}
@Test
public void testOneMatch() {
validate(Map.of(
token(USER_1, day(1)), CheckStatus.MATCH,
token(USER_2, day(1)), CheckStatus.NO_MATCH,
token(USER_3, day(1)), CheckStatus.NO_MATCH
), day(2));
}
@Test
public void testNoMatch() {
validate(Map.of(
token(USER_2, day(1)), CheckStatus.NO_MATCH,
token(USER_3, day(1)), CheckStatus.NO_MATCH
), day(2));
}
@Test
public void testSomeInvalid() {
final ExternalServiceCredentials user1Cred = credentials(USER_1, day(1));
final ExternalServiceCredentials user2Cred = credentials(USER_2, day(1));
final ExternalServiceCredentials user3Cred = credentials(USER_3, day(1));
final String fakeToken = token(new ExternalServiceCredentials(user2Cred.username(), user3Cred.password()));
validate(Map.of(
token(user1Cred), CheckStatus.MATCH,
token(user2Cred), CheckStatus.NO_MATCH,
fakeToken, CheckStatus.INVALID
), day(2));
}
@Test
public void testSomeExpired() {
validate(Map.of(
token(USER_1, day(100)), CheckStatus.MATCH,
token(USER_2, day(100)), CheckStatus.NO_MATCH,
token(USER_3, day(10)), CheckStatus.INVALID,
token(USER_3, day(20)), CheckStatus.INVALID
), day(110));
}
@Test
public void testSomeHaveNewerVersions() {
validate(Map.of(
token(USER_1, day(10)), CheckStatus.INVALID,
token(USER_1, day(20)), CheckStatus.MATCH,
token(USER_2, day(10)), CheckStatus.NO_MATCH,
token(USER_3, day(20)), CheckStatus.NO_MATCH,
token(USER_3, day(10)), CheckStatus.INVALID
), day(25));
}
private void validate(
final Map<String, CheckStatus> expected,
final long nowMillis) {
CLOCK.setTimeMillis(nowMillis);
final AuthCheckRequest request = new AuthCheckRequest(E164_VALID, List.copyOf(expected.keySet()));
final Response response = RESOURCES.getJerseyTest().target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity(request, MediaType.APPLICATION_JSON));
try (response) {
assertEquals(200, response.getStatus());
final Map<String, CheckStatus> res = parseCheckResponse(response);
assertEquals(expected, res);
}
}
@Test
public void testHttpResponseCodeSuccess() {
final Map<String, CheckStatus> expected = Map.of(
token(USER_1, day(10)), CheckStatus.INVALID,
token(USER_1, day(20)), CheckStatus.MATCH,
token(USER_2, day(10)), CheckStatus.NO_MATCH,
token(USER_3, day(20)), CheckStatus.NO_MATCH,
token(USER_3, day(10)), CheckStatus.INVALID
);
CLOCK.setTimeMillis(day(25));
final AuthCheckRequest in = new AuthCheckRequest(E164_VALID, List.copyOf(expected.keySet()));
final Response response = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity(in, MediaType.APPLICATION_JSON));
try (response) {
assertEquals(200, response.getStatus());
assertEquals(expected, parseCheckResponse(response));
}
}
@Test
public void testHttpResponseCodeWhenInvalidNumber() {
final AuthCheckRequest in = new AuthCheckRequest(E164_INVALID, Collections.singletonList("1"));
final Response response = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity(in, MediaType.APPLICATION_JSON));
try (response) {
assertEquals(422, response.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenTooManyTokens() {
final AuthCheckRequest inOkay = new AuthCheckRequest(E164_VALID, List.of(
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"
));
final AuthCheckRequest inTooMany = new AuthCheckRequest(E164_VALID, List.of(
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"
));
final AuthCheckRequest inNoTokens = new AuthCheckRequest(E164_VALID, Collections.emptyList());
final Response responseOkay = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity(inOkay, MediaType.APPLICATION_JSON));
final Response responseError1 = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity(inTooMany, MediaType.APPLICATION_JSON));
final Response responseError2 = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity(inNoTokens, MediaType.APPLICATION_JSON));
try (responseOkay; responseError1; responseError2) {
assertEquals(200, responseOkay.getStatus());
assertEquals(422, responseError1.getStatus());
assertEquals(422, responseError2.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenPasswordsMissing() {
final Response response = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity("""
{
"number": "123"
}
""", MediaType.APPLICATION_JSON));
try (response) {
assertEquals(422, response.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenNumberMissing() {
final Response response = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity("""
{
"passwords": ["aaa:bbb"]
}
""", MediaType.APPLICATION_JSON));
try (response) {
assertEquals(422, response.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenExtraFields() {
final Response response = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity("""
{
"number": "+18005550123",
"passwords": ["aaa:bbb"],
"unexpected": "value"
}
""", MediaType.APPLICATION_JSON));
try (response) {
assertEquals(200, response.getStatus());
}
}
@Test
public void testAcceptsPasswordsOrTokens() {
final Response passwordsResponse = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity("""
{
"number": "+18005550123",
"passwords": ["aaa:bbb"]
}
""", MediaType.APPLICATION_JSON));
try (passwordsResponse) {
assertEquals(200, passwordsResponse.getStatus());
}
final Response tokensResponse = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity("""
{
"number": "+18005550123",
"tokens": ["aaa:bbb"]
}
""", MediaType.APPLICATION_JSON));
try (tokensResponse) {
assertEquals(200, tokensResponse.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenNotAJson() {
final Response response = RESOURCES.getJerseyTest()
.target(pathPrefix + "/auth/check")
.request()
.post(Entity.entity("random text", MediaType.APPLICATION_JSON));
try (response) {
assertEquals(400, response.getStatus());
}
}
private String token(final UUID uuid, final long timeMillis) {
return token(credentials(uuid, timeMillis));
}
private static String token(final ExternalServiceCredentials credentials) {
return credentials.username() + ":" + credentials.password();
}
private ExternalServiceCredentials credentials(final UUID uuid, final long timeMillis) {
CLOCK.setTimeMillis(timeMillis);
return CREDENTIAL_GENERATOR.generateForUuid(uuid);
}
private static long day(final int n) {
return TimeUnit.DAYS.toMillis(n);
}
private static Account account(final UUID uuid) {
final Account a = new Account();
a.setUuid(uuid);
return a;
}
}
}

View File

@ -0,0 +1,70 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.whispersystems.textsecuregcm.util.MockUtils.randomSecretBytes;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecoveryConfiguration;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.MutableClock;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.time.Instant;
import java.util.HexFormat;
@ExtendWith(DropwizardExtensionsSupport.class)
public class SecureValueRecoveryBControllerTest {
private static final SecureValueRecoveryConfiguration CFG = new SecureValueRecoveryConfiguration(
"",
randomSecretBytes(32),
randomSecretBytes(32),
null,
null,
null
);
private static final MutableClock CLOCK = new MutableClock();
private static final ExternalServiceCredentialsGenerator CREDENTIAL_GENERATOR =
SecureValueRecoveryBController.credentialsGenerator(CFG, CLOCK);
private static final SecureValueRecoveryBController CONTROLLER =
new SecureValueRecoveryBController(CREDENTIAL_GENERATOR);
private static final ResourceExtension RESOURCES = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(CONTROLLER)
.build();
@Test
public void testGetCredentials() {
CLOCK.setTimeInstant(Instant.ofEpochSecond(123));
final ExternalServiceCredentials creds = RESOURCES.getJerseyTest()
.target("/v1/svrb/auth")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(ExternalServiceCredentials.class);
assertThat(HexFormat.of().parseHex(creds.username())).hasSize(16);
System.out.println(creds.password());
final String[] split = creds.password().split(":", 2);
assertThat(Long.parseLong(split[0])).isEqualTo(123);
}
}

View File

@ -1,324 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.dropwizard.testing.junit5.ResourceExtension;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.entities.AuthCheckRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.MutableClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
abstract class SecureValueRecoveryControllerBaseTest {
private static final UUID USER_1 = UUID.randomUUID();
private static final UUID USER_2 = UUID.randomUUID();
private static final UUID USER_3 = UUID.randomUUID();
private static final String E164_VALID = "+18005550123";
private static final String E164_INVALID = "1(800)555-0123";
private final String pathPrefix;
private final ResourceExtension resourceExtension;
private final AccountsManager mockAccountsManager;
private final ExternalServiceCredentialsGenerator credentialsGenerator;
private final MutableClock clock;
@BeforeEach
public void before() throws Exception {
Mockito.when(mockAccountsManager.getByE164(E164_VALID)).thenReturn(Optional.of(account(USER_1)));
}
protected SecureValueRecoveryControllerBaseTest(
final String pathPrefix,
final AccountsManager mockAccountsManager,
final MutableClock mutableClock,
final ResourceExtension resourceExtension,
final ExternalServiceCredentialsGenerator credentialsGenerator) {
this.pathPrefix = pathPrefix;
this.resourceExtension = resourceExtension;
this.mockAccountsManager = mockAccountsManager;
this.credentialsGenerator = credentialsGenerator;
this.clock = mutableClock;
}
enum CheckStatus {
MATCH,
NO_MATCH,
INVALID
}
abstract Map<String, CheckStatus> parseCheckResponse(Response response);
@Test
public void testOneMatch() throws Exception {
validate(Map.of(
token(USER_1, day(1)), CheckStatus.MATCH,
token(USER_2, day(1)), CheckStatus.NO_MATCH,
token(USER_3, day(1)), CheckStatus.NO_MATCH
), day(2));
}
@Test
public void testNoMatch() throws Exception {
validate(Map.of(
token(USER_2, day(1)), CheckStatus.NO_MATCH,
token(USER_3, day(1)), CheckStatus.NO_MATCH
), day(2));
}
@Test
public void testSomeInvalid() throws Exception {
final ExternalServiceCredentials user1Cred = credentials(USER_1, day(1));
final ExternalServiceCredentials user2Cred = credentials(USER_2, day(1));
final ExternalServiceCredentials user3Cred = credentials(USER_3, day(1));
final String fakeToken = token(new ExternalServiceCredentials(user2Cred.username(), user3Cred.password()));
validate(Map.of(
token(user1Cred), CheckStatus.MATCH,
token(user2Cred), CheckStatus.NO_MATCH,
fakeToken, CheckStatus.INVALID
), day(2));
}
@Test
public void testSomeExpired() throws Exception {
validate(Map.of(
token(USER_1, day(100)), CheckStatus.MATCH,
token(USER_2, day(100)), CheckStatus.NO_MATCH,
token(USER_3, day(10)), CheckStatus.INVALID,
token(USER_3, day(20)), CheckStatus.INVALID
), day(110));
}
@Test
public void testSomeHaveNewerVersions() throws Exception {
validate(Map.of(
token(USER_1, day(10)), CheckStatus.INVALID,
token(USER_1, day(20)), CheckStatus.MATCH,
token(USER_2, day(10)), CheckStatus.NO_MATCH,
token(USER_3, day(20)), CheckStatus.NO_MATCH,
token(USER_3, day(10)), CheckStatus.INVALID
), day(25));
}
private void validate(
final Map<String, CheckStatus> expected,
final long nowMillis) throws Exception {
clock.setTimeMillis(nowMillis);
final AuthCheckRequest request = new AuthCheckRequest(E164_VALID, List.copyOf(expected.keySet()));
final Response response = resourceExtension.getJerseyTest().target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity(request, MediaType.APPLICATION_JSON));
try (response) {
assertEquals(200, response.getStatus());
final Map<String, CheckStatus> res = parseCheckResponse(response);
assertEquals(expected, res);
}
}
@Test
public void testHttpResponseCodeSuccess() throws Exception {
final Map<String, CheckStatus> expected = Map.of(
token(USER_1, day(10)), CheckStatus.INVALID,
token(USER_1, day(20)), CheckStatus.MATCH,
token(USER_2, day(10)), CheckStatus.NO_MATCH,
token(USER_3, day(20)), CheckStatus.NO_MATCH,
token(USER_3, day(10)), CheckStatus.INVALID
);
clock.setTimeMillis(day(25));
final AuthCheckRequest in = new AuthCheckRequest(E164_VALID, List.copyOf(expected.keySet()));
final Response response = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity(in, MediaType.APPLICATION_JSON));
try (response) {
assertEquals(200, response.getStatus());
assertEquals(expected, parseCheckResponse(response));
}
}
@Test
public void testHttpResponseCodeWhenInvalidNumber() throws Exception {
final AuthCheckRequest in = new AuthCheckRequest(E164_INVALID, Collections.singletonList("1"));
final Response response = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity(in, MediaType.APPLICATION_JSON));
try (response) {
assertEquals(422, response.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenTooManyTokens() throws Exception {
final AuthCheckRequest inOkay = new AuthCheckRequest(E164_VALID, List.of(
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"
));
final AuthCheckRequest inTooMany = new AuthCheckRequest(E164_VALID, List.of(
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"
));
final AuthCheckRequest inNoTokens = new AuthCheckRequest(E164_VALID, Collections.emptyList());
final Response responseOkay = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity(inOkay, MediaType.APPLICATION_JSON));
final Response responseError1 = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity(inTooMany, MediaType.APPLICATION_JSON));
final Response responseError2 = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity(inNoTokens, MediaType.APPLICATION_JSON));
try (responseOkay; responseError1; responseError2) {
assertEquals(200, responseOkay.getStatus());
assertEquals(422, responseError1.getStatus());
assertEquals(422, responseError2.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenPasswordsMissing() throws Exception {
final Response response = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity("""
{
"number": "123"
}
""", MediaType.APPLICATION_JSON));
try (response) {
assertEquals(422, response.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenNumberMissing() throws Exception {
final Response response = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity("""
{
"passwords": ["aaa:bbb"]
}
""", MediaType.APPLICATION_JSON));
try (response) {
assertEquals(422, response.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenExtraFields() throws Exception {
final Response response = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity("""
{
"number": "+18005550123",
"passwords": ["aaa:bbb"],
"unexpected": "value"
}
""", MediaType.APPLICATION_JSON));
try (response) {
assertEquals(200, response.getStatus());
}
}
@Test
public void testAcceptsPasswordsOrTokens() {
final Response passwordsResponse = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity("""
{
"number": "+18005550123",
"passwords": ["aaa:bbb"]
}
""", MediaType.APPLICATION_JSON));
try (passwordsResponse) {
assertEquals(200, passwordsResponse.getStatus());
}
final Response tokensResponse = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity("""
{
"number": "+18005550123",
"tokens": ["aaa:bbb"]
}
""", MediaType.APPLICATION_JSON));
try (tokensResponse) {
assertEquals(200, tokensResponse.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenNotAJson() throws Exception {
final Response response = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity("random text", MediaType.APPLICATION_JSON));
try (response) {
assertEquals(400, response.getStatus());
}
}
String token(final UUID uuid, final long timeMillis) {
return token(credentials(uuid, timeMillis));
}
static String token(final ExternalServiceCredentials credentials) {
return credentials.username() + ":" + credentials.password();
}
private ExternalServiceCredentials credentials(final UUID uuid, final long timeMillis) {
clock.setTimeMillis(timeMillis);
return credentialsGenerator.generateForUuid(uuid);
}
static long day(final int n) {
return TimeUnit.DAYS.toMillis(n);
}
private static Account account(final UUID uuid) {
final Account a = new Account();
a.setUuid(uuid);
return a;
}
}

View File

@ -992,6 +992,44 @@ class SubscriptionControllerTest {
verify(PLAY_MANAGER, times(1)).cancelAllActiveSubscriptions(oldPurchaseToken);
}
@Test
void createReceiptChargeFailure() throws InvalidInputException, VerificationFailedException {
final byte[] subscriberUserAndKey = new byte[32];
Arrays.fill(subscriberUserAndKey, (byte) 1);
final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey);
when(CLOCK.instant()).thenReturn(Instant.now());
when(SUBSCRIPTIONS.get(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Subscriptions.GetResult.found(Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16),
Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()),
Subscriptions.KEY_PROCESSOR_ID_CUSTOMER_ID,
b(new ProcessorCustomer("customer", PaymentProvider.STRIPE).toDynamoBytes()),
Subscriptions.KEY_SUBSCRIPTION_ID, s("subscriptionId"))))));
when(STRIPE_MANAGER.getReceiptItem(any()))
.thenReturn(CompletableFuture.failedFuture(new SubscriptionException.ChargeFailurePaymentRequired(
PaymentProvider.STRIPE,
new ChargeFailure("card_declined", "Insufficient funds", null, null, null))));
final ReceiptCredentialRequest receiptRequest = new ClientZkReceiptOperations(
ServerSecretParams.generate().getPublicParams()).createReceiptCredentialRequestContext(
new ReceiptSerial(new byte[ReceiptSerial.SIZE])).getRequest();
final Response response = RESOURCE_EXTENSION
.target(String.format("/v1/subscription/%s/receipt_credentials", subscriberId))
.request()
.post(Entity.json(new SubscriptionController.GetReceiptCredentialsRequest(receiptRequest.serialize())));
assertThat(response.getStatus()).isEqualTo(402);
final Map responseMap = response.readEntity(Map.class);
assertThat(responseMap.get("processor")).isEqualTo("STRIPE");
assertThat(responseMap.get("chargeFailure")).asInstanceOf(
InstanceOfAssertFactories.map(String.class, Object.class))
.extracting("code")
.isEqualTo("card_declined");
}
@ParameterizedTest
@CsvSource({"5, P45D", "201, P13D"})
public void createReceiptCredential(long level, Duration expectedExpirationWindow)

View File

@ -0,0 +1,305 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.Status;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.ConsistencyParameters;
import org.signal.keytransparency.client.DistinguishedRequest;
import org.signal.keytransparency.client.DistinguishedResponse;
import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc;
import org.signal.keytransparency.client.MonitorRequest;
import org.signal.keytransparency.client.MonitorResponse;
import org.signal.keytransparency.client.SearchRequest;
import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.UsernameHashMonitorRequest;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.util.Optional;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.ACI;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.ACI_IDENTITY_KEY;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.NUMBER;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.UNIDENTIFIED_ACCESS_KEY;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.USERNAME_HASH;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import static org.whispersystems.textsecuregcm.grpc.KeyTransparencyGrpcService.COMMITMENT_INDEX_LENGTH;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class KeyTransparencyGrpcServiceTest extends SimpleBaseGrpcTest<KeyTransparencyGrpcService, KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceBlockingStub>{
@Mock
private KeyTransparencyServiceClient keyTransparencyServiceClient;
@Mock
private RateLimiter rateLimiter;
@Override
protected KeyTransparencyGrpcService createServiceBeforeEachTest() {
final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.getKeyTransparencySearchLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getKeyTransparencyDistinguishedLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getKeyTransparencyMonitorLimiter()).thenReturn(rateLimiter);
return new KeyTransparencyGrpcService(rateLimiters, keyTransparencyServiceClient);
}
@Override
protected KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceBlockingStub createStub(final Channel channel) {
return KeyTransparencyQueryServiceGrpc.newBlockingStub(channel);
}
@Test
void searchSuccess() throws RateLimitExceededException {
when(keyTransparencyServiceClient.search(any())).thenReturn(SearchResponse.getDefaultInstance());
Mockito.doNothing().when(rateLimiter).validate(any(String.class));
final SearchRequest request = SearchRequest.newBuilder()
.setAci(ByteString.copyFrom(ACI.toCompactByteArray()))
.setAciIdentityKey(ByteString.copyFrom(ACI_IDENTITY_KEY.serialize()))
.setConsistency(ConsistencyParameters.newBuilder()
.setDistinguished(10)
.build())
.build();
assertDoesNotThrow(() -> unauthenticatedServiceStub().search(request));
verify(keyTransparencyServiceClient, times(1)).search(eq(request));
}
@ParameterizedTest
@MethodSource
void searchInvalidRequest(final Optional<byte[]> aciServiceIdentifier,
final Optional<IdentityKey> aciIdentityKey,
final Optional<String> e164,
final Optional<byte[]> unidentifiedAccessKey,
final Optional<byte[]> usernameHash,
final Optional<Long> lastTreeHeadSize,
final Optional<Long> distinguishedTreeHeadSize) {
final SearchRequest.Builder requestBuilder = SearchRequest.newBuilder();
aciServiceIdentifier.ifPresent(v -> requestBuilder.setAci(ByteString.copyFrom(v)));
aciIdentityKey.ifPresent(v -> requestBuilder.setAciIdentityKey(ByteString.copyFrom(v.serialize())));
usernameHash.ifPresent(v -> requestBuilder.setUsernameHash(ByteString.copyFrom(v)));
final E164SearchRequest.Builder e164RequestBuilder = E164SearchRequest.newBuilder();
e164.ifPresent(e164RequestBuilder::setE164);
unidentifiedAccessKey.ifPresent(v -> e164RequestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(v)));
requestBuilder.setE164SearchRequest(e164RequestBuilder.build());
final ConsistencyParameters.Builder consistencyBuilder = ConsistencyParameters.newBuilder();
distinguishedTreeHeadSize.ifPresent(consistencyBuilder::setDistinguished);
lastTreeHeadSize.ifPresent(consistencyBuilder::setLast);
requestBuilder.setConsistency(consistencyBuilder.build());
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().search(requestBuilder.build()));
verifyNoInteractions(keyTransparencyServiceClient);
}
private static Stream<Arguments> searchInvalidRequest() {
byte[] aciBytes = ACI.toCompactByteArray();
return Stream.of(
Arguments.argumentSet("Empty ACI", Optional.empty(), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Null ACI identity key", Optional.of(aciBytes), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Invalid ACI", Optional.of(new byte[15]), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Non-positive consistency.last", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(0L), Optional.of(4L)),
Arguments.argumentSet("consistency.distinguished not provided",Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()),
Arguments.argumentSet("Non-positive consistency.distinguished",Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(0L)),
Arguments.argumentSet("E164 can't be provided without an unidentified access key", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.of(NUMBER), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Unidentified access key can't be provided without E164", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.of(UNIDENTIFIED_ACCESS_KEY), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Invalid username hash", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.of(new byte[19]), Optional.empty(), Optional.of(4L))
);
}
@Test
void searchRatelimited() throws RateLimitExceededException {
final Duration retryAfterDuration = Duration.ofMinutes(7);
Mockito.doThrow(new RateLimitExceededException(retryAfterDuration)).when(rateLimiter).validate(any(String.class));
final SearchRequest request = SearchRequest.newBuilder()
.setAci(ByteString.copyFrom(ACI.toCompactByteArray()))
.setAciIdentityKey(ByteString.copyFrom(ACI_IDENTITY_KEY.serialize()))
.setConsistency(ConsistencyParameters.newBuilder()
.setDistinguished(10)
.build())
.build();
assertRateLimitExceeded(retryAfterDuration, () -> unauthenticatedServiceStub().search(request));
verifyNoInteractions(keyTransparencyServiceClient);
}
@Test
void monitorSuccess() {
when(keyTransparencyServiceClient.monitor(any())).thenReturn(MonitorResponse.getDefaultInstance());
when(rateLimiter.validateReactive(any(String.class)))
.thenReturn(Mono.empty());
final AciMonitorRequest aciMonitorRequest = AciMonitorRequest.newBuilder()
.setAci(ByteString.copyFrom(ACI.toCompactByteArray()))
.setCommitmentIndex(ByteString.copyFrom(new byte[COMMITMENT_INDEX_LENGTH]))
.setEntryPosition(10)
.build();
final MonitorRequest request = MonitorRequest.newBuilder()
.setAci(aciMonitorRequest)
.setConsistency(ConsistencyParameters.newBuilder()
.setDistinguished(10)
.setLast(10)
.build())
.build();
assertDoesNotThrow(() -> unauthenticatedServiceStub().monitor(request));
verify(keyTransparencyServiceClient, times(1)).monitor(eq(request));
}
@ParameterizedTest
@MethodSource
void monitorInvalidRequest(final Optional<AciMonitorRequest> aciMonitorRequest,
final Optional<E164MonitorRequest> e164MonitorRequest,
final Optional<UsernameHashMonitorRequest> usernameHashMonitorRequest,
final Optional<Long> lastTreeHeadSize,
final Optional<Long> distinguishedTreeHeadSize) {
final MonitorRequest.Builder requestBuilder = MonitorRequest.newBuilder();
aciMonitorRequest.ifPresent(requestBuilder::setAci);
e164MonitorRequest.ifPresent(requestBuilder::setE164);
usernameHashMonitorRequest.ifPresent(requestBuilder::setUsernameHash);
final ConsistencyParameters.Builder consistencyBuilder = ConsistencyParameters.newBuilder();
lastTreeHeadSize.ifPresent(consistencyBuilder::setLast);
distinguishedTreeHeadSize.ifPresent(consistencyBuilder::setDistinguished);
requestBuilder.setConsistency(consistencyBuilder.build());
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().monitor(requestBuilder.build()));
}
private static Stream<Arguments> monitorInvalidRequest() {
final Optional<AciMonitorRequest> validAciMonitorRequest = Optional.of(constructAciMonitorRequest(ACI.toCompactByteArray(), new byte[32], 10));
return Stream.of(
Arguments.argumentSet("ACI monitor request can't be unset", Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("ACI can't be empty",Optional.of(AciMonitorRequest.newBuilder().build()), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Empty ACI on ACI monitor request",Optional.of(constructAciMonitorRequest(new byte[0], new byte[32], 10)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid ACI", Optional.of(constructAciMonitorRequest(new byte[15], new byte[32], 10)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid commitment index on ACI monitor request", Optional.of(constructAciMonitorRequest(ACI.toCompactByteArray(), new byte[31], 10)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid entry position on ACI monitor request", Optional.of(constructAciMonitorRequest(ACI.toCompactByteArray(), new byte[32], 0)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("E164 can't be blank", validAciMonitorRequest, Optional.of(constructE164MonitorRequest("", new byte[32], 10)), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid commitment index on E164 monitor request", validAciMonitorRequest, Optional.of(constructE164MonitorRequest(NUMBER, new byte[31], 10)), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid entry position on E164 monitor request", validAciMonitorRequest, Optional.of(constructE164MonitorRequest(NUMBER, new byte[32], 0)), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Username hash can't be empty", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(new byte[0], new byte[32], 10)), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid username hash length", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(new byte[31], new byte[32], 10)), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid commitment index on username hash monitor request", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(USERNAME_HASH, new byte[31], 10)), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid entry position on username hash monitor request", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(USERNAME_HASH, new byte[32], 0)), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("consistency.last must be provided", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L),
Arguments.argumentSet("consistency.last must be positive", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.of(0L), Optional.of(4L)),
Arguments.argumentSet("consistency.distinguished must be provided", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.of(4L)), Optional.empty()),
Arguments.argumentSet("consistency.distinguished must be positive", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(0L))
);
}
@Test
void monitorRatelimited() throws RateLimitExceededException {
final Duration retryAfterDuration = Duration.ofMinutes(7);
Mockito.doThrow(new RateLimitExceededException(retryAfterDuration)).when(rateLimiter).validate(any(String.class));
final AciMonitorRequest aciMonitorRequest = AciMonitorRequest.newBuilder()
.setAci(ByteString.copyFrom(ACI.toCompactByteArray()))
.setCommitmentIndex(ByteString.copyFrom(new byte[COMMITMENT_INDEX_LENGTH]))
.setEntryPosition(10)
.build();
final MonitorRequest request = MonitorRequest.newBuilder()
.setAci(aciMonitorRequest)
.setConsistency(ConsistencyParameters.newBuilder()
.setDistinguished(10)
.setLast(10)
.build())
.build();
assertRateLimitExceeded(retryAfterDuration, () -> unauthenticatedServiceStub().monitor(request));
verifyNoInteractions(keyTransparencyServiceClient);
}
@Test
void distinguishedSuccess() {
when(keyTransparencyServiceClient.distinguished(any())).thenReturn(DistinguishedResponse.getDefaultInstance());
when(rateLimiter.validateReactive(any(String.class)))
.thenReturn(Mono.empty());
final DistinguishedRequest request = DistinguishedRequest.newBuilder().build();
assertDoesNotThrow(() -> unauthenticatedServiceStub().distinguished(request));
verify(keyTransparencyServiceClient, times(1)).distinguished(eq(request));
}
@Test
void distinguishedInvalidRequest() {
final DistinguishedRequest request = DistinguishedRequest.newBuilder()
.setLast(0)
.build();
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().distinguished(request));
verifyNoInteractions(keyTransparencyServiceClient);
}
@Test
void distinguishedRatelimited() throws RateLimitExceededException {
final Duration retryAfterDuration = Duration.ofMinutes(7);
Mockito.doThrow(new RateLimitExceededException(retryAfterDuration)).when(rateLimiter).validate(any(String.class));
final DistinguishedRequest request = DistinguishedRequest.newBuilder()
.setLast(10)
.build();
assertRateLimitExceeded(retryAfterDuration, () -> unauthenticatedServiceStub().distinguished(request));
verifyNoInteractions(keyTransparencyServiceClient);
}
private static AciMonitorRequest constructAciMonitorRequest(final byte[] aci, final byte[] commitmentIndex, final long entryPosition) {
return AciMonitorRequest.newBuilder()
.setAci(ByteString.copyFrom(aci))
.setCommitmentIndex(ByteString.copyFrom(commitmentIndex))
.setEntryPosition(entryPosition)
.build();
}
private static E164MonitorRequest constructE164MonitorRequest(final String e164, final byte[] commitmentIndex, final long entryPosition) {
return E164MonitorRequest.newBuilder()
.setE164(e164)
.setCommitmentIndex(ByteString.copyFrom(commitmentIndex))
.setEntryPosition(entryPosition)
.build();
}
private static UsernameHashMonitorRequest constructUsernameHashMonitorRequest(final byte[] usernameHash, final byte[] commitmentIndex, final long entryPosition) {
return UsernameHashMonitorRequest.newBuilder()
.setUsernameHash(ByteString.copyFrom(usernameHash))
.setCommitmentIndex(ByteString.copyFrom(commitmentIndex))
.setEntryPosition(entryPosition)
.build();
}
}

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.redis;
import static org.junit.jupiter.api.Assumptions.assumeFalse;
import io.github.resilience4j.circuitbreaker.CallNotPermittedException;
import io.lettuce.core.FlushMode;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisException;
import io.lettuce.core.RedisURI;
@ -21,7 +22,6 @@ import java.net.ServerSocket;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
@ -32,18 +32,18 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import redis.embedded.RedisServer;
import redis.embedded.exceptions.EmbeddedRedisException;
public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallback, AfterAllCallback,
AfterEachCallback {
public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallback, AfterEachCallback,
ExtensionContext.Store.CloseableResource {
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(2);
private static final int NODE_COUNT = 2;
private static final RedisServer[] CLUSTER_NODES = new RedisServer[NODE_COUNT];
private static ClientResources redisClientResources;
private final Duration timeout;
private final RetryConfiguration retryConfiguration;
private FaultTolerantRedisClusterClient redisCluster;
private ClientResources redisClientResources;
public RedisClusterExtension(final Duration timeout, final RetryConfiguration retryConfiguration) {
this.timeout = timeout;
@ -56,35 +56,42 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb
}
@Override
public void afterAll(final ExtensionContext context) throws Exception {
for (final RedisServer node : CLUSTER_NODES) {
node.stop();
public void close() throws Throwable {
if (redisClientResources != null) {
redisClientResources.shutdown().get();
for (final RedisServer node : CLUSTER_NODES) {
node.stop();
}
}
redisClientResources = null;
}
@Override
public void afterEach(final ExtensionContext context) throws Exception {
redisCluster.shutdown();
redisClientResources.shutdown().get();
}
@Override
public void beforeAll(final ExtensionContext context) throws Exception {
assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
for (int i = 0; i < NODE_COUNT; i++) {
// We're occasionally seeing redis server startup failing due to the bind address being already in use.
// To mitigate that, we're going to just retry a couple of times before failing the test.
CLUSTER_NODES[i] = startWithRetries(3);
}
if (redisClientResources == null) {
redisClientResources = ClientResources.builder().build();
assembleCluster(CLUSTER_NODES);
for (int i = 0; i < NODE_COUNT; i++) {
// We're occasionally seeing redis server startup failing due to the bind address being already in use.
// To mitigate that, we're going to just retry a couple of times before failing the test.
CLUSTER_NODES[i] = startWithRetries(3);
}
assembleCluster(CLUSTER_NODES);
}
}
@Override
public void beforeEach(final ExtensionContext context) throws Exception {
redisClientResources = ClientResources.builder().build();
final CircuitBreakerConfiguration circuitBreakerConfig = new CircuitBreakerConfiguration();
circuitBreakerConfig.setWaitDurationInOpenState(Duration.ofMillis(500));
redisCluster = new FaultTolerantRedisClusterClient("test-cluster",
@ -120,7 +127,7 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb
}
});
redisCluster.useCluster(connection -> connection.sync().flushall());
redisCluster.useCluster(connection -> connection.sync().flushall(FlushMode.SYNC));
}
public static List<RedisURI> getRedisURIs() {

View File

@ -7,13 +7,12 @@ package org.whispersystems.textsecuregcm.redis;
import static org.junit.jupiter.api.Assumptions.assumeFalse;
import io.lettuce.core.FlushMode;
import io.lettuce.core.RedisURI;
import io.lettuce.core.resource.ClientResources;
import java.io.IOException;
import java.net.ServerSocket;
import java.time.Duration;
import io.lettuce.core.resource.ClientResources;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
@ -22,11 +21,12 @@ import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
import redis.embedded.RedisServer;
import redis.embedded.exceptions.EmbeddedRedisException;
public class RedisServerExtension implements BeforeAllCallback, BeforeEachCallback, AfterAllCallback, AfterEachCallback {
public class RedisServerExtension implements BeforeAllCallback, BeforeEachCallback, ExtensionContext.Store.CloseableResource {
private static RedisServer redisServer;
private static ClientResources redisClientResources;
private FaultTolerantRedisClient faultTolerantRedisClient;
private ClientResources redisClientResources;
public static class RedisServerExtensionBuilder {
@ -46,14 +46,18 @@ public class RedisServerExtension implements BeforeAllCallback, BeforeEachCallba
public void beforeAll(final ExtensionContext context) throws Exception {
assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
redisServer = RedisServer.builder()
.setting("appendonly no")
.setting("save \"\"")
.setting("dir " + System.getProperty("java.io.tmpdir"))
.port(getAvailablePort())
.build();
if (redisServer == null) {
redisServer = RedisServer.builder()
.setting("appendonly no")
.setting("save \"\"")
.setting("dir " + System.getProperty("java.io.tmpdir"))
.port(getAvailablePort())
.build();
startWithRetries(3);
redisClientResources = ClientResources.builder().build();
startWithRetries(3);
}
}
public static RedisURI getRedisURI() {
@ -62,7 +66,6 @@ public class RedisServerExtension implements BeforeAllCallback, BeforeEachCallba
@Override
public void beforeEach(final ExtensionContext context) {
redisClientResources = ClientResources.builder().build();
final CircuitBreakerConfiguration circuitBreakerConfig = new CircuitBreakerConfiguration();
circuitBreakerConfig.setWaitDurationInOpenState(Duration.ofMillis(500));
faultTolerantRedisClient = new FaultTolerantRedisClient("test-redis-client",
@ -72,19 +75,18 @@ public class RedisServerExtension implements BeforeAllCallback, BeforeEachCallba
circuitBreakerConfig,
new RetryConfiguration());
faultTolerantRedisClient.useConnection(connection -> connection.sync().flushall());
faultTolerantRedisClient.useConnection(connection -> connection.sync().flushall(FlushMode.SYNC));
}
@Override
public void afterEach(final ExtensionContext context) throws InterruptedException {
redisClientResources.shutdown().await();
}
@Override
public void afterAll(final ExtensionContext context) {
public void close() throws Throwable {
if (redisServer != null) {
redisClientResources.shutdown().await();
redisServer.stop();
}
redisClientResources = null;
redisServer = null;
}
public FaultTolerantRedisClient getRedisClient() {

View File

@ -32,7 +32,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecoveryConfiguration;
class SecureValueRecovery2ClientTest {
@ -55,7 +55,7 @@ class SecureValueRecovery2ClientTest {
httpExecutor = Executors.newSingleThreadExecutor();
retryExecutor = Executors.newSingleThreadScheduledExecutor();
final SecureValueRecovery2Configuration config = new SecureValueRecovery2Configuration(
final SecureValueRecoveryConfiguration config = new SecureValueRecoveryConfiguration(
"http://localhost:" + wireMock.getPort(),
randomSecretBytes(32),
randomSecretBytes(32),

View File

@ -44,12 +44,14 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
public class AccountCreationDeletionIntegrationTest {
@ -65,12 +67,16 @@ public class AccountCreationDeletionIntegrationTest {
DynamoDbExtensionSchema.Tables.USERNAMES,
DynamoDbExtensionSchema.Tables.EC_KEYS,
DynamoDbExtensionSchema.Tables.PQ_KEYS,
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private ScheduledExecutorService executor;
@ -90,13 +96,19 @@ public class AccountCreationDeletionIntegrationTest {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()
);
new SingleUseECPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()),
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()),
new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
mock(ExperimentEnrollmentManager.class));
final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName());

View File

@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@ -44,6 +45,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
class AccountsManagerChangeNumberIntegrationTest {
@ -59,12 +61,16 @@ class AccountsManagerChangeNumberIntegrationTest {
Tables.USERNAMES,
Tables.EC_KEYS,
Tables.PQ_KEYS,
Tables.PAGED_PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private KeysManager keysManager;
private DisconnectionRequestManager disconnectionRequestManager;
private ScheduledExecutorService executor;
@ -81,13 +87,19 @@ class AccountsManagerChangeNumberIntegrationTest {
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()
);
new SingleUseECPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()),
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()),
new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
mock(ExperimentEnrollmentManager.class));
final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName());

View File

@ -73,6 +73,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
Tables.DELETED_ACCOUNTS,
Tables.EC_KEYS,
Tables.PQ_KEYS,
Tables.PAGED_PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);

View File

@ -38,6 +38,7 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -47,6 +48,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
@ -72,12 +74,16 @@ class AccountsManagerUsernameIntegrationTest {
Tables.PNI_ASSIGNMENTS,
Tables.EC_KEYS,
Tables.PQ_KEYS,
Tables.PAGED_PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension
static RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private AccountsManager accountsManager;
private Accounts accounts;
@ -94,13 +100,19 @@ class AccountsManagerUsernameIntegrationTest {
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()
);
new SingleUseECPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()),
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()),
new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
mock(ExperimentEnrollmentManager.class));
accounts = Mockito.spy(new Accounts(
Clock.systemUTC(),

View File

@ -36,6 +36,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
@ -45,6 +46,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
public class AddRemoveDeviceIntegrationTest {
@ -61,6 +63,7 @@ public class AddRemoveDeviceIntegrationTest {
DynamoDbExtensionSchema.Tables.USERNAMES,
DynamoDbExtensionSchema.Tables.EC_KEYS,
DynamoDbExtensionSchema.Tables.PQ_KEYS,
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@ -70,6 +73,9 @@ public class AddRemoveDeviceIntegrationTest {
@RegisterExtension
static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build();
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private ExecutorService accountLockExecutor;
private ScheduledExecutorService messagePollExecutor;
@ -89,13 +95,19 @@ public class AddRemoveDeviceIntegrationTest {
clock = TestClock.pinned(Instant.now());
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()
);
new SingleUseECPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()),
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()),
new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
mock(ExperimentEnrollmentManager.class));
final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName());

View File

@ -143,6 +143,20 @@ public final class DynamoDbExtensionSchema {
.build()),
List.of(), List.of()),
PAGED_PQ_KEYS("paged_pq_keys_test",
PagedSingleUseKEMPreKeyStore.KEY_ACCOUNT_UUID,
PagedSingleUseKEMPreKeyStore.KEY_DEVICE_ID,
List.of(
AttributeDefinition.builder()
.attributeName(PagedSingleUseKEMPreKeyStore.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(PagedSingleUseKEMPreKeyStore.KEY_DEVICE_ID)
.attributeType(ScalarAttributeType.N)
.build()),
List.of(), List.of()),
PUSH_NOTIFICATION_EXPERIMENT_SAMPLES("push_notification_experiment_samples_test",
PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME,
PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID,

View File

@ -0,0 +1,114 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.grpc.ServiceIdentifierUtil;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import static org.junit.jupiter.api.Assertions.*;
class EnvelopeUtilTest {
@Test
void compressExpand() {
{
final MessageProtos.Envelope compressibleFieldsNullMessage = generateRandomMessageBuilder().build();
final MessageProtos.Envelope compressed = EnvelopeUtil.compress(compressibleFieldsNullMessage);
assertFalse(compressed.hasSourceServiceId());
assertFalse(compressed.hasSourceServiceIdBinary());
assertFalse(compressed.hasDestinationServiceId());
assertFalse(compressed.hasDestinationServiceIdBinary());
assertFalse(compressed.hasServerGuid());
assertFalse(compressed.hasServerGuidBinary());
assertFalse(compressed.hasUpdatedPni());
assertFalse(compressed.hasUpdatedPniBinary());
final MessageProtos.Envelope expanded = EnvelopeUtil.expand(compressed);
assertFalse(expanded.hasSourceServiceId());
assertFalse(expanded.hasSourceServiceIdBinary());
assertFalse(expanded.hasDestinationServiceId());
assertFalse(expanded.hasDestinationServiceIdBinary());
assertFalse(expanded.hasServerGuid());
assertFalse(expanded.hasServerGuidBinary());
assertFalse(compressed.hasUpdatedPni());
assertFalse(compressed.hasUpdatedPniBinary());
}
{
final ServiceIdentifier sourceServiceId = generateRandomServiceIdentifier();
final ServiceIdentifier destinationServiceId = generateRandomServiceIdentifier();
final UUID serverGuid = UUID.randomUUID();
final UUID updatedPni = UUID.randomUUID();
final MessageProtos.Envelope compressibleFieldsExpandedMessage = generateRandomMessageBuilder()
.setSourceServiceId(sourceServiceId.toServiceIdentifierString())
.setDestinationServiceId(destinationServiceId.toServiceIdentifierString())
.setServerGuid(serverGuid.toString())
.setUpdatedPni(updatedPni.toString())
.build();
final MessageProtos.Envelope compressed = EnvelopeUtil.compress(compressibleFieldsExpandedMessage);
assertFalse(compressed.hasSourceServiceId());
assertEquals(ServiceIdentifierUtil.toCompactByteString(sourceServiceId), compressed.getSourceServiceIdBinary());
assertFalse(compressed.hasDestinationServiceId());
assertEquals(ServiceIdentifierUtil.toCompactByteString(destinationServiceId), compressed.getDestinationServiceIdBinary());
assertFalse(compressed.hasServerGuid());
assertEquals(UUIDUtil.toByteString(serverGuid), compressed.getServerGuidBinary());
assertFalse(compressed.hasUpdatedPni());
assertEquals(UUIDUtil.toByteString(updatedPni), compressed.getUpdatedPniBinary());
assertEquals(compressed, EnvelopeUtil.compress(compressed), "Double compression should make no changes");
final MessageProtos.Envelope expanded = EnvelopeUtil.expand(compressed);
assertEquals(sourceServiceId.toServiceIdentifierString(), expanded.getSourceServiceId());
assertFalse(expanded.hasSourceServiceIdBinary());
assertEquals(destinationServiceId.toServiceIdentifierString(), expanded.getDestinationServiceId());
assertFalse(expanded.hasDestinationServiceIdBinary());
assertEquals(serverGuid.toString(), expanded.getServerGuid());
assertFalse(expanded.hasServerGuidBinary());
assertEquals(updatedPni.toString(), expanded.getUpdatedPni());
assertEquals(UUIDUtil.toByteString(updatedPni), expanded.getUpdatedPniBinary());
assertEquals(expanded, EnvelopeUtil.expand(expanded), "Double expansion should make no changes");
// Expanded envelopes include both representations of the `updatedPni` field
assertEquals(compressibleFieldsExpandedMessage.toBuilder().setUpdatedPniBinary(UUIDUtil.toByteString(updatedPni)).build(),
expanded);
}
}
private static ServiceIdentifier generateRandomServiceIdentifier() {
final IdentityType identityType = ThreadLocalRandom.current().nextBoolean() ? IdentityType.ACI : IdentityType.PNI;
return switch (identityType) {
case ACI -> new AciServiceIdentifier(UUID.randomUUID());
case PNI -> new PniServiceIdentifier(UUID.randomUUID());
};
}
private MessageProtos.Envelope.Builder generateRandomMessageBuilder() {
return MessageProtos.Envelope.newBuilder()
.setClientTimestamp(ThreadLocalRandom.current().nextLong())
.setServerTimestamp(ThreadLocalRandom.current().nextLong())
.setContent(ByteString.copyFrom(TestRandomUtil.nextBytes(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT);
}
}

View File

@ -0,0 +1,19 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.apple.foundationdb.Database;
import com.apple.foundationdb.FDB;
import java.io.IOException;
interface FoundationDbDatabaseLifecycleManager {
void initializeDatabase(final FDB fdb) throws IOException;
Database getDatabase();
void closeDatabase();
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.apple.foundationdb.Database;
import com.apple.foundationdb.FDB;
import java.io.IOException;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
class FoundationDbExtension implements BeforeAllCallback, ExtensionContext.Store.CloseableResource {
private static FoundationDbDatabaseLifecycleManager databaseLifecycleManager;
@Override
public void beforeAll(final ExtensionContext context) throws IOException {
if (databaseLifecycleManager == null) {
final String serviceContainerName = System.getProperty("foundationDb.serviceContainerName");
databaseLifecycleManager = serviceContainerName != null
? new ServiceContainerFoundationDbDatabaseLifecycleManager(serviceContainerName)
: new TestcontainersFoundationDbDatabaseLifecycleManager();
databaseLifecycleManager.initializeDatabase(FDB.selectAPIVersion(FoundationDbVersion.getFoundationDbApiVersion()));
context.getRoot().getStore(ExtensionContext.Namespace.GLOBAL).put(getClass().getName(), this);
}
}
public Database getDatabase() {
return databaseLifecycleManager.getDatabase();
}
@Override
public void close() throws Throwable {
if (databaseLifecycleManager != null) {
databaseLifecycleManager.closeDatabase();
}
}
}

View File

@ -0,0 +1,34 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import java.nio.charset.StandardCharsets;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
public class FoundationDbTest {
@RegisterExtension
static FoundationDbExtension FOUNDATION_DB_EXTENSION = new FoundationDbExtension();
@Test
void setGetValue() {
final byte[] key = "test".getBytes(StandardCharsets.UTF_8);
final byte[] value = TestRandomUtil.nextBytes(16);
FOUNDATION_DB_EXTENSION.getDatabase().run(transaction -> {
transaction.set(key, value);
return null;
});
final byte[] retrievedValue = FOUNDATION_DB_EXTENSION.getDatabase().run(transaction -> transaction.get(key).join());
assertArrayEquals(value, retrievedValue);
}
}

View File

@ -0,0 +1,102 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class KEMPreKeyPageTest {
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@Test
void serializeSinglePreKey() {
final ByteBuffer page = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, List.of(generatePreKey(5)));
final int actualMagic = page.getInt();
assertEquals(KEMPreKeyPage.HEADER_MAGIC, actualMagic);
final int version = page.getInt();
assertEquals(version, 1);
assertEquals(KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH, page.remaining());
}
@Test
void emptyPreKeys() {
assertThrows(IllegalArgumentException.class, () -> KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, Collections.emptyList()));
}
@Test
void roundTripSingleton() throws InvalidKeyException {
final KEMSignedPreKey preKey = generatePreKey(5);
final ByteBuffer buffer = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, List.of(preKey));
final long serializedLength = buffer.remaining();
assertEquals(KEMPreKeyPage.HEADER_SIZE + KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH, serializedLength);
final KEMPreKeyPage.KeyLocation keyLocation = KEMPreKeyPage.keyLocation(1, 0);
assertEquals(KEMPreKeyPage.HEADER_SIZE, keyLocation.getStartInclusive());
assertEquals(serializedLength, KEMPreKeyPage.HEADER_SIZE + keyLocation.length());
buffer.position(keyLocation.getStartInclusive());
final KEMSignedPreKey deserializedPreKey = KEMPreKeyPage.deserializeKey(1, buffer);
assertEquals(5L, deserializedPreKey.keyId());
assertEquals(preKey, deserializedPreKey);
}
@Test
void roundTripMultiple() throws InvalidKeyException {
final List<KEMSignedPreKey> keys = Arrays.asList(generatePreKey(1), generatePreKey(2), generatePreKey(5));
final ByteBuffer page = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, keys);
assertEquals(KEMPreKeyPage.HEADER_SIZE + KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH * 3, page.remaining());
for (int i = 0; i < keys.size(); i++) {
final KEMPreKeyPage.KeyLocation keyLocation = KEMPreKeyPage.keyLocation(1, i);
assertEquals(
KEMPreKeyPage.HEADER_SIZE + KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH * i,
keyLocation.getStartInclusive());
final ByteBuffer buf = page.slice(keyLocation.getStartInclusive(), keyLocation.length());
final KEMSignedPreKey actual = KEMPreKeyPage.deserializeKey(1, buf);
assertEquals(keys.get(i), actual);
}
}
@Test
void wrongFormat() {
assertThrows(IllegalArgumentException.class, () ->
KEMPreKeyPage.deserializeKey(2,
ByteBuffer.allocate(KEMPreKeyPage.HEADER_SIZE + KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH)));
}
@Test
void wrongSize() {
assertThrows(IllegalArgumentException.class, () -> KEMPreKeyPage.deserializeKey(1, ByteBuffer.allocate(100)));
}
@Test
void negativeKeyId() throws InvalidKeyException {
final KEMSignedPreKey preKey = generatePreKey(-1);
ByteBuffer page = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, List.of(preKey));
page.position(KEMPreKeyPage.HEADER_SIZE);
KEMSignedPreKey deserializedPreKey = KEMPreKeyPage.deserializeKey(1, page);
assertEquals(-1L, deserializedPreKey.keyId());
}
private static KEMSignedPreKey generatePreKey(long keyId) {
return KeysHelper.signedKEMPreKey((int) keyId, IDENTITY_KEY_PAIR);
}
}

View File

@ -8,28 +8,45 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.when;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
class KeysManagerTest {
private KeysManager keysManager;
private ExperimentEnrollmentManager experimentEnrollmentManager;
private SingleUseKEMPreKeyStore singleUseKEMPreKeyStore;
private PagedSingleUseKEMPreKeyStore pagedSingleUseKEMPreKeyStore;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.PAGED_PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final byte DEVICE_ID = 1;
@ -38,13 +55,21 @@ class KeysManagerTest {
@BeforeEach
void setup() {
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
singleUseKEMPreKeyStore = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, Tables.PQ_KEYS.tableName());
pagedSingleUseKEMPreKeyStore = new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
S3_EXTENSION.getBucketName());
keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()
);
new SingleUseECPreKeyStore(dynamoDbAsyncClient, Tables.EC_KEYS.tableName()),
singleUseKEMPreKeyStore,
pagedSingleUseKEMPreKeyStore,
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
experimentEnrollmentManager);
}
@Test
@ -60,18 +85,58 @@ class KeysManagerTest {
"Repeatedly storing same key should have no effect");
}
@Test
void storeKemOneTimePreKeys() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void storeKemOneTimePreKeysClearsOld(boolean inPagedExperiment) {
final List<KEMSignedPreKey> oldPreKeys = List.of(generateTestKEMSignedPreKey(1));
// Leave a key in the 'other' key store
(inPagedExperiment
? singleUseKEMPreKeyStore.store(ACCOUNT_UUID, DEVICE_ID, oldPreKeys)
: pagedSingleUseKEMPreKeyStore.store(ACCOUNT_UUID, DEVICE_ID, oldPreKeys))
.join();
when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME))
.thenReturn(inPagedExperiment);
final List<KEMSignedPreKey> newPreKeys = List.of(generateTestKEMSignedPreKey(2));
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, newPreKeys).join();
final int expectedPagedKeyCount = inPagedExperiment ? 1 : 0;
final int expectedUnpagedKeyCount = 1 - expectedPagedKeyCount;
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedPagedKeyCount, pagedSingleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedUnpagedKeyCount, singleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
final KEMSignedPreKey key = keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow();
assertEquals(2, key.keyId());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void storeKemOneTimePreKeys(boolean inPagedExperiment) {
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Initial pre-key count for an account should be zero");
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME))
.thenReturn(inPagedExperiment);
final int expectedPagedKeyCount = inPagedExperiment ? 1 : 0;
final int expectedUnpagedKeyCount = 1 - expectedPagedKeyCount;
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedPagedKeyCount, pagedSingleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedUnpagedKeyCount, singleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedPagedKeyCount, pagedSingleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedUnpagedKeyCount, singleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
}
@Test
void storeEcSignedPreKeys() {
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isEmpty());
@ -121,9 +186,24 @@ class KeysManagerTest {
}
@Test
void testDeleteSingleUsePreKeysByAccount() {
void takeWithExistingExperimentalKey() {
// Put a key in the new store, even though we're not in the experiment. This simulates a take when operating
// in mixed mode on experiment rollout
pagedSingleUseKEMPreKeyStore.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow().keyId());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testDeleteSingleUsePreKeysByAccount(final boolean inPagedExperiment) {
int keyId = 1;
when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME))
.thenReturn(inPagedExperiment);
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join();
@ -148,10 +228,14 @@ class KeysManagerTest {
}
}
@Test
void testDeleteSingleUsePreKeysByAccountAndDevice() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testDeleteSingleUsePreKeysByAccountAndDevice(final boolean inPagedExperiment) {
int keyId = 1;
when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME))
.thenReturn(inPagedExperiment);
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join();

View File

@ -50,8 +50,8 @@ class MessagesCacheGetItemsScriptTest {
assertNotNull(messageAndScores);
assertEquals(2, messageAndScores.size());
final MessageProtos.Envelope resultEnvelope = MessageProtos.Envelope.parseFrom(
messageAndScores.getFirst());
final MessageProtos.Envelope resultEnvelope =
EnvelopeUtil.expand(MessageProtos.Envelope.parseFrom(messageAndScores.getFirst()));
assertEquals(serverGuid, resultEnvelope.getServerGuid());
}

View File

@ -43,7 +43,7 @@ class MessagesCacheInsertScriptTest {
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId));
assertEquals(List.of(EnvelopeUtil.compress(envelope1)), getStoredMessages(destinationUuid, deviceId));
final MessageProtos.Envelope envelope2 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
@ -52,11 +52,13 @@ class MessagesCacheInsertScriptTest {
insertScript.executeAsync(destinationUuid, deviceId, envelope2);
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId));
assertEquals(List.of(EnvelopeUtil.compress(envelope1), EnvelopeUtil.compress(envelope2)),
getStoredMessages(destinationUuid, deviceId));
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId),
assertEquals(List.of(EnvelopeUtil.compress(envelope1), EnvelopeUtil.compress(envelope2)),
getStoredMessages(destinationUuid, deviceId),
"Messages with same GUID should be deduplicated");
}

View File

@ -15,6 +15,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class MessagesCacheRemoveByGuidScriptTest {
@ -44,9 +45,8 @@ class MessagesCacheRemoveByGuidScriptTest {
assertEquals(1, removedMessages.size());
final MessageProtos.Envelope resultMessage = MessageProtos.Envelope.parseFrom(
removedMessages.getFirst());
final MessageProtos.Envelope resultMessage = MessageProtos.Envelope.parseFrom(removedMessages.getFirst());
assertEquals(serverGuid, UUID.fromString(resultMessage.getServerGuid()));
assertEquals(serverGuid, UUIDUtil.fromByteString(resultMessage.getServerGuidBinary()));
}
}

View File

@ -0,0 +1,268 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.IntStream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Object;
class PagedSingleUseKEMPreKeyStoreTest {
private static final int KEY_COUNT = 100;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private static final String BUCKET_NAME = "testbucket";
private PagedSingleUseKEMPreKeyStore keyStore;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS);
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension(BUCKET_NAME);
@BeforeEach
void setUp() {
keyStore = new PagedSingleUseKEMPreKeyStore(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
BUCKET_NAME);
}
@Test
void storeTake() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(Optional.empty(), keyStore.take(accountIdentifier, deviceId).join());
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join());
final List<KEMSignedPreKey> sortedPreKeys = preKeys.stream()
.sorted(Comparator.comparing(KEMSignedPreKey::keyId))
.toList();
assertEquals(Optional.of(sortedPreKeys.get(0)), keyStore.take(accountIdentifier, deviceId).join());
assertEquals(Optional.of(sortedPreKeys.get(1)), keyStore.take(accountIdentifier, deviceId).join());
}
@Test
void storeTwice() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
final List<KEMSignedPreKey> preKeys1 = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys1).join();
List<String> oldPages = listPages(accountIdentifier).stream().map(S3Object::key).toList();
assertEquals(1, oldPages.size());
final List<KEMSignedPreKey> preKeys2 = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys2).join();
List<String> newPages = listPages(accountIdentifier).stream().map(S3Object::key).toList();
assertEquals(1, newPages.size());
assertNotEquals(oldPages.getFirst(), newPages.getFirst());
assertEquals(
preKeys2.stream().sorted(Comparator.comparing(KEMSignedPreKey::keyId)).toList(),
IntStream.range(0, preKeys2.size())
.mapToObj(i -> keyStore.take(accountIdentifier, deviceId).join())
.map(Optional::orElseThrow)
.toList());
assertTrue(keyStore.take(accountIdentifier, deviceId).join().isEmpty());
}
@Test
void takeAll() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join());
final List<KEMSignedPreKey> sortedPreKeys = preKeys.stream()
.sorted(Comparator.comparing(KEMSignedPreKey::keyId))
.toList();
for (int i = 0; i < KEY_COUNT; i++) {
assertEquals(Optional.of(sortedPreKeys.get(i)), keyStore.take(accountIdentifier, deviceId).join());
}
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertTrue(keyStore.take(accountIdentifier, deviceId).join().isEmpty());
}
@Test
void getCount() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys).join();
assertEquals(KEY_COUNT, keyStore.getCount(accountIdentifier, deviceId).join());
for (int i = 0; i < KEY_COUNT; i++) {
keyStore.take(accountIdentifier, deviceId).join();
assertEquals(KEY_COUNT - (i + 1), keyStore.getCount(accountIdentifier, deviceId).join());
}
}
@Test
void deleteSingleDevice() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> keyStore.delete(accountIdentifier, deviceId).join());
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys).join();
keyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
assertDoesNotThrow(() -> keyStore.delete(accountIdentifier, deviceId).join());
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(KEY_COUNT, keyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
final List<S3Object> pages = listPages(accountIdentifier);
assertEquals(1, pages.size());
assertTrue(pages.getFirst().key().startsWith("%s/%s".formatted(accountIdentifier, deviceId + 1)));
}
@Test
void deleteAllDevices() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> keyStore.delete(accountIdentifier).join());
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys).join();
keyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
assertDoesNotThrow(() -> keyStore.delete(accountIdentifier).join());
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(0, keyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
assertEquals(0, listPages(accountIdentifier).size());
}
@Test
void listPages() {
final UUID aci1 = UUID.randomUUID();
final UUID aci2 = new UUID(aci1.getMostSignificantBits(), aci1.getLeastSignificantBits() + 1);
final byte deviceId = 1;
keyStore.store(aci1, deviceId, generateRandomPreKeys()).join();
keyStore.store(aci1, (byte) (deviceId + 1), generateRandomPreKeys()).join();
keyStore.store(aci2, deviceId, generateRandomPreKeys()).join();
List<DeviceKEMPreKeyPages> stored = keyStore.listStoredPages(1).collectList().block();
assertEquals(3, stored.size());
for (DeviceKEMPreKeyPages pages : stored) {
assertEquals(1, pages.pageIdToLastModified().size());
}
assertEquals(List.of(aci1, aci1, aci2), stored.stream().map(DeviceKEMPreKeyPages::identifier).toList());
assertEquals(
List.of(deviceId, (byte) (deviceId + 1), deviceId),
stored.stream().map(DeviceKEMPreKeyPages::deviceId).toList());
}
@Test
void listPagesWithOrphans() {
final UUID aci1 = UUID.randomUUID();
final UUID aci2 = new UUID(aci1.getMostSignificantBits(), aci1.getLeastSignificantBits() + 1);
final byte deviceId = 1;
// Two orphans
keyStore.store(aci1, deviceId, generateRandomPreKeys()).join();
writeOrphanedS3Object(aci1, deviceId);
writeOrphanedS3Object(aci1, deviceId);
// No orphans
keyStore.store(aci1, (byte) (deviceId + 1), generateRandomPreKeys()).join();
// One orphan
keyStore.store(aci2, deviceId, generateRandomPreKeys()).join();
writeOrphanedS3Object(aci2, deviceId);
// Orphan with no database record
writeOrphanedS3Object(aci2, (byte) (deviceId + 2));
List<DeviceKEMPreKeyPages> stored = keyStore.listStoredPages(1).collectList().block();
assertEquals(4, stored.size());
assertEquals(
List.of(3, 1, 2, 1),
stored.stream().map(s -> s.pageIdToLastModified().size()).toList());
}
private void writeOrphanedS3Object(final UUID identifier, final byte deviceId) {
S3_EXTENSION.getS3Client()
.putObject(PutObjectRequest.builder()
.bucket(BUCKET_NAME)
.key("%s/%s/%s".formatted(identifier, deviceId, UUID.randomUUID())).build(),
AsyncRequestBody.fromBytes(TestRandomUtil.nextBytes(10)))
.join();
}
private List<S3Object> listPages(final UUID identifier) {
return Flux.from(S3_EXTENSION.getS3Client().listObjectsV2Paginator(ListObjectsV2Request.builder()
.bucket(BUCKET_NAME)
.prefix(identifier.toString())
.build()))
.concatMap(response -> Flux.fromIterable(response.contents()))
.collectList()
.block();
}
private List<KEMSignedPreKey> generateRandomPreKeys() {
final Set<Integer> keyIds = new HashSet<>(KEY_COUNT);
while (keyIds.size() < KEY_COUNT) {
keyIds.add(Math.abs(ThreadLocalRandom.current().nextInt()));
}
return keyIds.stream()
.map(keyId -> KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR))
.toList();
}
}

View File

@ -0,0 +1,93 @@
/*
* Copyright 2021-2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3;
import java.util.Objects;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.testcontainers.containers.localstack.LocalStackContainer;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.services.s3.model.DeleteBucketRequest;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
@Testcontainers
public class S3LocalStackExtension implements BeforeEachCallback, AfterEachCallback, BeforeAllCallback,
AfterAllCallback {
private final static DockerImageName LOCAL_STACK_IMAGE =
DockerImageName.parse(Objects.requireNonNull(
System.getProperty("localstackImage"),
"Local stack image not found; must provide localstackImage system property"));
private static LocalStackContainer LOCAL_STACK = new LocalStackContainer(LOCAL_STACK_IMAGE).withServices(S3);
private final String bucketName;
private S3AsyncClient s3Client;
public S3LocalStackExtension(final String bucketName) {
this.bucketName = bucketName;
}
@Override
public void afterEach(ExtensionContext context) {
Flux.from(s3Client.listObjectsV2Paginator(ListObjectsV2Request.builder()
.bucket(bucketName)
.build())
.contents())
.flatMap(obj -> Mono.fromFuture(() -> s3Client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucketName)
.key(obj.key())
.build())), 100)
.then()
.block();
s3Client.deleteBucket(DeleteBucketRequest.builder().bucket(bucketName).build()).join();
}
@Override
public void beforeEach(ExtensionContext context) throws Exception {
s3Client.createBucket(CreateBucketRequest.builder().bucket(bucketName).build()).join();
}
public S3AsyncClient getS3Client() {
return s3Client;
}
@Override
public void afterAll(final ExtensionContext context) throws Exception {
s3Client.close();
LOCAL_STACK.close();
}
@Override
public void beforeAll(final ExtensionContext context) throws Exception {
LOCAL_STACK.start();
s3Client = S3AsyncClient.builder()
.endpointOverride(LOCAL_STACK.getEndpoint())
.credentialsProvider(StaticCredentialsProvider
.create(AwsBasicCredentials.create(LOCAL_STACK.getAccessKey(), LOCAL_STACK.getSecretKey())))
.region(Region.of(LOCAL_STACK.getRegion()))
.build();
}
public String getBucketName() {
return bucketName;
}
}

View File

@ -0,0 +1,53 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.apple.foundationdb.Database;
import com.apple.foundationdb.FDB;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Manages the lifecycle of a database connected to a FoundationDB instance running as an external service container.
*/
class ServiceContainerFoundationDbDatabaseLifecycleManager implements FoundationDbDatabaseLifecycleManager {
private final String foundationDbServiceContainerName;
private Database database;
private static final Logger log = LoggerFactory.getLogger(ServiceContainerFoundationDbDatabaseLifecycleManager.class);
ServiceContainerFoundationDbDatabaseLifecycleManager(final String foundationDbServiceContainerName) {
log.info("Using FoundationDB service container: {}", foundationDbServiceContainerName);
this.foundationDbServiceContainerName = foundationDbServiceContainerName;
}
@Override
public void initializeDatabase(final FDB fdb) throws IOException {
final File clusterFile = File.createTempFile("fdb.cluster", "");
clusterFile.deleteOnExit();
try (final FileWriter fileWriter = new FileWriter(clusterFile)) {
fileWriter.write(String.format("docker:docker@%s:4500", foundationDbServiceContainerName));
}
database = fdb.open(clusterFile.getAbsolutePath());
}
@Override
public Database getDatabase() {
return database;
}
@Override
public void closeDatabase() {
database.close();
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.apple.foundationdb.Database;
import com.apple.foundationdb.FDB;
import earth.adi.testcontainers.containers.FoundationDBContainer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testcontainers.utility.DockerImageName;
class TestcontainersFoundationDbDatabaseLifecycleManager implements FoundationDbDatabaseLifecycleManager {
private FoundationDBContainer foundationDBContainer;
private Database database;
private static final String FOUNDATIONDB_IMAGE_NAME = "foundationdb/foundationdb:" + FoundationDbVersion.getFoundationDbVersion();
private static final Logger log = LoggerFactory.getLogger(TestcontainersFoundationDbDatabaseLifecycleManager.class);
@Override
public void initializeDatabase(final FDB fdb) {
log.info("Using Testcontainers FoundationDB container: {}", FOUNDATIONDB_IMAGE_NAME);
foundationDBContainer = new FoundationDBContainer(DockerImageName.parse(FOUNDATIONDB_IMAGE_NAME));
foundationDBContainer.start();
database = fdb.open(foundationDBContainer.getClusterFilePath());
}
@Override
public Database getDatabase() {
return database;
}
@Override
public void closeDatabase() {
database.close();
foundationDBContainer.close();
}
}

View File

@ -42,6 +42,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.storage.SubscriptionException;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.MockUtils;
@ -191,6 +192,15 @@ class GooglePlayBillingManagerTest {
verifyNoInteractions(cancel);
}
@Test
public void handle429() throws IOException {
final HttpResponseException mockException = mock(HttpResponseException.class);
when(mockException.getStatusCode()).thenReturn(429);
when(subscriptionsv2Get.execute()).thenThrow(mockException);
CompletableFutureTestUtil.assertFailsWithCause(
RateLimitExceededException.class, googlePlayBillingManager.getSubscriptionInformation(PURCHASE_TOKEN));
}
@Test
public void getReceiptUnacknowledged() throws IOException {
when(subscriptionsv2Get.execute()).thenReturn(new SubscriptionPurchaseV2()

View File

@ -0,0 +1,138 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.workers;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import io.dropwizard.core.setup.Environment;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import net.sourceforge.argparse4j.inf.Namespace;
import org.assertj.core.api.Assertions;
import org.junit.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.storage.DeviceKEMPreKeyPages;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.TestClock;
import reactor.core.publisher.Flux;
public class RemoveOrphanedPreKeyPagesCommandTest {
@ParameterizedTest
@ValueSource(booleans = {true, false})
public void removeStalePages(boolean dryRun) throws Exception {
final TestClock clock = TestClock.pinned(Instant.EPOCH.plus(Duration.ofSeconds(10)));
final KeysManager keysManager = mock(KeysManager.class);
final UUID currentPage = UUID.randomUUID();
final UUID freshOrphanedPage = UUID.randomUUID();
final UUID staleOrphanedPage = UUID.randomUUID();
when(keysManager.listStoredKEMPreKeyPages(anyInt())).thenReturn(Flux.fromIterable(List.of(
new DeviceKEMPreKeyPages(UUID.randomUUID(), (byte) 1, Optional.of(currentPage), Map.of(
currentPage, Instant.EPOCH,
staleOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(4)),
freshOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(5)))))));
when(keysManager.pruneDeadPage(any(), anyByte(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
runCommand(clock, Duration.ofSeconds(5), dryRun, keysManager);
verify(keysManager, times(dryRun ? 0 : 1))
.pruneDeadPage(any(), eq((byte) 1), eq(staleOrphanedPage));
verify(keysManager, times(1)).listStoredKEMPreKeyPages(anyInt());
verifyNoMoreInteractions(keysManager);
}
@Test
public void noCurrentPage() throws Exception {
final TestClock clock = TestClock.pinned(Instant.EPOCH.plus(Duration.ofSeconds(10)));
final KeysManager keysManager = mock(KeysManager.class);
final UUID freshOrphanedPage = UUID.randomUUID();
final UUID staleOrphanedPage = UUID.randomUUID();
when(keysManager.listStoredKEMPreKeyPages(anyInt())).thenReturn(Flux.fromIterable(List.of(
new DeviceKEMPreKeyPages(UUID.randomUUID(), (byte) 1, Optional.empty(), Map.of(
staleOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(4)),
freshOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(5)))))));
when(keysManager.pruneDeadPage(any(), anyByte(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
runCommand(clock, Duration.ofSeconds(5), false, keysManager);
verify(keysManager, times(1))
.pruneDeadPage(any(), eq((byte) 1), eq(staleOrphanedPage));
verify(keysManager, times(1)).listStoredKEMPreKeyPages(anyInt());
verifyNoMoreInteractions(keysManager);
}
@Test
public void noPages() throws Exception {
final TestClock clock = TestClock.pinned(Instant.EPOCH);
final KeysManager keysManager = mock(KeysManager.class);
when(keysManager.listStoredKEMPreKeyPages(anyInt())).thenReturn(Flux.empty());
runCommand(clock, Duration.ofSeconds(5), false, keysManager);
verify(keysManager).listStoredKEMPreKeyPages(anyInt());
verifyNoMoreInteractions(keysManager);
}
private enum PageStatus {NO_CURRENT, MATCH_CURRENT, MISMATCH_CURRENT}
@CartesianTest
void shouldDeletePage(
@CartesianTest.Enum final PageStatus pageStatus,
@CartesianTest.Values(booleans = {false, true}) final boolean isOld) {
final Optional<UUID> currentPage = pageStatus == PageStatus.NO_CURRENT
? Optional.empty()
: Optional.of(UUID.randomUUID());
final UUID page = switch (pageStatus) {
case MATCH_CURRENT -> currentPage.orElseThrow();
case NO_CURRENT, MISMATCH_CURRENT -> UUID.randomUUID();
};
final Instant threshold = Instant.EPOCH.plus(Duration.ofSeconds(10));
final Instant lastModified = isOld ? threshold.minus(Duration.ofSeconds(1)) : threshold;
final boolean shouldDelete = pageStatus != PageStatus.MATCH_CURRENT && isOld;
Assertions.assertThat(RemoveOrphanedPreKeyPagesCommand.shouldDeletePage(currentPage, page, threshold, lastModified))
.isEqualTo(shouldDelete);
}
private void runCommand(final Clock clock, final Duration minimumOrphanAge, final boolean dryRun,
final KeysManager keysManager) throws Exception {
final CommandDependencies commandDependencies = mock(CommandDependencies.class);
when(commandDependencies.keysManager()).thenReturn(keysManager);
final Namespace namespace = mock(Namespace.class);
when(namespace.getBoolean(RemoveOrphanedPreKeyPagesCommand.DRY_RUN_ARGUMENT)).thenReturn(dryRun);
when(namespace.getInt(RemoveOrphanedPreKeyPagesCommand.CONCURRENCY_ARGUMENT)).thenReturn(2);
when(namespace.getString(RemoveOrphanedPreKeyPagesCommand.MINIMUM_ORPHAN_AGE_ARGUMENT))
.thenReturn(minimumOrphanAge.toString());
final RemoveOrphanedPreKeyPagesCommand command = new RemoveOrphanedPreKeyPagesCommand(clock);
command.run(mock(Environment.class), namespace, mock(WhisperServerConfiguration.class), commandDependencies);
}
}

View File

@ -53,6 +53,9 @@ directoryV2.client.userIdTokenSharedSecret: bbcdefghijklmnopqrstuvwxyz0123456789
svr2.userAuthenticationTokenSharedSecret: abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG= # base64-encoded secret shared with SVR2 to generate auth tokens for Signal users
svr2.userIdTokenSharedSecret: bbcdefghijklmnopqrstuvwxyz0123456789ABCDEFG= # base64-encoded secret shared with SVR2 to generate auth identity tokens for Signal users
svrb.userAuthenticationTokenSharedSecret: abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG= # base64-encoded secret shared with SVRB to generate auth tokens for Signal users
svrb.userIdTokenSharedSecret: bbcdefghijklmnopqrstuvwxyz0123456789ABCDEFG= # base64-encoded secret shared with SVRB to generate auth identity tokens for Signal users
tus.userAuthenticationTokenSharedSecret: abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG=
# The below private key was key generated exclusively for testing purposes. Do not use it in any other context.

View File

@ -135,6 +135,8 @@ dynamoDbTables:
tableName: repeated_use_signed_ec_pre_keys_test
pqKeys:
tableName: pq_keys_test
pagedPqKeys:
tableName: paged_pq_keys_test
pqLastResortKeys:
tableName: repeated_use_signed_kem_pre_keys_test
messages:
@ -171,6 +173,10 @@ dynamoDbTables:
verificationSessions:
tableName: verification_sessions_test
pagedSingleUseKEMPreKeyStore:
bucket: preKeyBucket # S3 Bucket name
region: us-west-2 # AWS region
cacheCluster: # Redis server configuration for cache cluster
type: local
@ -217,6 +223,35 @@ svr2:
9Kxq0DY7RCEpdHMCKcOL
-----END CERTIFICATE-----
svrb:
uri: svrb.example.com
userAuthenticationTokenSharedSecret: secret://svrb.userAuthenticationTokenSharedSecret
userIdTokenSharedSecret: secret://svrb.userIdTokenSharedSecret
svrCaCertificates:
# this is a randomly generated test certificate
- |
-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUW5lcNWkuynRVc8Rq5pO6mHQBuZAwDQYJKoZIhvcNAQEL
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNDAzMjUwMzE4MTNaFw0yOTAz
MjQwMzE4MTNaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
AQUAA4IBDwAwggEKAoIBAQCfH4Um+fv2r4KudhD37/UXp8duRLTmp4XvpBTpDHpD
2HF8p2yThVKlJnMkP/9Ey1Rb0vhxO7DCltLdW8IYcxJuHoyMvyhGUEtxxkOZbrk8
ciUR9jTZ37x7vXRGj/RxcdlS6iD0MeF0D/LAkImt4T/kiKwDbENrVEnYWJmipCKP
ribxWky7HqxDCoYMQr0zatxB3A9mx5stH+H3kbw3CZcm+ugF9ZIKDEVHb0lf28gq
llmD120q/vs9YV3rzVL7sBGDqf6olkulvHQJKElZg2rdcHWFcngSlU2BjR04oyuH
c/SSiLSB3YB0tdFGta5uorXyV1y7RElPeBfOfvEjsG3TAgMBAAGjUzBRMB0GA1Ud
DgQWBBQX+xlgSWWbDjv0SrJ+h67xauJ80zAfBgNVHSMEGDAWgBQX+xlgSWWbDjv0
SrJ+h67xauJ80zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAw
ZG2MCCjscn6h/QOoJU+IDfa68OqLq0I37gMnLMde4yEhAmm//miePIq4Uz9GRJ+h
rAmdEnspKgyQ93PjF7Xpk/JdJA4B1bIrsOl/cSwqx2sFhRt8Kt1DHGlGWXqOaHRP
UkZ86MyRL3sXly6WkxEYxZJeQaOzMy2XmQh7grzrlTBuSI+0xf7vsRRDipxr6LVQ
6qGWyGODLLc2JD1IXj/1HpRVT2LoGGlKMuyxACQAm4oak1vvJ9mGxgfd9AU+eo58
O/esB2Eaf+QqMPELdFSZQfG2jvp+3WQTZK8fDKHyLr076G3UetEMy867F6fzTSZd
9Kxq0DY7RCEpdHMCKcOL
-----END CERTIFICATE-----
messageCache: # Redis server configuration for message store cache
persistDelayMinutes: 1
cluster:

@ -1 +1 @@
Subproject commit 915fcf6626862bd4070dd577b5e007bb5907ef79
Subproject commit 01c0bbfd6253e9532802e789dd801e9673ae4c2c