diff --git a/pom.xml b/pom.xml index e386469b5..1ddf8211a 100644 --- a/pom.xml +++ b/pom.xml @@ -36,6 +36,7 @@ 2.9.0 2.0.22 1.1.13 + 2.8.8 30.1.1-jre 2.3.1 2.9.0 @@ -53,6 +54,7 @@ 1.5.0 3.1.0 1.7.30 + 20.79.0 UTF-8 @@ -241,6 +243,16 @@ 9.2 test + + com.stripe + stripe-java + ${stripe.version} + + + com.google.code.gson + gson + ${gson.version} + diff --git a/service/config/sample.yml b/service/config/sample.yml index 5fe26335b..96497bfd9 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -1,3 +1,21 @@ +stripe: + apiKey: + idempotencyKeyGenerator: + +dynamoDbClientConfiguration: + region: # AWS Region + +dynamoDbTables: + issuedReceipts: + tableName: # DDB Table Name + expiration: # Duration of time until rows expire + generator: # random binary sequence + redeemedReceipts: + tableName: # DDB Table Name + expiration: # Duration of time until rows expire + subscriptions: + tableName: # DDB Table Name + twilio: # Twilio gateway configuration accountId: accountToken: @@ -249,7 +267,6 @@ asnTable: donation: uri: # value - apiKey: # value supportedCurrencies: - # 1st supported currency - # 2nd supported currency diff --git a/service/pom.xml b/service/pom.xml index 6ffa1ca23..306da335d 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -449,6 +449,11 @@ google-cloud-recaptchaenterprise + + com.stripe + stripe-java + + pl.pragmatists JUnitParams diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index 9ae7a5dec..2e347b1c1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -26,7 +26,9 @@ import org.whispersystems.textsecuregcm.configuration.DatadogConfiguration; import org.whispersystems.textsecuregcm.configuration.DeletedAccountsDynamoDbConfiguration; import org.whispersystems.textsecuregcm.configuration.DirectoryConfiguration; import org.whispersystems.textsecuregcm.configuration.DonationConfiguration; +import org.whispersystems.textsecuregcm.configuration.DynamoDbClientConfiguration; import org.whispersystems.textsecuregcm.configuration.DynamoDbConfiguration; +import org.whispersystems.textsecuregcm.configuration.DynamoDbTables; import org.whispersystems.textsecuregcm.configuration.GcmConfiguration; import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration; import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration; @@ -38,12 +40,13 @@ import org.whispersystems.textsecuregcm.configuration.PushConfiguration; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration; import org.whispersystems.textsecuregcm.configuration.RecaptchaConfiguration; import org.whispersystems.textsecuregcm.configuration.RecaptchaV2Configuration; -import org.whispersystems.textsecuregcm.configuration.RedeemedReceiptsDynamoDbConfiguration; import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration; import org.whispersystems.textsecuregcm.configuration.RedisConfiguration; import org.whispersystems.textsecuregcm.configuration.RemoteConfigConfiguration; import org.whispersystems.textsecuregcm.configuration.SecureBackupServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration; +import org.whispersystems.textsecuregcm.configuration.StripeConfiguration; +import org.whispersystems.textsecuregcm.configuration.SubscriptionConfiguration; import org.whispersystems.textsecuregcm.configuration.TestDeviceConfiguration; import org.whispersystems.textsecuregcm.configuration.TurnConfiguration; import org.whispersystems.textsecuregcm.configuration.TwilioConfiguration; @@ -55,6 +58,21 @@ import org.whispersystems.websocket.configuration.WebSocketConfiguration; /** @noinspection MismatchedQueryAndUpdateOfCollection, WeakerAccess */ public class WhisperServerConfiguration extends Configuration { + @NotNull + @Valid + @JsonProperty + private StripeConfiguration stripe; + + @NotNull + @Valid + @JsonProperty + private DynamoDbClientConfiguration dynamoDbClientConfiguration; + + @NotNull + @Valid + @JsonProperty + private DynamoDbTables dynamoDbTables; + @NotNull @Valid @JsonProperty @@ -155,11 +173,6 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private DynamoDbConfiguration deletedAccountsLockDynamoDb; - @Valid - @NotNull - @JsonProperty - private RedeemedReceiptsDynamoDbConfiguration redeemedReceiptsDynamoDb; - @Valid @NotNull @JsonProperty @@ -300,8 +313,25 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private BadgesConfiguration badges; + @Valid + @JsonProperty + // TODO: Mark as @NotNull when enabled for production. + private SubscriptionConfiguration subscription; + private Map transparentDataIndex = new HashMap<>(); + public StripeConfiguration getStripe() { + return stripe; + } + + public DynamoDbClientConfiguration getDynamoDbClientConfiguration() { + return dynamoDbClientConfiguration; + } + + public DynamoDbTables getDynamoDbTables() { + return dynamoDbTables; + } + public RecaptchaConfiguration getRecaptchaConfiguration() { return recaptcha; } @@ -398,10 +428,6 @@ public class WhisperServerConfiguration extends Configuration { return deletedAccountsLockDynamoDb; } - public RedeemedReceiptsDynamoDbConfiguration getRedeemedReceiptsDynamoDbConfiguration() { - return redeemedReceiptsDynamoDb; - } - public DatabaseConfiguration getAbuseDatabaseConfiguration() { return abuseDatabase; } @@ -515,4 +541,8 @@ public class WhisperServerConfiguration extends Configuration { public BadgesConfiguration getBadges() { return badges; } + + public SubscriptionConfiguration getSubscription() { + return subscription; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 77014ed16..f1fb4bde3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.DeserializationFeature; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import io.dropwizard.Application; import io.dropwizard.auth.AuthFilter; import io.dropwizard.auth.PolymorphicAuthDynamicFeature; @@ -91,6 +92,7 @@ import org.whispersystems.textsecuregcm.controllers.RemoteConfigController; import org.whispersystems.textsecuregcm.controllers.SecureBackupController; import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.StickerController; +import org.whispersystems.textsecuregcm.controllers.SubscriptionController; import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController; import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager; import org.whispersystems.textsecuregcm.currency.FixerClient; @@ -167,6 +169,7 @@ import org.whispersystems.textsecuregcm.storage.DirectoryReconciler; import org.whispersystems.textsecuregcm.storage.DirectoryReconciliationClient; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; +import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager; import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagePersister; import org.whispersystems.textsecuregcm.storage.MessagesCache; @@ -185,9 +188,11 @@ import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReservedUsernames; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; +import org.whispersystems.textsecuregcm.storage.SubscriptionManager; import org.whispersystems.textsecuregcm.storage.Usernames; import org.whispersystems.textsecuregcm.storage.UsernamesManager; import org.whispersystems.textsecuregcm.storage.VerificationCodeStore; +import org.whispersystems.textsecuregcm.stripe.StripeManager; import org.whispersystems.textsecuregcm.util.AsnManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig; @@ -250,10 +255,9 @@ public class WhisperServerService extends Application commonControllers = List.of( + final List commonControllers = Lists.newArrayList( new AttachmentControllerV1(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getBucket()), new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()), new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()), @@ -604,7 +622,7 @@ public class WhisperServerService extends Application supportedCurrencies; private CircuitBreakerConfiguration circuitBreaker = new CircuitBreakerConfiguration(); @@ -32,17 +31,6 @@ public class DonationConfiguration { this.uri = uri; } - @JsonProperty - @NotEmpty - public String getApiKey() { - return apiKey; - } - - @VisibleForTesting - public void setApiKey(final String apiKey) { - this.apiKey = apiKey; - } - @JsonProperty public String getDescription() { return description; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbClientConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbClientConfiguration.java new file mode 100644 index 000000000..3a7c84a7c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbClientConfiguration.java @@ -0,0 +1,41 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.time.Duration; +import javax.validation.constraints.NotEmpty; + +public class DynamoDbClientConfiguration { + + private final String region; + private final Duration clientExecutionTimeout; + private final Duration clientRequestTimeout; + + @JsonCreator + public DynamoDbClientConfiguration( + @JsonProperty("region") final String region, + @JsonProperty("clientExcecutionTimeout") final Duration clientExecutionTimeout, + @JsonProperty("clientRequestTimeout") final Duration clientRequestTimeout) { + this.region = region; + this.clientExecutionTimeout = clientExecutionTimeout != null ? clientExecutionTimeout : Duration.ofSeconds(30); + this.clientRequestTimeout = clientRequestTimeout != null ? clientRequestTimeout : Duration.ofSeconds(10); + } + + @NotEmpty + public String getRegion() { + return region; + } + + public Duration getClientExecutionTimeout() { + return clientExecutionTimeout; + } + + public Duration getClientRequestTimeout() { + return clientRequestTimeout; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java new file mode 100644 index 000000000..d4e6b9ff0 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java @@ -0,0 +1,80 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.time.Duration; +import javax.validation.Valid; +import javax.validation.constraints.NotEmpty; +import javax.validation.constraints.NotNull; + +public class DynamoDbTables { + + public static class Table { + private final String tableName; + + @JsonCreator + public Table( + @JsonProperty("tableName") final String tableName) { + this.tableName = tableName; + } + + @NotEmpty + public String getTableName() { + return tableName; + } + } + + public static class TableWithExpiration extends Table { + private final Duration expiration; + + @JsonCreator + public TableWithExpiration( + @JsonProperty("tableName") final String tableName, + @JsonProperty("expiration") final Duration expiration) { + super(tableName); + this.expiration = expiration; + } + + @NotNull + public Duration getExpiration() { + return expiration; + } + } + + private final IssuedReceiptsTableConfiguration issuedReceipts; + private final TableWithExpiration redeemedReceipts; + private final Table subscriptions; + + @JsonCreator + public DynamoDbTables( + @JsonProperty("issuedReceipts") final IssuedReceiptsTableConfiguration issuedReceipts, + @JsonProperty("redeemedReceipts") final TableWithExpiration redeemedReceipts, + @JsonProperty("subscriptions") final Table subscriptions) { + this.issuedReceipts = issuedReceipts; + this.redeemedReceipts = redeemedReceipts; + this.subscriptions = subscriptions; + } + + @Valid + @NotNull + public IssuedReceiptsTableConfiguration getIssuedReceipts() { + return issuedReceipts; + } + + @Valid + @NotNull + public TableWithExpiration getRedeemedReceipts() { + return redeemedReceipts; + } + + @Valid + @NotNull + public Table getSubscriptions() { + return subscriptions; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/IssuedReceiptsTableConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/IssuedReceiptsTableConfiguration.java new file mode 100644 index 000000000..e7969f521 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/IssuedReceiptsTableConfiguration.java @@ -0,0 +1,28 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.time.Duration; +import javax.validation.constraints.NotEmpty; + +public class IssuedReceiptsTableConfiguration extends DynamoDbTables.TableWithExpiration { + + private final byte[] generator; + + public IssuedReceiptsTableConfiguration( + @JsonProperty("tableName") final String tableName, + @JsonProperty("expiration") final Duration expiration, + @JsonProperty("generator") final byte[] generator) { + super(tableName, expiration); + this.generator = generator; + } + + @NotEmpty + public byte[] getGenerator() { + return generator; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedeemedReceiptsDynamoDbConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedeemedReceiptsDynamoDbConfiguration.java deleted file mode 100644 index 972297ff7..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedeemedReceiptsDynamoDbConfiguration.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.configuration; - -import com.fasterxml.jackson.annotation.JsonProperty; -import java.time.Duration; -import javax.validation.constraints.NotNull; - -public class RedeemedReceiptsDynamoDbConfiguration extends DynamoDbConfiguration { - - private Duration expirationTime; - - @NotNull - @JsonProperty - public Duration getExpirationTime() { - return expirationTime; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/StripeConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/StripeConfiguration.java new file mode 100644 index 000000000..20dc192d1 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/StripeConfiguration.java @@ -0,0 +1,34 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import javax.validation.constraints.NotEmpty; + +public class StripeConfiguration { + + private final String apiKey; + private final byte[] idempotencyKeyGenerator; + + @JsonCreator + public StripeConfiguration( + @JsonProperty("apiKey") final String apiKey, + @JsonProperty("idempotencyKeyGenerator") final byte[] idempotencyKeyGenerator) { + this.apiKey = apiKey; + this.idempotencyKeyGenerator = idempotencyKeyGenerator; + } + + @NotEmpty + public String getApiKey() { + return apiKey; + } + + @NotEmpty + public byte[] getIdempotencyKeyGenerator() { + return idempotencyKeyGenerator; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SubscriptionConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SubscriptionConfiguration.java new file mode 100644 index 000000000..a9115ab55 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SubscriptionConfiguration.java @@ -0,0 +1,52 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.dropwizard.validation.ValidationMethod; +import java.time.Duration; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import javax.validation.Valid; +import javax.validation.constraints.Min; +import javax.validation.constraints.NotNull; + +public class SubscriptionConfiguration { + + private final Duration badgeGracePeriod; + private final Map levels; + + @JsonCreator + public SubscriptionConfiguration( + @JsonProperty("badgeGracePeriod") @Valid Duration badgeGracePeriod, + @JsonProperty("levels") @Valid Map<@NotNull @Min(1) Long, @NotNull @Valid SubscriptionLevelConfiguration> levels) { + this.badgeGracePeriod = badgeGracePeriod; + this.levels = levels; + } + + public Duration getBadgeGracePeriod() { + return badgeGracePeriod; + } + + public Map getLevels() { + return levels; + } + + @JsonIgnore + @ValidationMethod(message = "has a mismatch between the levels supported currencies") + public boolean isCurrencyListSameAcrossAllLevels() { + Optional any = levels.values().stream().findAny(); + if (any.isEmpty()) { + return true; + } + + Set currencies = any.get().getPrices().keySet(); + return levels.values().stream().allMatch(level -> currencies.equals(level.getPrices().keySet())); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SubscriptionLevelConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SubscriptionLevelConfiguration.java new file mode 100644 index 000000000..0f6397984 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SubscriptionLevelConfiguration.java @@ -0,0 +1,42 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; +import javax.validation.Valid; +import javax.validation.constraints.NotEmpty; +import javax.validation.constraints.NotNull; + +public class SubscriptionLevelConfiguration { + + private final String badge; + private final String product; + private final Map prices; + + @JsonCreator + public SubscriptionLevelConfiguration( + @JsonProperty("badge") @NotEmpty String badge, + @JsonProperty("product") @NotEmpty String product, + @JsonProperty("prices") @Valid Map<@NotEmpty String, @NotNull @Valid SubscriptionPriceConfiguration> prices) { + this.badge = badge; + this.product = product; + this.prices = prices; + } + + public String getBadge() { + return badge; + } + + public String getProduct() { + return product; + } + + public Map getPrices() { + return prices; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SubscriptionPriceConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SubscriptionPriceConfiguration.java new file mode 100644 index 000000000..3e2aaef1b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SubscriptionPriceConfiguration.java @@ -0,0 +1,35 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.math.BigDecimal; +import javax.validation.constraints.DecimalMin; +import javax.validation.constraints.NotEmpty; +import javax.validation.constraints.NotNull; + +public class SubscriptionPriceConfiguration { + + private final String id; + private final BigDecimal amount; + + @JsonCreator + public SubscriptionPriceConfiguration( + @JsonProperty("id") @NotEmpty String id, + @JsonProperty("amount") @NotNull @DecimalMin("0.01") BigDecimal amount) { + this.id = id; + this.amount = amount; + } + + public String getId() { + return id; + } + + public BigDecimal getAmount() { + return amount; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java index e8f5334f3..8dc4f7d5c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java @@ -52,6 +52,7 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; import org.whispersystems.textsecuregcm.configuration.DonationConfiguration; +import org.whispersystems.textsecuregcm.configuration.StripeConfiguration; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse; import org.whispersystems.textsecuregcm.entities.RedeemReceiptRequest; @@ -70,7 +71,7 @@ public class DonationController { ReceiptCredentialPresentation build(byte[] bytes) throws InvalidInputException; } - private final Logger logger = LoggerFactory.getLogger(DonationController.class); + private static final Logger logger = LoggerFactory.getLogger(DonationController.class); private final Clock clock; private final ServerZkReceiptOperations serverZkReceiptOperations; @@ -92,7 +93,8 @@ public class DonationController { @Nonnull final BadgesConfiguration badgesConfiguration, @Nonnull final ReceiptCredentialPresentationFactory receiptCredentialPresentationFactory, @Nonnull final Executor httpClientExecutor, - @Nonnull final DonationConfiguration configuration) { + @Nonnull final DonationConfiguration configuration, + @Nonnull final StripeConfiguration stripeConfiguration) { this.clock = Objects.requireNonNull(clock); this.serverZkReceiptOperations = Objects.requireNonNull(serverZkReceiptOperations); this.redeemedReceiptsManager = Objects.requireNonNull(redeemedReceiptsManager); @@ -100,7 +102,7 @@ public class DonationController { this.badgesConfiguration = Objects.requireNonNull(badgesConfiguration); this.receiptCredentialPresentationFactory = Objects.requireNonNull(receiptCredentialPresentationFactory); this.uri = URI.create(configuration.getUri()); - this.apiKey = configuration.getApiKey(); + this.apiKey = stripeConfiguration.getApiKey(); this.description = configuration.getDescription(); this.supportedCurrencies = configuration.getSupportedCurrencies(); this.httpClient = FaultTolerantHttpClient.newBuilder() diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java new file mode 100644 index 000000000..6a218d7b5 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java @@ -0,0 +1,689 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import com.codahale.metrics.annotation.Timed; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Strings; +import com.stripe.model.Invoice; +import com.stripe.model.InvoiceLineItem; +import com.stripe.model.Product; +import com.stripe.model.Subscription; +import io.dropwizard.auth.Auth; +import java.math.BigDecimal; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.time.Clock; +import java.time.Instant; +import java.util.Base64; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import javax.ws.rs.BadRequestException; +import javax.ws.rs.Consumes; +import javax.ws.rs.DELETE; +import javax.ws.rs.ForbiddenException; +import javax.ws.rs.GET; +import javax.ws.rs.InternalServerErrorException; +import javax.ws.rs.NotFoundException; +import javax.ws.rs.POST; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.Response.Status; +import org.signal.zkgroup.InvalidInputException; +import org.signal.zkgroup.VerificationFailedException; +import org.signal.zkgroup.receipts.ReceiptCredentialRequest; +import org.signal.zkgroup.receipts.ReceiptCredentialResponse; +import org.signal.zkgroup.receipts.ServerZkReceiptOperations; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.configuration.SubscriptionConfiguration; +import org.whispersystems.textsecuregcm.configuration.SubscriptionLevelConfiguration; +import org.whispersystems.textsecuregcm.configuration.SubscriptionPriceConfiguration; +import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager; +import org.whispersystems.textsecuregcm.storage.SubscriptionManager; +import org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult; +import org.whispersystems.textsecuregcm.stripe.StripeManager; +import org.whispersystems.textsecuregcm.util.ExactlySize; + +@Path("/v1/subscription") +public class SubscriptionController { + + private static final Logger logger = LoggerFactory.getLogger(SubscriptionController.class); + + private final Clock clock; + private final SubscriptionConfiguration config; + private final SubscriptionManager subscriptionManager; + private final StripeManager stripeManager; + private final ServerZkReceiptOperations zkReceiptOperations; + private final IssuedReceiptsManager issuedReceiptsManager; + + public SubscriptionController( + @Nonnull Clock clock, + @Nonnull SubscriptionConfiguration config, + @Nonnull SubscriptionManager subscriptionManager, + @Nonnull StripeManager stripeManager, + @Nonnull ServerZkReceiptOperations zkReceiptOperations, + @Nonnull IssuedReceiptsManager issuedReceiptsManager) { + this.clock = Objects.requireNonNull(clock); + this.config = Objects.requireNonNull(config); + this.subscriptionManager = Objects.requireNonNull(subscriptionManager); + this.stripeManager = Objects.requireNonNull(stripeManager); + this.zkReceiptOperations = Objects.requireNonNull(zkReceiptOperations); + this.issuedReceiptsManager = Objects.requireNonNull(issuedReceiptsManager); + } + + @Timed + @DELETE + @Path("/{subscriberId}") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public CompletableFuture deleteSubscriber( + @Auth Optional authenticatedAccount, + @PathParam("subscriberId") String subscriberId) { + RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); + return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) + .thenCompose(getResult -> { + if (getResult == GetResult.NOT_STORED || getResult == GetResult.PASSWORD_MISMATCH) { + throw new NotFoundException(); + } + String customerId = getResult.record.customerId; + if (Strings.isNullOrEmpty(customerId)) { + throw new InternalServerErrorException("no customer id found"); + } + return stripeManager.getCustomer(customerId).thenCompose(customer -> { + if (customer == null) { + throw new InternalServerErrorException("no customer record found for id " + customerId); + } + return stripeManager.listNonCanceledSubscriptions(customer); + }).thenCompose(subscriptions -> { + @SuppressWarnings("unchecked") + CompletableFuture[] futures = (CompletableFuture[]) subscriptions.stream() + .map(stripeManager::cancelSubscriptionAtEndOfCurrentPeriod).toArray(CompletableFuture[]::new); + return CompletableFuture.allOf(futures); + }); + }) + .thenCompose(unused -> subscriptionManager.canceledAt(requestData.subscriberUser, requestData.now)) + .thenApply(unused -> Response.ok().build()); + } + + @Timed + @PUT + @Path("/{subscriberId}") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public CompletableFuture updateSubscriber( + @Auth Optional authenticatedAccount, + @PathParam("subscriberId") String subscriberId) { + RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); + return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) + .thenCompose(getResult -> { + if (getResult == GetResult.PASSWORD_MISMATCH) { + throw new ForbiddenException("subscriberId mismatch"); + } else if (getResult == GetResult.NOT_STORED) { + // create a customer and write it to ddb + return stripeManager.createCustomer(requestData.subscriberUser).thenCompose( + customer -> subscriptionManager.create( + requestData.subscriberUser, requestData.hmac, customer.getId(), requestData.now) + .thenApply(updatedRecord -> { + if (updatedRecord == null) { + throw new NotFoundException(); + } + return updatedRecord; + })); + } else { + // already exists so just touch access time and return + return subscriptionManager.accessedAt(requestData.subscriberUser, requestData.now) + .thenApply(unused -> getResult.record); + } + }) + .thenApply(record -> Response.ok().build()); + } + + public static class CreatePaymentMethodResponse { + + private final String clientSecret; + + @JsonCreator + public CreatePaymentMethodResponse( + @JsonProperty("clientSecret") String clientSecret) { + this.clientSecret = clientSecret; + } + + @SuppressWarnings("unused") + public String getClientSecret() { + return clientSecret; + } + } + + @Timed + @POST + @Path("/{subscriberId}/create_payment_method") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public CompletableFuture createPaymentMethod( + @Auth Optional authenticatedAccount, + @PathParam("subscriberId") String subscriberId) { + RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); + return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) + .thenApply(this::requireRecordFromGetResult) + .thenCompose(record -> stripeManager.createSetupIntent(record.customerId)) + .thenApply(setupIntent -> Response.ok(new CreatePaymentMethodResponse(setupIntent.getClientSecret())).build()); + } + + public static class SetSubscriptionLevelSuccessResponse { + + private final long level; + + @JsonCreator + public SetSubscriptionLevelSuccessResponse( + @JsonProperty("level") long level) { + this.level = level; + } + + public long getLevel() { + return level; + } + } + + public static class SetSubscriptionLevelErrorResponse { + + public static class Error { + + public enum Type { + UNSUPPORTED_LEVEL, + UNSUPPORTED_CURRENCY, + } + + private final Type type; + private final String message; + + @JsonCreator + public Error( + @JsonProperty("type") Type type, + @JsonProperty("message") String message) { + this.type = type; + this.message = message; + } + + public Type getType() { + return type; + } + + public String getMessage() { + return message; + } + } + + private final List errors; + + @JsonCreator + public SetSubscriptionLevelErrorResponse( + @JsonProperty("errors") List errors) { + this.errors = errors; + } + + public List getErrors() { + return errors; + } + } + + @Timed + @PUT + @Path("/{subscriberId}/level/{level}/{currency}/{idempotencyKey}") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public CompletableFuture setSubscriptionLevel( + @Auth Optional authenticatedAccount, + @PathParam("subscriberId") String subscriberId, + @PathParam("level") long level, + @PathParam("currency") String currency, + @PathParam("idempotencyKey") String idempotencyKey) { + RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); + return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) + .thenApply(this::requireRecordFromGetResult) + .thenCompose(record -> { + SubscriptionLevelConfiguration levelConfiguration = config.getLevels().get(level); + if (levelConfiguration == null) { + throw new BadRequestException(Response.status(Status.BAD_REQUEST) + .entity(new SetSubscriptionLevelErrorResponse(List.of( + new SetSubscriptionLevelErrorResponse.Error( + SetSubscriptionLevelErrorResponse.Error.Type.UNSUPPORTED_LEVEL, null)))) + .build()); + } + SubscriptionPriceConfiguration priceConfiguration = levelConfiguration.getPrices() + .get(currency.toLowerCase(Locale.ROOT)); + if (priceConfiguration == null) { + throw new BadRequestException(Response.status(Status.BAD_REQUEST) + .entity(new SetSubscriptionLevelErrorResponse(List.of( + new SetSubscriptionLevelErrorResponse.Error( + SetSubscriptionLevelErrorResponse.Error.Type.UNSUPPORTED_CURRENCY, null)))) + .build()); + } + + if (record.subscriptionId == null) { + // we don't have one yet so create it and then record the subscription id + // + // this relies on stripe's idempotency key to avoid creating more than one subscription if the client + // retries this request + return stripeManager.createSubscription(record.customerId, priceConfiguration.getId(), level) + .thenCompose(subscription -> subscriptionManager.subscriptionCreated( + requestData.subscriberUser, subscription.getId(), requestData.now, level) + .thenApply(unused -> subscription)); + } else { + // we already have a subscription in our records so let's check the level and change it if needed + return stripeManager.getSubscription(record.subscriptionId).thenCompose( + subscription -> stripeManager.getLevelForSubscription(subscription).thenCompose(existingLevel -> { + if (level == existingLevel) { + return CompletableFuture.completedFuture(subscription); + } + return stripeManager.updateSubscription( + subscription, priceConfiguration.getId(), level, idempotencyKey) + .thenCompose(updatedSubscription -> + subscriptionManager.subscriptionLevelChanged(requestData.subscriberUser, requestData.now, level) + .thenApply(unused -> updatedSubscription)); + })); + } + }) + .thenApply(subscription -> Response.ok(new SetSubscriptionLevelSuccessResponse(level)).build()); + } + + public static class GetLevelsResponse { + + public static class Level { + + public static class Price { + + private final BigDecimal amount; + + @JsonCreator + public Price( + @JsonProperty("amount") BigDecimal amount) { + this.amount = amount; + } + + public BigDecimal getAmount() { + return amount; + } + } + + private final String badgeId; + private final Map currencies; + + @JsonCreator + public Level( + @JsonProperty("badgeId") String badgeId, + @JsonProperty("currencies") Map currencies) { + this.badgeId = badgeId; + this.currencies = currencies; + } + + public String getBadgeId() { + return badgeId; + } + + public Map getCurrencies() { + return currencies; + } + } + + private final Map levels; + + @JsonCreator + public GetLevelsResponse( + @JsonProperty("levels") Map levels) { + this.levels = levels; + } + + public Map getLevels() { + return levels; + } + } + + @Timed + @GET + @Path("/levels") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public CompletableFuture getLevels() { + return CompletableFuture.supplyAsync(() -> { + GetLevelsResponse getLevelsResponse = new GetLevelsResponse( + config.getLevels().entrySet().stream().collect(Collectors.toMap(Entry::getKey, + entry -> new GetLevelsResponse.Level(entry.getValue().getBadge(), + entry.getValue().getPrices().entrySet().stream().collect( + Collectors.toMap(levelEntry -> levelEntry.getKey().toUpperCase(Locale.ROOT), + levelEntry -> new GetLevelsResponse.Level.Price(levelEntry.getValue().getAmount()))))))); + return Response.ok(getLevelsResponse).build(); + }); + } + + public static class GetSubscriptionInformationResponse { + + public static class Subscription { + + private final long level; + private final Instant billingCycleAnchor; + private final Instant endOfCurrentPeriod; + private final boolean active; + private final boolean cancelAtPeriodEnd; + private final String currency; + private final BigDecimal amount; + + public Subscription( + @JsonProperty("level") long level, + @JsonProperty("billingCycleAnchor") Instant billingCycleAnchor, + @JsonProperty("endOfCurrentPeriod") Instant endOfCurrentPeriod, + @JsonProperty("active") boolean active, + @JsonProperty("cancelAtPeriodEnd") boolean cancelAtPeriodEnd, + @JsonProperty("currency") String currency, + @JsonProperty("amount") BigDecimal amount) { + this.level = level; + this.billingCycleAnchor = billingCycleAnchor; + this.endOfCurrentPeriod = endOfCurrentPeriod; + this.active = active; + this.cancelAtPeriodEnd = cancelAtPeriodEnd; + this.currency = currency; + this.amount = amount; + } + + public long getLevel() { + return level; + } + + public Instant getBillingCycleAnchor() { + return billingCycleAnchor; + } + + public Instant getEndOfCurrentPeriod() { + return endOfCurrentPeriod; + } + + public boolean isActive() { + return active; + } + + public boolean isCancelAtPeriodEnd() { + return cancelAtPeriodEnd; + } + + public String getCurrency() { + return currency; + } + + public BigDecimal getAmount() { + return amount; + } + } + + private final Subscription subscription; + + @JsonCreator + public GetSubscriptionInformationResponse( + @JsonProperty("subscription") Subscription subscription) { + this.subscription = subscription; + } + + public Subscription getSubscription() { + return subscription; + } + } + + @Timed + @GET + @Path("/{subscriberId}") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public CompletableFuture getSubscriptionInformation( + @Auth Optional authenticatedAccount, + @PathParam("subscriberId") String subscriberId) { + RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); + return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) + .thenApply(this::requireRecordFromGetResult) + .thenCompose(record -> { + if (record.subscriptionId == null) { + return CompletableFuture.completedFuture(Response.ok(new GetSubscriptionInformationResponse(null)).build()); + } + return stripeManager.getSubscription(record.subscriptionId).thenCompose(subscription -> + stripeManager.getPriceForSubscription(subscription).thenCompose(price -> + stripeManager.getLevelForPrice(price).thenApply(level -> Response.ok( + new GetSubscriptionInformationResponse(new GetSubscriptionInformationResponse.Subscription( + level, + Instant.ofEpochSecond(subscription.getBillingCycleAnchor()), + Instant.ofEpochSecond(subscription.getCurrentPeriodEnd()), + Objects.equals(subscription.getStatus(), "active"), + subscription.getCancelAtPeriodEnd(), + price.getCurrency().toUpperCase(Locale.ROOT), + price.getUnitAmountDecimal() + ))).build()))); + }); + } + + public static class GetReceiptCredentialsRequest { + private final byte[] receiptCredentialRequest; + + @JsonCreator + public GetReceiptCredentialsRequest( + @JsonProperty("receiptCredentialRequest") byte[] receiptCredentialRequest) { + this.receiptCredentialRequest = receiptCredentialRequest; + } + + @ExactlySize(ReceiptCredentialRequest.SIZE) + public byte[] getReceiptCredentialRequest() { + return receiptCredentialRequest; + } + } + + public static class GetReceiptCredentialsResponse { + private final byte[] receiptCredentialResponse; + + @JsonCreator + public GetReceiptCredentialsResponse( + @JsonProperty("receiptCredentialResponse") byte[] receiptCredentialResponse) { + this.receiptCredentialResponse = receiptCredentialResponse; + } + + @ExactlySize(ReceiptCredentialResponse.SIZE) + public byte[] getReceiptCredentialResponse() { + return receiptCredentialResponse; + } + } + + @Timed + @POST + @Path("/{subscriberId}/receipt_credentials") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public CompletableFuture getReceiptCredentials( + @Auth Optional authenticatedAccount, + @PathParam("subscriberId") String subscriberId, + GetReceiptCredentialsRequest request) { + RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); + return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) + .thenApply(this::requireRecordFromGetResult) + .thenCompose(record -> { + if (record.subscriptionId == null) { + return CompletableFuture.completedFuture(Response.noContent().build()); + } + ReceiptCredentialRequest receiptCredentialRequest; + try { + receiptCredentialRequest = new ReceiptCredentialRequest(request.getReceiptCredentialRequest()); + } catch (InvalidInputException e) { + throw new BadRequestException("invalid receipt credential request", e); + } + return stripeManager.getPaidInvoicesForSubscription(record.subscriptionId, requestData.now) + .thenCompose(invoices -> checkNextInvoice(invoices.iterator(), record.subscriptionId)) + .thenCompose(receipt -> { + if (receipt == null) { + return CompletableFuture.completedFuture(null); + } + return issuedReceiptsManager.recordIssuance( + receipt.invoiceLineItemId, receiptCredentialRequest, requestData.now).thenApply(unused -> receipt); + }) + .thenApply(receipt -> { + if (receipt == null) { + return Response.noContent().build(); + } else { + ReceiptCredentialResponse receiptCredentialResponse; + try { + receiptCredentialResponse = zkReceiptOperations.issueReceiptCredential( + receiptCredentialRequest, receipt.getExpiration().getEpochSecond(), receipt.getLevel()); + } catch (VerificationFailedException e) { + throw new BadRequestException("receipt credential request failed verification", e); + } + return Response.ok(new GetReceiptCredentialsResponse(receiptCredentialResponse.serialize())).build(); + } + }); + }); + } + + public static class Receipt { + private final Instant expiration; + private final long level; + private final String invoiceLineItemId; + + public Receipt(Instant expiration, long level, String invoiceLineItemId) { + this.expiration = expiration; + this.level = level; + this.invoiceLineItemId = invoiceLineItemId; + } + + public Instant getExpiration() { + return expiration; + } + + public long getLevel() { + return level; + } + + public String getInvoiceLineItemId() { + return invoiceLineItemId; + } + } + + private CompletableFuture checkNextInvoice(Iterator invoiceIterator, String subscriptionId) { + if (!invoiceIterator.hasNext()) { + return null; + } + + Invoice invoice = invoiceIterator.next(); + return stripeManager.getInvoiceLineItemsForInvoice(invoice).thenCompose(invoiceLineItems -> { + Collection subscriptionLineItems = invoiceLineItems.stream() + .filter(invoiceLineItem -> Objects.equals("subscription", invoiceLineItem.getType())) + .collect(Collectors.toList()); + if (subscriptionLineItems.isEmpty()) { + return checkNextInvoice(invoiceIterator, subscriptionId); + } + if (subscriptionLineItems.size() > 1) { + throw new IllegalStateException("invoice has more than one subscription; subscriptionId=" + subscriptionId + + "; count=" + subscriptionLineItems.size()); + } + + InvoiceLineItem subscriptionLineItem = subscriptionLineItems.stream().findAny().get(); + Product product = subscriptionLineItem.getPrice().getProductObject(); + return CompletableFuture.completedFuture(new Receipt( + Instant.ofEpochSecond(subscriptionLineItem.getPeriod().getEnd()).plus(config.getBadgeGracePeriod()), + stripeManager.getLevelForProduct(product), + subscriptionLineItem.getId())); + }); + } + + private SubscriptionManager.Record requireRecordFromGetResult(SubscriptionManager.GetResult getResult) { + if (getResult == GetResult.PASSWORD_MISMATCH) { + throw new ForbiddenException("subscriberId mismatch"); + } else if (getResult == GetResult.NOT_STORED) { + throw new NotFoundException(); + } else { + return getResult.record; + } + } + + private static class RequestData { + + public final byte[] subscriberBytes; + public final byte[] subscriberUser; + public final byte[] subscriberKey; + public final byte[] hmac; + public final Instant now; + + private RequestData( + @Nonnull byte[] subscriberBytes, + @Nonnull byte[] subscriberUser, + @Nonnull byte[] subscriberKey, + @Nonnull byte[] hmac, + @Nonnull Instant now) { + this.subscriberBytes = Objects.requireNonNull(subscriberBytes); + this.subscriberUser = Objects.requireNonNull(subscriberUser); + this.subscriberKey = Objects.requireNonNull(subscriberKey); + this.hmac = Objects.requireNonNull(hmac); + this.now = Objects.requireNonNull(now); + } + + public static RequestData process( + Optional authenticatedAccount, + String subscriberId, + Clock clock) { + Instant now = clock.instant(); + if (authenticatedAccount.isPresent()) { + throw new ForbiddenException("must not use authenticated connection for subscriber operations"); + } + byte[] subscriberBytes = convertSubscriberIdStringToBytes(subscriberId); + byte[] subscriberUser = getUser(subscriberBytes); + byte[] subscriberKey = getKey(subscriberBytes); + byte[] hmac = computeHmac(subscriberUser, subscriberKey); + return new RequestData(subscriberBytes, subscriberUser, subscriberKey, hmac, now); + } + + private static byte[] convertSubscriberIdStringToBytes(String subscriberId) { + try { + byte[] bytes = Base64.getUrlDecoder().decode(subscriberId); + if (bytes.length != 32) { + throw new NotFoundException(); + } + return bytes; + } catch (IllegalArgumentException e) { + throw new NotFoundException(e); + } + } + + private static byte[] getUser(byte[] subscriberBytes) { + byte[] user = new byte[16]; + System.arraycopy(subscriberBytes, 0, user, 0, user.length); + return user; + } + + private static byte[] getKey(byte[] subscriberBytes) { + byte[] key = new byte[16]; + System.arraycopy(subscriberBytes, 16, key, 0, key.length); + return key; + } + + private static byte[] computeHmac(byte[] subscriberUser, byte[] subscriberKey) { + try { + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(subscriberKey, "HmacSHA256")); + return mac.doFinal(subscriberUser); + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new InternalServerErrorException(e); + } + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/IssuedReceiptsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/IssuedReceiptsManager.java new file mode 100644 index 000000000..65a50292b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/IssuedReceiptsManager.java @@ -0,0 +1,112 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.whispersystems.textsecuregcm.util.AttributeValues.b; +import static org.whispersystems.textsecuregcm.util.AttributeValues.n; +import static org.whispersystems.textsecuregcm.util.AttributeValues.s; + +import com.google.common.base.Throwables; +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; +import java.time.Instant; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.function.Consumer; +import javax.annotation.Nonnull; +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import javax.ws.rs.ClientErrorException; +import javax.ws.rs.core.Response.Status; +import org.signal.zkgroup.receipts.ReceiptCredentialRequest; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.ReturnValue; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; + +public class IssuedReceiptsManager { + + public static final String KEY_INVOICE_LINE_ITEM_ID = "A"; // S (HashKey) + public static final String KEY_ISSUED_RECEIPT_TAG = "B"; // B + public static final String KEY_EXPIRATION = "E"; // N + + private final String table; + private final Duration expiration; + private final DynamoDbAsyncClient dynamoDbAsyncClient; + private final byte[] receiptTagGenerator; + + public IssuedReceiptsManager( + @Nonnull String table, + @Nonnull Duration expiration, + @Nonnull DynamoDbAsyncClient dynamoDbAsyncClient, + @Nonnull byte[] receiptTagGenerator) { + this.table = Objects.requireNonNull(table); + this.expiration = Objects.requireNonNull(expiration); + this.dynamoDbAsyncClient = Objects.requireNonNull(dynamoDbAsyncClient); + this.receiptTagGenerator = Objects.requireNonNull(receiptTagGenerator); + } + + /** + * Returns a future that completes normally if either this invoice line item was never issued a receipt credential + * previously OR if it was issued a receipt credential previously for the exact same receipt credential request + * enabling clients to retry in case they missed the original response. + * + * If this invoice line item id has already been used to issue another receipt, throws a 409 conflict web application + * exception. + */ + public CompletableFuture recordIssuance( + String invoiceLineItemId, + ReceiptCredentialRequest request, + Instant now) { + UpdateItemRequest updateItemRequest = UpdateItemRequest.builder() + .tableName(table) + .key(Map.of(KEY_INVOICE_LINE_ITEM_ID, s(invoiceLineItemId))) + .conditionExpression("attribute_not_exists(#key) OR #tag = :tag") + .returnValues(ReturnValue.NONE) + .updateExpression("SET " + + "#tag = if_not_exists(#tag, :tag), " + + "#exp = if_not_exists(#exp, :exp)") + .expressionAttributeNames(Map.of( + "#key", KEY_INVOICE_LINE_ITEM_ID, + "#tag", KEY_ISSUED_RECEIPT_TAG, + "#exp", KEY_EXPIRATION)) + .expressionAttributeValues(Map.of( + ":tag", b(generateIssuedReceiptTag(request)), + ":exp", n(now.plus(expiration).getEpochSecond()))) + .build(); + return dynamoDbAsyncClient.updateItem(updateItemRequest).handle((updateItemResponse, throwable) -> { + if (throwable != null) { + Throwable rootCause = Throwables.getRootCause(throwable); + if (rootCause instanceof ConditionalCheckFailedException) { + throw new ClientErrorException(Status.CONFLICT, rootCause); + } + Throwables.throwIfUnchecked(throwable); + throw new CompletionException(throwable); + } + return null; + }); + } + + private byte[] generateIssuedReceiptTag(ReceiptCredentialRequest request) { + return generateHmac("issuedReceiptTag", mac -> mac.update(request.serialize())); + } + + private byte[] generateHmac(String type, Consumer byteConsumer) { + try { + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(receiptTagGenerator, "HmacSHA256")); + mac.update(type.getBytes(StandardCharsets.UTF_8)); + byteConsumer.accept(mac); + return mac.doFinal(); + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new AssertionError(e); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SubscriptionManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SubscriptionManager.java new file mode 100644 index 000000000..0dc867b8f --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SubscriptionManager.java @@ -0,0 +1,338 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.whispersystems.textsecuregcm.util.AttributeValues.b; +import static org.whispersystems.textsecuregcm.util.AttributeValues.n; +import static org.whispersystems.textsecuregcm.util.AttributeValues.s; + +import com.google.common.base.Throwables; +import java.security.MessageDigest; +import java.time.Instant; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.model.ReturnValue; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; + +public class SubscriptionManager { + + private static final Logger logger = LoggerFactory.getLogger(SubscriptionManager.class); + + private static final int USER_LENGTH = 16; + + public static final String KEY_USER = "U"; // B (Hash Key) + public static final String KEY_PASSWORD = "P"; // B + public static final String KEY_CUSTOMER_ID = "C"; // S (GSI Hash Key of `c_to_u` index) + public static final String KEY_CREATED_AT = "R"; // N + public static final String KEY_SUBSCRIPTION_ID = "S"; // S + public static final String KEY_SUBSCRIPTION_CREATED_AT = "T"; // N + public static final String KEY_SUBSCRIPTION_LEVEL = "L"; + public static final String KEY_SUBSCRIPTION_LEVEL_CHANGED_AT = "V"; // N + public static final String KEY_ACCESSED_AT = "A"; // N + public static final String KEY_CANCELED_AT = "B"; // N + public static final String KEY_CURRENT_PERIOD_ENDS_AT = "D"; // N + + public static final String INDEX_NAME = "c_to_u"; // Hash Key "C" + + public static class Record { + + public final byte[] user; + public final byte[] password; + public final String customerId; + public final Instant createdAt; + public String subscriptionId; + public Instant subscriptionCreatedAt; + public Long subscriptionLevel; + public Instant subscriptionLevelChangedAt; + public Instant accessedAt; + public Instant canceledAt; + public Instant currentPeriodEndsAt; + + private Record(byte[] user, byte[] password, String customerId, Instant createdAt) { + this.user = checkUserLength(user); + this.password = Objects.requireNonNull(password); + this.customerId = Objects.requireNonNull(customerId); + this.createdAt = Objects.requireNonNull(createdAt); + } + + public static Record from(byte[] user, Map item) { + Record self = new Record( + user, + item.get(KEY_PASSWORD).b().asByteArray(), + item.get(KEY_CUSTOMER_ID).s(), + getInstant(item, KEY_CREATED_AT)); + self.subscriptionId = getString(item, KEY_SUBSCRIPTION_ID); + self.subscriptionCreatedAt = getInstant(item, KEY_SUBSCRIPTION_CREATED_AT); + self.subscriptionLevel = getLong(item, KEY_SUBSCRIPTION_LEVEL); + self.subscriptionLevelChangedAt = getInstant(item, KEY_SUBSCRIPTION_LEVEL_CHANGED_AT); + self.accessedAt = getInstant(item, KEY_ACCESSED_AT); + self.canceledAt = getInstant(item, KEY_CANCELED_AT); + self.currentPeriodEndsAt = getInstant(item, KEY_CURRENT_PERIOD_ENDS_AT); + return self; + } + + public Map asKey() { + return Map.of(KEY_USER, b(user)); + } + + private static String getString(Map item, String key) { + AttributeValue attributeValue = item.get(key); + if (attributeValue == null) { + return null; + } + return attributeValue.s(); + } + + private static Long getLong(Map item, String key) { + AttributeValue attributeValue = item.get(key); + if (attributeValue == null || attributeValue.n() == null) { + return null; + } + return Long.valueOf(attributeValue.n()); + } + + private static Instant getInstant(Map item, String key) { + AttributeValue attributeValue = item.get(key); + if (attributeValue == null || attributeValue.n() == null) { + return null; + } + return Instant.ofEpochSecond(Long.parseLong(attributeValue.n())); + } + } + + private final String table; + private final DynamoDbAsyncClient client; + + public SubscriptionManager( + @Nonnull String table, + @Nonnull DynamoDbAsyncClient client) { + this.table = Objects.requireNonNull(table); + this.client = Objects.requireNonNull(client); + } + + /** + * Looks in the GSI for a record with the given customer id and returns the user id. + */ + public CompletableFuture getSubscriberUserByStripeCustomerId(@Nonnull String customerId) { + QueryRequest query = QueryRequest.builder() + .tableName(table) + .indexName(INDEX_NAME) + .keyConditionExpression("#customer_id = :customer_id") + .projectionExpression("#user") + .expressionAttributeNames(Map.of( + "#customer_id", KEY_CUSTOMER_ID, + "#user", KEY_USER)) + .expressionAttributeValues(Map.of( + ":customer_id", s(Objects.requireNonNull(customerId)))) + .build(); + return client.query(query).thenApply(queryResponse -> { + int count = queryResponse.count(); + if (count == 0) { + return null; + } else if (count > 1) { + logger.error("expected invariant of 1-1 subscriber-customer violated for customer {}", customerId); + throw new IllegalStateException( + "expected invariant of 1-1 subscriber-customer violated for customer " + customerId); + } else { + Map result = queryResponse.items().get(0); + return result.get(KEY_USER).b().asByteArray(); + } + }); + } + + public static class GetResult { + + public static final GetResult NOT_STORED = new GetResult(Type.NOT_STORED, null); + public static final GetResult PASSWORD_MISMATCH = new GetResult(Type.PASSWORD_MISMATCH, null); + + public enum Type { + NOT_STORED, + PASSWORD_MISMATCH, + FOUND + } + + public final Type type; + public final Record record; + + private GetResult(Type type, Record record) { + this.type = type; + this.record = record; + } + + public static GetResult found(Record record) { + return new GetResult(Type.FOUND, record); + } + } + + /** + * Looks up a record with the given {@code user} and validates the {@code hmac} before returning it. + */ + public CompletableFuture get(byte[] user, byte[] hmac) { + checkUserLength(user); + + GetItemRequest request = GetItemRequest.builder() + .consistentRead(Boolean.TRUE) + .tableName(table) + .key(Map.of(KEY_USER, b(user))) + .build(); + return client.getItem(request).thenApply(getItemResponse -> { + if (!getItemResponse.hasItem()) { + return GetResult.NOT_STORED; + } + + Record record = Record.from(user, getItemResponse.item()); + if (!MessageDigest.isEqual(hmac, record.password)) { + return GetResult.PASSWORD_MISMATCH; + } + return GetResult.found(record); + }); + } + + public CompletableFuture create(byte[] user, byte[] password, String customerId, Instant createdAt) { + checkUserLength(user); + + UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(table) + .key(Map.of(KEY_USER, b(user))) + .returnValues(ReturnValue.ALL_NEW) + .conditionExpression("attribute_not_exists(#user) OR #password = :password") + .updateExpression("SET " + + "#password = if_not_exists(#password, :password), " + + "#customer_id = if_not_exists(#customer_id, :customer_id), " + + "#created_at = if_not_exists(#created_at, :created_at), " + + "#accessed_at = if_not_exists(#accessed_at, :accessed_at)") + .expressionAttributeNames(Map.of( + "#user", KEY_USER, + "#password", KEY_PASSWORD, + "#customer_id", KEY_CUSTOMER_ID, + "#created_at", KEY_CREATED_AT, + "#accessed_at", KEY_ACCESSED_AT)) + .expressionAttributeValues(Map.of( + ":password", b(password), + ":customer_id", s(customerId), + ":created_at", n(createdAt.getEpochSecond()), + ":accessed_at", n(createdAt.getEpochSecond()))) + .build(); + return client.updateItem(request).handle((updateItemResponse, throwable) -> { + if (throwable != null) { + if (Throwables.getRootCause(throwable) instanceof ConditionalCheckFailedException) { + return null; + } + Throwables.throwIfUnchecked(throwable); + throw new CompletionException(throwable); + } + + return Record.from(user, updateItemResponse.attributes()); + }); + } + + public CompletableFuture accessedAt(byte[] user, Instant accessedAt) { + checkUserLength(user); + + UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(table) + .key(Map.of(KEY_USER, b(user))) + .returnValues(ReturnValue.NONE) + .updateExpression("SET #accessed_at = :accessed_at") + .expressionAttributeNames(Map.of("#accessed_at", KEY_ACCESSED_AT)) + .expressionAttributeValues(Map.of(":accessed_at", n(accessedAt.getEpochSecond()))) + .build(); + return client.updateItem(request).thenApply(updateItemResponse -> null); + } + + public CompletableFuture canceledAt(byte[] user, Instant canceledAt) { + checkUserLength(user); + + UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(table) + .key(Map.of(KEY_USER, b(user))) + .returnValues(ReturnValue.NONE) + .updateExpression("SET " + + "#accessed_at = :accessed_at, " + + "#canceled_at = :canceled_at " + + "REMOVE #subscription_id") + .expressionAttributeNames(Map.of( + "#accessed_at", KEY_ACCESSED_AT, + "#canceled_at", KEY_CANCELED_AT, + "#subscription_id", KEY_SUBSCRIPTION_ID)) + .expressionAttributeValues(Map.of( + ":accessed_at", n(canceledAt.getEpochSecond()), + ":canceled_at", n(canceledAt.getEpochSecond()))) + .build(); + return client.updateItem(request).thenApply(updateItemResponse -> null); + } + + public CompletableFuture subscriptionCreated( + byte[] user, String subscriptionId, Instant subscriptionCreatedAt, long level) { + checkUserLength(user); + + UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(table) + .key(Map.of(KEY_USER, b(user))) + .returnValues(ReturnValue.NONE) + .updateExpression("SET " + + "#accessed_at = :accessed_at, " + + "#subscription_id = :subscription_id, " + + "#subscription_created_at = :subscription_created_at, " + + "#subscription_level = :subscription_level, " + + "#subscription_level_changed_at = :subscription_level_changed_at") + .expressionAttributeNames(Map.of( + "#accessed_at", KEY_ACCESSED_AT, + "#subscription_id", KEY_SUBSCRIPTION_ID, + "#subscription_created_at", KEY_SUBSCRIPTION_CREATED_AT, + "#subscription_level", KEY_SUBSCRIPTION_LEVEL, + "#subscription_level_changed_at", KEY_SUBSCRIPTION_LEVEL_CHANGED_AT)) + .expressionAttributeValues(Map.of( + ":accessed_at", n(subscriptionCreatedAt.getEpochSecond()), + ":subscription_id", s(subscriptionId), + ":subscription_created_at", n(subscriptionCreatedAt.getEpochSecond()), + ":subscription_level", n(level), + ":subscription_level_changed_at", n(subscriptionCreatedAt.getEpochSecond()))) + .build(); + return client.updateItem(request).thenApply(updateItemResponse -> null); + } + + public CompletableFuture subscriptionLevelChanged( + byte[] user, Instant subscriptionLevelChangedAt, long level) { + checkUserLength(user); + + UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(table) + .key(Map.of(KEY_USER, b(user))) + .returnValues(ReturnValue.NONE) + .updateExpression("SET " + + "#accessed_at = :accessed_at, " + + "#subscription_level = :subscription_level, " + + "#subscription_level_changed_at = :subscription_level_changed_at") + .expressionAttributeNames(Map.of( + "#accessed_at", KEY_ACCESSED_AT, + "#subscription_level", KEY_SUBSCRIPTION_LEVEL, + "#subscription_level_changed_at", KEY_SUBSCRIPTION_LEVEL_CHANGED_AT)) + .expressionAttributeValues(Map.of( + ":accessed_at", n(subscriptionLevelChangedAt.getEpochSecond()), + ":subscription_level", n(level), + ":subscription_level_changed_at", n(subscriptionLevelChangedAt.getEpochSecond()))) + .build(); + return client.updateItem(request).thenApply(updateItemResponse -> null); + } + + private static byte[] checkUserLength(final byte[] user) { + if (user.length != USER_LENGTH) { + throw new IllegalArgumentException("user length is wrong; expected " + USER_LENGTH + "; was " + user.length); + } + return user; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/stripe/StripeManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/stripe/StripeManager.java new file mode 100644 index 000000000..525d8fedc --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/stripe/StripeManager.java @@ -0,0 +1,332 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.stripe; + +import com.google.common.base.Strings; +import com.google.common.collect.Lists; +import com.stripe.exception.StripeException; +import com.stripe.model.Customer; +import com.stripe.model.Invoice; +import com.stripe.model.InvoiceLineItem; +import com.stripe.model.Price; +import com.stripe.model.Product; +import com.stripe.model.SetupIntent; +import com.stripe.model.Subscription; +import com.stripe.model.SubscriptionItem; +import com.stripe.net.RequestOptions; +import com.stripe.param.CustomerCreateParams; +import com.stripe.param.CustomerRetrieveParams; +import com.stripe.param.InvoiceListParams; +import com.stripe.param.PriceRetrieveParams; +import com.stripe.param.SetupIntentCreateParams; +import com.stripe.param.SubscriptionCancelParams; +import com.stripe.param.SubscriptionCreateParams; +import com.stripe.param.SubscriptionListParams; +import com.stripe.param.SubscriptionUpdateParams; +import com.stripe.param.SubscriptionUpdateParams.BillingCycleAnchor; +import com.stripe.param.SubscriptionUpdateParams.ProrationBehavior; +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.Executor; +import java.util.function.Consumer; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import org.apache.commons.codec.binary.Hex; + +public class StripeManager { + + private static final String METADATA_KEY_LEVEL = "level"; + + private final String apiKey; + private final Executor executor; + private final byte[] idempotencyKeyGenerator; + + public StripeManager( + @Nonnull String apiKey, + @Nonnull Executor executor, + @Nonnull byte[] idempotencyKeyGenerator) { + this.apiKey = Objects.requireNonNull(apiKey); + if (Strings.isNullOrEmpty(apiKey)) { + throw new IllegalArgumentException("apiKey cannot be empty"); + } + this.executor = Objects.requireNonNull(executor); + this.idempotencyKeyGenerator = Objects.requireNonNull(idempotencyKeyGenerator); + if (idempotencyKeyGenerator.length == 0) { + throw new IllegalArgumentException("idempotencyKeyGenerator cannot be empty"); + } + } + + private RequestOptions commonOptions() { + return commonOptions(null); + } + + private RequestOptions commonOptions(@Nullable String idempotencyKey) { + return RequestOptions.builder() + .setIdempotencyKey(idempotencyKey) + .setApiKey(apiKey) + .build(); + } + + public CompletableFuture createCustomer(byte[] subscriberUser) { + return CompletableFuture.supplyAsync(() -> { + CustomerCreateParams params = CustomerCreateParams.builder() + .putMetadata("subscriberUser", Hex.encodeHexString(subscriberUser)) + .build(); + try { + return Customer.create(params, commonOptions(generateIdempotencyKeyForSubscriberUser(subscriberUser))); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture getCustomer(String customerId) { + return CompletableFuture.supplyAsync(() -> { + CustomerRetrieveParams params = CustomerRetrieveParams.builder().build(); + try { + return Customer.retrieve(customerId, params, commonOptions()); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture createSetupIntent(String customerId) { + return CompletableFuture.supplyAsync(() -> { + SetupIntentCreateParams params = SetupIntentCreateParams.builder() + .setCustomer(customerId) + .build(); + try { + return SetupIntent.create(params, commonOptions()); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture createSubscription(String customerId, String priceId, long level) { + return CompletableFuture.supplyAsync(() -> { + SubscriptionCreateParams params = SubscriptionCreateParams.builder() + .setCustomer(customerId) + .addItem(SubscriptionCreateParams.Item.builder() + .setPrice(priceId) + .build()) + .putMetadata(METADATA_KEY_LEVEL, Long.toString(level)) + .build(); + try { + // the idempotency key intentionally excludes priceId + // + // If the client tells the server several times in a row before the initial creation of a subscription to + // create a subscription, we want to ensure only one gets created. If the prices are different each time, + // whichever one gets to stripe first will win (depending on how idempotent the idempotency keys are...) + return Subscription.create(params, commonOptions(generateIdempotencyKeyForCustomerId(customerId))); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture updateSubscription( + Subscription subscription, String priceId, long level, String idempotencyKey) { + return CompletableFuture.supplyAsync(() -> { + List items = new ArrayList<>(); + for (final SubscriptionItem item : subscription.getItems().autoPagingIterable(null, commonOptions())) { + items.add(SubscriptionUpdateParams.Item.builder() + .setId(item.getId()) + .setDeleted(true) + .build()); + } + items.add(SubscriptionUpdateParams.Item.builder() + .setPrice(priceId) + .build()); + SubscriptionUpdateParams params = SubscriptionUpdateParams.builder() + .putMetadata(METADATA_KEY_LEVEL, Long.toString(level)) + + // since badge redemption is untrackable by design and unrevokable, subscription changes must be immediate and + // not prorated + .setProrationBehavior(ProrationBehavior.NONE) + .setBillingCycleAnchor(BillingCycleAnchor.NOW) + .addAllItem(items) + .build(); + try { + return subscription.update(params, commonOptions(generateIdempotencyKeyForSubscriptionUpdate( + subscription.getCustomer(), idempotencyKey))); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture getSubscription(String subscriptionId) { + return CompletableFuture.supplyAsync(() -> { + try { + return Subscription.retrieve(subscriptionId, commonOptions()); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture> listNonCanceledSubscriptions(Customer customer) { + return CompletableFuture.supplyAsync(() -> { + SubscriptionListParams params = SubscriptionListParams.builder() + .setCustomer(customer.getId()) + .build(); + try { + return Lists.newArrayList(Subscription.list(params, commonOptions()).autoPagingIterable(null, commonOptions())); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture cancelSubscriptionImmediately(Subscription subscription) { + return CompletableFuture.supplyAsync(() -> { + SubscriptionCancelParams params = SubscriptionCancelParams.builder().build(); + try { + return subscription.cancel(params, commonOptions()); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture cancelSubscriptionAtEndOfCurrentPeriod(Subscription subscription) { + return CompletableFuture.supplyAsync(() -> { + SubscriptionUpdateParams params = SubscriptionUpdateParams.builder() + .setCancelAtPeriodEnd(true) + .build(); + try { + return subscription.update(params, commonOptions()); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture> getItemsForSubscription(Subscription subscription) { + return CompletableFuture.supplyAsync( + () -> Lists.newArrayList(subscription.getItems().autoPagingIterable(null, commonOptions())), + executor); + } + + public CompletableFuture getPriceForSubscription(Subscription subscription) { + return getItemsForSubscription(subscription).thenApply(subscriptionItems -> { + if (subscriptionItems.isEmpty()) { + throw new IllegalStateException("no items found in subscription " + subscription.getId()); + } else if (subscriptionItems.size() > 1) { + throw new IllegalStateException( + "too many items found in subscription " + subscription.getId() + "; items=" + subscriptionItems.size()); + } else { + return subscriptionItems.stream().findAny().get().getPrice(); + } + }); + } + + public CompletableFuture getProductForSubscription(Subscription subscription) { + return getPriceForSubscription(subscription).thenCompose(price -> getProductForPrice(price.getId())); + } + + public CompletableFuture getLevelForSubscription(Subscription subscription) { + return getProductForSubscription(subscription).thenApply(this::getLevelForProduct); + } + + public CompletableFuture getLevelForPrice(Price price) { + return getProductForPrice(price.getId()).thenApply(this::getLevelForProduct); + } + + public CompletableFuture getProductForPrice(String priceId) { + return CompletableFuture.supplyAsync(() -> { + PriceRetrieveParams params = PriceRetrieveParams.builder().addExpand("product").build(); + try { + return Price.retrieve(priceId, params, commonOptions()).getProductObject(); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public long getLevelForProduct(Product product) { + return Long.parseLong(product.getMetadata().get(METADATA_KEY_LEVEL)); + } + + /** + * Returns the paid invoices within the past 90 days for a subscription ordered by the creation date in descending + * order (latest first). + */ + public CompletableFuture> getPaidInvoicesForSubscription(String subscriptionId, Instant now) { + return CompletableFuture.supplyAsync(() -> { + InvoiceListParams params = InvoiceListParams.builder() + .setSubscription(subscriptionId) + .setStatus(InvoiceListParams.Status.PAID) + .setCreated(InvoiceListParams.Created.builder() + .setGte(now.minus(Duration.ofDays(90)).getEpochSecond()) + .build()) + .addExpand("lines.data.price.product") + .build(); + try { + ArrayList invoices = Lists.newArrayList(Invoice.list(params, commonOptions()) + .autoPagingIterable(null, commonOptions())); + invoices.sort(Comparator.comparingLong(Invoice::getCreated).reversed()); + return invoices; + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor); + } + + public CompletableFuture> getInvoiceLineItemsForInvoice(Invoice invoice) { + return CompletableFuture.supplyAsync( + () -> Lists.newArrayList(invoice.getLines().autoPagingIterable(null, commonOptions())), executor); + } + + /** + * We use a client generated idempotency key for subscription updates due to not being able to distinguish between a + * call to update to level 2, then back to level 1, then back to level 2. If this all happens within Stripe's + * idempotency window the subsequent update call would not happen unless we get some indication from the client that + * it is intentionally sending a repeat of the update to level 2 request because user is changing again, so in this + * case we derive idempotency from the client. + */ + private String generateIdempotencyKeyForSubscriptionUpdate(String customerId, String idempotencyKey) { + return generateIdempotencyKey("subscriptionUpdate", mac -> { + mac.update(customerId.getBytes(StandardCharsets.UTF_8)); + mac.update(idempotencyKey.getBytes(StandardCharsets.UTF_8)); + }); + } + + private String generateIdempotencyKeyForSubscriberUser(byte[] subscriberUser) { + return generateIdempotencyKey("subscriberUser", mac -> mac.update(subscriberUser)); + } + + private String generateIdempotencyKeyForCustomerId(String customerId) { + return generateIdempotencyKey("customerId", mac -> mac.update(customerId.getBytes(StandardCharsets.UTF_8))); + } + + private String generateIdempotencyKey(String type, Consumer byteConsumer) { + try { + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(idempotencyKeyGenerator, "HmacSHA256")); + mac.update(type.getBytes(StandardCharsets.UTF_8)); + byteConsumer.accept(mac); + return Base64.getUrlEncoder().encodeToString(mac.doFinal()); + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new AssertionError(e); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java index de77e859a..ab35e5e76 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java @@ -33,6 +33,11 @@ public class AttributeValues { return AttributeValue.builder().n(String.valueOf(value)).build(); } + public static AttributeValue s(String value) { + return AttributeValue.builder().s(value).build(); + } + + // More opinionated methods public static AttributeValue fromString(String value) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/DynamoDbFromConfig.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/DynamoDbFromConfig.java index 1e30d05dc..52004cee5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/DynamoDbFromConfig.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/DynamoDbFromConfig.java @@ -1,36 +1,36 @@ package org.whispersystems.textsecuregcm.util; +import org.whispersystems.textsecuregcm.configuration.DynamoDbClientConfiguration; import org.whispersystems.textsecuregcm.configuration.DynamoDbConfiguration; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; -import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClientBuilder; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; public class DynamoDbFromConfig { - private static ClientOverrideConfiguration clientOverrideConfiguration(DynamoDbConfiguration config) { - return ClientOverrideConfiguration.builder() - .apiCallTimeout(config.getClientExecutionTimeout()) - .apiCallAttemptTimeout(config.getClientRequestTimeout()) - .build(); - } - public static DynamoDbClient client(DynamoDbConfiguration config, AwsCredentialsProvider credentialsProvider) { return DynamoDbClient.builder() .region(Region.of(config.getRegion())) .credentialsProvider(credentialsProvider) - .overrideConfiguration(clientOverrideConfiguration(config)) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .apiCallTimeout(config.getClientExecutionTimeout()) + .apiCallAttemptTimeout(config.getClientRequestTimeout()) + .build()) .build(); } public static DynamoDbAsyncClient asyncClient( - DynamoDbConfiguration config, AwsCredentialsProvider credentialsProvider) { - DynamoDbAsyncClientBuilder builder = DynamoDbAsyncClient.builder() + DynamoDbClientConfiguration config, + AwsCredentialsProvider credentialsProvider) { + return DynamoDbAsyncClient.builder() .region(Region.of(config.getRegion())) .credentialsProvider(credentialsProvider) - .overrideConfiguration(clientOverrideConfiguration(config)); - return builder.build(); + .overrideConfiguration(ClientOverrideConfiguration.builder() + .apiCallTimeout(config.getClientExecutionTimeout()) + .apiCallAttemptTimeout(config.getClientRequestTimeout()) + .build()) + .build(); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java index 38cad7696..0c8fb0bde 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java @@ -30,7 +30,7 @@ import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput; public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback { - static final String DEFAULT_TABLE_NAME = "test_table"; + static final String DEFAULT_TABLE_NAME = "test_table"; static final ProvisionedThroughput DEFAULT_PROVISIONED_THROUGHPUT = ProvisionedThroughput.builder() .readCapacityUnits(20L) @@ -164,12 +164,12 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback private String hashKey; private String rangeKey; - private List attributeDefinitions = new ArrayList<>(); - private List globalSecondaryIndexes = new ArrayList<>(); - private List localSecondaryIndexes = new ArrayList<>(); + private final List attributeDefinitions = new ArrayList<>(); + private final List globalSecondaryIndexes = new ArrayList<>(); + private final List localSecondaryIndexes = new ArrayList<>(); - private long readCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.readCapacityUnits(); - private long writeCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.writeCapacityUnits(); + private final long readCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.readCapacityUnits(); + private final long writeCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.writeCapacityUnits(); private DynamoDbExtensionBuilder() { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/IssuedReceiptsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/IssuedReceiptsManagerTest.java new file mode 100644 index 000000000..f2cb25aa7 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/IssuedReceiptsManagerTest.java @@ -0,0 +1,86 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.security.SecureRandom; +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.CompletableFuture; +import javax.ws.rs.ClientErrorException; +import org.assertj.core.api.Condition; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.signal.zkgroup.receipts.ReceiptCredentialRequest; +import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; +import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType; + +class IssuedReceiptsManagerTest { + + private static final long NOW_EPOCH_SECONDS = 1_500_000_000L; + private static final String ISSUED_RECEIPTS_TABLE_NAME = "issued_receipts"; + private static final SecureRandom SECURE_RANDOM = new SecureRandom(); + + @RegisterExtension + static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder() + .tableName(ISSUED_RECEIPTS_TABLE_NAME) + .hashKey(IssuedReceiptsManager.KEY_INVOICE_LINE_ITEM_ID) + .attributeDefinition(AttributeDefinition.builder() + .attributeName(IssuedReceiptsManager.KEY_INVOICE_LINE_ITEM_ID) + .attributeType(ScalarAttributeType.S) + .build()) + .build(); + + ReceiptCredentialRequest receiptCredentialRequest; + IssuedReceiptsManager issuedReceiptsManager; + + @BeforeEach + void beforeEach() { + receiptCredentialRequest = mock(ReceiptCredentialRequest.class); + byte[] generator = new byte[16]; + SECURE_RANDOM.nextBytes(generator); + issuedReceiptsManager = new IssuedReceiptsManager( + ISSUED_RECEIPTS_TABLE_NAME, + Duration.ofDays(90), + dynamoDbExtension.getDynamoDbAsyncClient(), + generator); + } + + @Test + void testRecordIssuance() { + Instant now = Instant.ofEpochSecond(NOW_EPOCH_SECONDS); + byte[] request1 = new byte[ReceiptCredentialRequest.SIZE]; + SECURE_RANDOM.nextBytes(request1); + when(receiptCredentialRequest.serialize()).thenReturn(request1); + CompletableFuture future = issuedReceiptsManager.recordIssuance("item-1", receiptCredentialRequest, now); + assertThat(future).succeedsWithin(Duration.ofSeconds(3)); + + // same request should succeed + future = issuedReceiptsManager.recordIssuance("item-1", receiptCredentialRequest, now); + assertThat(future).succeedsWithin(Duration.ofSeconds(3)); + + // same item with new request should fail + byte[] request2 = new byte[ReceiptCredentialRequest.SIZE]; + SECURE_RANDOM.nextBytes(request2); + when(receiptCredentialRequest.serialize()).thenReturn(request2); + future = issuedReceiptsManager.recordIssuance("item-1", receiptCredentialRequest, now); + assertThat(future).failsWithin(Duration.ofSeconds(3)). + withThrowableOfType(Throwable.class). + havingCause(). + isExactlyInstanceOf(ClientErrorException.class). + has(new Condition<>( + e -> e instanceof ClientErrorException && ((ClientErrorException) e).getResponse().getStatus() == 409, + "status 409")); + + // different item with new request should be okay though + future = issuedReceiptsManager.recordIssuance("item-2", receiptCredentialRequest, now); + assertThat(future).succeedsWithin(Duration.ofSeconds(3)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SubscriptionManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SubscriptionManagerTest.java new file mode 100644 index 000000000..279bbea14 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SubscriptionManagerTest.java @@ -0,0 +1,227 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult.Type.FOUND; +import static org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult.Type.NOT_STORED; +import static org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult.Type.PASSWORD_MISMATCH; + +import java.security.SecureRandom; +import java.time.Duration; +import java.time.Instant; +import java.util.Base64; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import javax.annotation.Nonnull; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult; +import org.whispersystems.textsecuregcm.storage.SubscriptionManager.Record; +import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; +import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex; +import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; +import software.amazon.awssdk.services.dynamodb.model.KeyType; +import software.amazon.awssdk.services.dynamodb.model.Projection; +import software.amazon.awssdk.services.dynamodb.model.ProjectionType; +import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput; +import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType; + +class SubscriptionManagerTest { + + private static final long NOW_EPOCH_SECONDS = 1_500_000_000L; + private static final String SUBSCRIPTIONS_TABLE_NAME = "subscriptions"; + private static final SecureRandom SECURE_RANDOM = new SecureRandom(); + + @RegisterExtension + static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder(). + tableName(SUBSCRIPTIONS_TABLE_NAME). + hashKey(SubscriptionManager.KEY_USER). + attributeDefinition(AttributeDefinition.builder(). + attributeName(SubscriptionManager.KEY_USER). + attributeType(ScalarAttributeType.B). + build()). + attributeDefinition(AttributeDefinition.builder(). + attributeName(SubscriptionManager.KEY_CUSTOMER_ID). + attributeType(ScalarAttributeType.S). + build()). + globalSecondaryIndex(GlobalSecondaryIndex.builder(). + indexName("c_to_u"). + keySchema(KeySchemaElement.builder(). + attributeName(SubscriptionManager.KEY_CUSTOMER_ID). + keyType(KeyType.HASH). + build()). + projection(Projection.builder(). + projectionType(ProjectionType.KEYS_ONLY). + build()). + provisionedThroughput(ProvisionedThroughput.builder(). + readCapacityUnits(20L). + writeCapacityUnits(20L). + build()). + build()). + build(); + + byte[] user; + byte[] password; + String customer; + Instant created; + SubscriptionManager subscriptionManager; + + @BeforeEach + void beforeEach() { + user = getRandomBytes(16); + password = getRandomBytes(16); + customer = Base64.getEncoder().encodeToString(getRandomBytes(16)); + created = Instant.ofEpochSecond(NOW_EPOCH_SECONDS); + subscriptionManager = new SubscriptionManager( + SUBSCRIPTIONS_TABLE_NAME, dynamoDbExtension.getDynamoDbAsyncClient()); + } + + @Test + void testCreateOnlyOnce() { + byte[] password1 = getRandomBytes(16); + byte[] password2 = getRandomBytes(16); + String customer1 = Base64.getEncoder().encodeToString(getRandomBytes(16)); + String customer2 = Base64.getEncoder().encodeToString(getRandomBytes(16)); + Instant created1 = Instant.ofEpochSecond(NOW_EPOCH_SECONDS); + Instant created2 = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); + + CompletableFuture getFuture = subscriptionManager.get(user, password1); + assertThat(getFuture).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { + assertThat(getResult.type).isEqualTo(NOT_STORED); + assertThat(getResult.record).isNull(); + }); + + getFuture = subscriptionManager.get(user, password2); + assertThat(getFuture).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { + assertThat(getResult.type).isEqualTo(NOT_STORED); + assertThat(getResult.record).isNull(); + }); + + CompletableFuture createFuture = + subscriptionManager.create(user, password1, customer1, created1); + Consumer recordRequirements = checkFreshlyCreatedRecord(user, password1, customer1, created1); + assertThat(createFuture).succeedsWithin(Duration.ofSeconds(3)).satisfies(recordRequirements); + + // password check fails so this should return null + createFuture = subscriptionManager.create(user, password2, customer2, created2); + assertThat(createFuture).succeedsWithin(Duration.ofSeconds(3)).isNull(); + + // password check matches, but the record already exists so nothing should get updated + createFuture = subscriptionManager.create(user, password1, customer2, created2); + assertThat(createFuture).succeedsWithin(Duration.ofSeconds(3)).satisfies(recordRequirements); + } + + @Test + void testGet() { + byte[] wrongUser = getRandomBytes(16); + byte[] wrongPassword = getRandomBytes(16); + assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + + assertThat(subscriptionManager.get(user, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { + assertThat(getResult.type).isEqualTo(FOUND); + assertThat(getResult.record).isNotNull().satisfies(checkFreshlyCreatedRecord(user, password, customer, created)); + }); + + assertThat(subscriptionManager.get(user, wrongPassword)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { + assertThat(getResult.type).isEqualTo(PASSWORD_MISMATCH); + assertThat(getResult.record).isNull(); + }); + + assertThat(subscriptionManager.get(wrongUser, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { + assertThat(getResult.type).isEqualTo(NOT_STORED); + assertThat(getResult.record).isNull(); + }); + } + + @Test + void testLookupByCustomerId() { + assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.getSubscriberUserByStripeCustomerId(customer)). + succeedsWithin(Duration.ofSeconds(3)). + isEqualTo(user); + } + + @Test + void testCanceledAt() { + Instant canceled = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 42); + assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.canceledAt(user, canceled)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.get(user, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { + assertThat(getResult).isNotNull(); + assertThat(getResult.type).isEqualTo(FOUND); + assertThat(getResult.record).isNotNull().satisfies(record -> { + assertThat(record.accessedAt).isEqualTo(canceled); + assertThat(record.canceledAt).isEqualTo(canceled); + assertThat(record.subscriptionId).isNull(); + }); + }); + } + + @Test + void testSubscriptionCreated() { + String subscriptionId = Base64.getEncoder().encodeToString(getRandomBytes(16)); + Instant subscriptionCreated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); + long level = 42; + assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.subscriptionCreated(user, subscriptionId, subscriptionCreated, level)). + succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.get(user, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { + assertThat(getResult).isNotNull(); + assertThat(getResult.type).isEqualTo(FOUND); + assertThat(getResult.record).isNotNull().satisfies(record -> { + assertThat(record.accessedAt).isEqualTo(subscriptionCreated); + assertThat(record.subscriptionId).isEqualTo(subscriptionId); + assertThat(record.subscriptionCreatedAt).isEqualTo(subscriptionCreated); + assertThat(record.subscriptionLevel).isEqualTo(level); + assertThat(record.subscriptionLevelChangedAt).isEqualTo(subscriptionCreated); + }); + }); + } + + @Test + void testSubscriptionLevelChanged() { + Instant at = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 500); + long level = 1776; + assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.subscriptionLevelChanged(user, at, level)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.get(user, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { + assertThat(getResult).isNotNull(); + assertThat(getResult.type).isEqualTo(FOUND); + assertThat(getResult.record).isNotNull().satisfies(record -> { + assertThat(record.accessedAt).isEqualTo(at); + assertThat(record.subscriptionLevelChangedAt).isEqualTo(at); + assertThat(record.subscriptionLevel).isEqualTo(level); + }); + }); + } + + private static byte[] getRandomBytes(int length) { + byte[] result = new byte[length]; + SECURE_RANDOM.nextBytes(result); + return result; + } + + @Nonnull + private static Consumer checkFreshlyCreatedRecord( + byte[] user, byte[] password, String customer, Instant created) { + return record -> { + assertThat(record).isNotNull(); + assertThat(record.user).isEqualTo(user); + assertThat(record.password).isEqualTo(password); + assertThat(record.customerId).isEqualTo(customer); + assertThat(record.createdAt).isEqualTo(created); + assertThat(record.subscriptionId).isNull(); + assertThat(record.subscriptionCreatedAt).isNull(); + assertThat(record.subscriptionLevel).isNull(); + assertThat(record.subscriptionLevelChangedAt).isNull(); + assertThat(record.accessedAt).isEqualTo(created); + assertThat(record.canceledAt).isNull(); + assertThat(record.currentPeriodEndsAt).isNull(); + }; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DonationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DonationControllerTest.java index efe483b56..41fdab729 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DonationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DonationControllerTest.java @@ -49,6 +49,7 @@ import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.DonationConfiguration; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; +import org.whispersystems.textsecuregcm.configuration.StripeConfiguration; import org.whispersystems.textsecuregcm.controllers.DonationController; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse; @@ -73,7 +74,6 @@ class DonationControllerTest { static DonationConfiguration getDonationConfiguration() { DonationConfiguration configuration = new DonationConfiguration(); - configuration.setApiKey("test-api-key"); configuration.setDescription("some description"); configuration.setUri("http://localhost:" + wm.getRuntimeInfo().getHttpPort() + "/foo/bar"); configuration.setCircuitBreaker(new CircuitBreakerConfiguration()); @@ -82,6 +82,10 @@ class DonationControllerTest { return configuration; } + static StripeConfiguration getStripeConfiguration() { + return new StripeConfiguration("test-api-key", new byte[16]); + } + static BadgesConfiguration getBadgesConfiguration() { return new BadgesConfiguration( List.of( @@ -135,7 +139,7 @@ class DonationControllerTest { .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, getBadgesConfiguration(), receiptCredentialPresentationFactory, httpClientExecutor, - getDonationConfiguration())) + getDonationConfiguration(), getStripeConfiguration())) .build(); resources.before(); }