Move common subscription management out of controller

This commit is contained in:
Ravi Khadiwala 2024-08-14 10:14:18 -05:00 committed by ravi-signal
parent a8eaf2d0ad
commit 97e566d470
21 changed files with 1242 additions and 948 deletions

View File

@ -177,6 +177,7 @@ import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptio
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RegistrationServiceSenderExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RegistrationServiceSenderExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper; import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.SubscriptionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.SubscriptionProcessorExceptionMapper; import org.whispersystems.textsecuregcm.mappers.SubscriptionProcessorExceptionMapper;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsApplicationEventListener; import org.whispersystems.textsecuregcm.metrics.MetricsApplicationEventListener;
@ -234,6 +235,7 @@ import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb; import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.SubscriptionManager; import org.whispersystems.textsecuregcm.storage.SubscriptionManager;
import org.whispersystems.textsecuregcm.storage.Subscriptions;
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager; import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
import org.whispersystems.textsecuregcm.storage.VerificationSessions; import org.whispersystems.textsecuregcm.storage.VerificationSessions;
import org.whispersystems.textsecuregcm.subscriptions.BankMandateTranslator; import org.whispersystems.textsecuregcm.subscriptions.BankMandateTranslator;
@ -667,7 +669,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getRedeemedReceipts().getTableName(), config.getDynamoDbTables().getRedeemedReceipts().getTableName(),
dynamoDbAsyncClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getRedeemedReceipts().getExpiration()); config.getDynamoDbTables().getRedeemedReceipts().getExpiration());
SubscriptionManager subscriptionManager = new SubscriptionManager( Subscriptions subscriptions = new Subscriptions(
config.getDynamoDbTables().getSubscriptions().getTableName(), dynamoDbAsyncClient); config.getDynamoDbTables().getSubscriptions().getTableName(), dynamoDbAsyncClient);
final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager( final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
@ -1119,9 +1121,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
accountsManager, registrationFraudChecker, dynamicConfigurationManager, clock) accountsManager, registrationFraudChecker, dynamicConfigurationManager, clock)
); );
if (config.getSubscription() != null && config.getOneTimeDonations() != null) { if (config.getSubscription() != null && config.getOneTimeDonations() != null) {
SubscriptionManager subscriptionManager = new SubscriptionManager(subscriptions,
List.of(stripeManager, braintreeManager), zkReceiptOperations, issuedReceiptsManager);
commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(), commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(),
subscriptionManager, stripeManager, braintreeManager, zkReceiptOperations, issuedReceiptsManager, subscriptionManager, stripeManager, braintreeManager, profileBadgeConverter, resourceBundleLevelTranslator,
profileBadgeConverter, resourceBundleLevelTranslator, bankMandateTranslator)); bankMandateTranslator));
commonControllers.add(new OneTimeDonationController(clock, config.getOneTimeDonations(), stripeManager, braintreeManager, commonControllers.add(new OneTimeDonationController(clock, config.getOneTimeDonations(), stripeManager, braintreeManager,
zkReceiptOperations, issuedReceiptsManager, oneTimeDonationsManager)); zkReceiptOperations, issuedReceiptsManager, oneTimeDonationsManager));
} }
@ -1188,6 +1192,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new NonNormalizedPhoneNumberExceptionMapper(), new NonNormalizedPhoneNumberExceptionMapper(),
new RegistrationServiceSenderExceptionMapper(), new RegistrationServiceSenderExceptionMapper(),
new SubscriptionProcessorExceptionMapper(), new SubscriptionProcessorExceptionMapper(),
new SubscriptionExceptionMapper(),
new JsonMappingExceptionMapper() new JsonMappingExceptionMapper()
).forEach(exceptionMapper -> { ).forEach(exceptionMapper -> {
environment.jersey().register(exceptionMapper); environment.jersey().register(exceptionMapper);

View File

@ -12,9 +12,9 @@ import javax.validation.constraints.DecimalMin;
import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotEmpty; import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
public record SubscriptionPriceConfiguration(@Valid @NotEmpty Map<SubscriptionProcessor, @NotBlank String> processorIds, public record SubscriptionPriceConfiguration(@Valid @NotEmpty Map<PaymentProvider, @NotBlank String> processorIds,
@NotNull @DecimalMin("0.01") BigDecimal amount) { @NotNull @DecimalMin("0.01") BigDecimal amount) {
} }

View File

@ -56,8 +56,8 @@ import org.whispersystems.textsecuregcm.subscriptions.PaymentDetails;
import org.whispersystems.textsecuregcm.subscriptions.PaymentMethod; import org.whispersystems.textsecuregcm.subscriptions.PaymentMethod;
import org.whispersystems.textsecuregcm.subscriptions.StripeManager; import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionCurrencyUtil; import org.whispersystems.textsecuregcm.subscriptions.SubscriptionCurrencyUtil;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessorManager; import org.whispersystems.textsecuregcm.subscriptions.SubscriptionPaymentProcessor;
import org.whispersystems.textsecuregcm.util.ExactlySize; import org.whispersystems.textsecuregcm.util.ExactlySize;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@ -170,7 +170,7 @@ public class OneTimeDonationController {
* @throws BadRequestException indicates validation failed. Inspect {@code response.error} for details * @throws BadRequestException indicates validation failed. Inspect {@code response.error} for details
*/ */
private void validateRequestCurrencyAmount(CreateBoostRequest request, BigDecimal amount, private void validateRequestCurrencyAmount(CreateBoostRequest request, BigDecimal amount,
SubscriptionProcessorManager manager) { SubscriptionPaymentProcessor manager) {
if (!manager.getSupportedCurrenciesForPaymentMethod(request.paymentMethod) if (!manager.getSupportedCurrenciesForPaymentMethod(request.paymentMethod)
.contains(request.currency.toLowerCase(Locale.ROOT))) { .contains(request.currency.toLowerCase(Locale.ROOT))) {
throw new BadRequestException(Response.status(Response.Status.BAD_REQUEST) throw new BadRequestException(Response.status(Response.Status.BAD_REQUEST)
@ -302,7 +302,7 @@ public class OneTimeDonationController {
public byte[] receiptCredentialRequest; public byte[] receiptCredentialRequest;
@NotNull @NotNull
public SubscriptionProcessor processor = SubscriptionProcessor.STRIPE; public PaymentProvider processor = PaymentProvider.STRIPE;
} }
public record CreateBoostReceiptCredentialsSuccessResponse(byte[] receiptCredentialResponse) { public record CreateBoostReceiptCredentialsSuccessResponse(byte[] receiptCredentialResponse) {

View File

@ -9,7 +9,6 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import com.stripe.exception.StripeException;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
@ -19,13 +18,10 @@ import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
@ -34,13 +30,10 @@ import java.util.Map.Entry;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotEmpty; import javax.validation.constraints.NotEmpty;
@ -50,11 +43,8 @@ import javax.ws.rs.ClientErrorException;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE; import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue; import javax.ws.rs.DefaultValue;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam; import javax.ws.rs.HeaderParam;
import javax.ws.rs.InternalServerErrorException;
import javax.ws.rs.NotFoundException;
import javax.ws.rs.POST; import javax.ws.rs.POST;
import javax.ws.rs.PUT; import javax.ws.rs.PUT;
import javax.ws.rs.Path; import javax.ws.rs.Path;
@ -66,11 +56,7 @@ import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.Response.Status;
import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequest;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialResponse; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialResponse;
import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@ -85,18 +71,20 @@ import org.whispersystems.textsecuregcm.entities.Badge;
import org.whispersystems.textsecuregcm.entities.PurchasableBadge; import org.whispersystems.textsecuregcm.entities.PurchasableBadge;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager; import org.whispersystems.textsecuregcm.storage.SubscriberCredentials;
import org.whispersystems.textsecuregcm.storage.SubscriptionException;
import org.whispersystems.textsecuregcm.storage.SubscriptionManager; import org.whispersystems.textsecuregcm.storage.SubscriptionManager;
import org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult; import org.whispersystems.textsecuregcm.storage.Subscriptions;
import org.whispersystems.textsecuregcm.subscriptions.BankMandateTranslator; import org.whispersystems.textsecuregcm.subscriptions.BankMandateTranslator;
import org.whispersystems.textsecuregcm.subscriptions.BankTransferType; import org.whispersystems.textsecuregcm.subscriptions.BankTransferType;
import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager; import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager;
import org.whispersystems.textsecuregcm.subscriptions.ChargeFailure; import org.whispersystems.textsecuregcm.subscriptions.ChargeFailure;
import org.whispersystems.textsecuregcm.subscriptions.PaymentMethod; import org.whispersystems.textsecuregcm.subscriptions.PaymentMethod;
import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer; import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer;
import org.whispersystems.textsecuregcm.subscriptions.StripeManager; import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.subscriptions.SubscriptionPaymentProcessor;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessorManager; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
@ -115,8 +103,6 @@ public class SubscriptionController {
private final SubscriptionManager subscriptionManager; private final SubscriptionManager subscriptionManager;
private final StripeManager stripeManager; private final StripeManager stripeManager;
private final BraintreeManager braintreeManager; private final BraintreeManager braintreeManager;
private final ServerZkReceiptOperations zkReceiptOperations;
private final IssuedReceiptsManager issuedReceiptsManager;
private final BadgeTranslator badgeTranslator; private final BadgeTranslator badgeTranslator;
private final LevelTranslator levelTranslator; private final LevelTranslator levelTranslator;
private final BankMandateTranslator bankMandateTranslator; private final BankMandateTranslator bankMandateTranslator;
@ -132,26 +118,22 @@ public class SubscriptionController {
@Nonnull SubscriptionManager subscriptionManager, @Nonnull SubscriptionManager subscriptionManager,
@Nonnull StripeManager stripeManager, @Nonnull StripeManager stripeManager,
@Nonnull BraintreeManager braintreeManager, @Nonnull BraintreeManager braintreeManager,
@Nonnull ServerZkReceiptOperations zkReceiptOperations,
@Nonnull IssuedReceiptsManager issuedReceiptsManager,
@Nonnull BadgeTranslator badgeTranslator, @Nonnull BadgeTranslator badgeTranslator,
@Nonnull LevelTranslator levelTranslator, @Nonnull LevelTranslator levelTranslator,
@Nonnull BankMandateTranslator bankMandateTranslator) { @Nonnull BankMandateTranslator bankMandateTranslator) {
this.subscriptionManager = subscriptionManager;
this.clock = Objects.requireNonNull(clock); this.clock = Objects.requireNonNull(clock);
this.subscriptionConfiguration = Objects.requireNonNull(subscriptionConfiguration); this.subscriptionConfiguration = Objects.requireNonNull(subscriptionConfiguration);
this.oneTimeDonationConfiguration = Objects.requireNonNull(oneTimeDonationConfiguration); this.oneTimeDonationConfiguration = Objects.requireNonNull(oneTimeDonationConfiguration);
this.subscriptionManager = Objects.requireNonNull(subscriptionManager);
this.stripeManager = Objects.requireNonNull(stripeManager); this.stripeManager = Objects.requireNonNull(stripeManager);
this.braintreeManager = Objects.requireNonNull(braintreeManager); this.braintreeManager = Objects.requireNonNull(braintreeManager);
this.zkReceiptOperations = Objects.requireNonNull(zkReceiptOperations);
this.issuedReceiptsManager = Objects.requireNonNull(issuedReceiptsManager);
this.badgeTranslator = Objects.requireNonNull(badgeTranslator); this.badgeTranslator = Objects.requireNonNull(badgeTranslator);
this.levelTranslator = Objects.requireNonNull(levelTranslator); this.levelTranslator = Objects.requireNonNull(levelTranslator);
this.bankMandateTranslator = Objects.requireNonNull(bankMandateTranslator); this.bankMandateTranslator = Objects.requireNonNull(bankMandateTranslator);
} }
private Map<String, CurrencyConfiguration> buildCurrencyConfiguration() { private Map<String, CurrencyConfiguration> buildCurrencyConfiguration() {
final List<SubscriptionProcessorManager> subscriptionProcessorManagers = List.of(stripeManager, braintreeManager); final List<SubscriptionPaymentProcessor> subscriptionPaymentProcessors = List.of(stripeManager, braintreeManager);
return oneTimeDonationConfiguration.currencies() return oneTimeDonationConfiguration.currencies()
.entrySet().stream() .entrySet().stream()
.collect(Collectors.toMap(Entry::getKey, currencyAndConfig -> { .collect(Collectors.toMap(Entry::getKey, currencyAndConfig -> {
@ -171,7 +153,7 @@ public class SubscriptionController {
levelIdAndConfig -> levelIdAndConfig.getValue().prices().get(currency).amount())); levelIdAndConfig -> levelIdAndConfig.getValue().prices().get(currency).amount()));
final List<String> supportedPaymentMethods = Arrays.stream(PaymentMethod.values()) final List<String> supportedPaymentMethods = Arrays.stream(PaymentMethod.values())
.filter(paymentMethod -> subscriptionProcessorManagers.stream() .filter(paymentMethod -> subscriptionPaymentProcessors.stream()
.anyMatch(manager -> manager.supportsPaymentMethod(paymentMethod) .anyMatch(manager -> manager.supportsPaymentMethod(paymentMethod)
&& manager.getSupportedCurrenciesForPaymentMethod(paymentMethod).contains(currency))) && manager.getSupportedCurrenciesForPaymentMethod(paymentMethod).contains(currency)))
.map(PaymentMethod::name) .map(PaymentMethod::name)
@ -236,20 +218,10 @@ public class SubscriptionController {
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> deleteSubscriber( public CompletableFuture<Response> deleteSubscriber(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount, @ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) { @PathParam("subscriberId") String subscriberId) throws SubscriptionException {
RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); SubscriberCredentials subscriberCredentials =
return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
.thenCompose(getResult -> { return subscriptionManager.deleteSubscriber(subscriberCredentials).thenApply(unused -> Response.ok().build());
if (getResult == GetResult.NOT_STORED || getResult == GetResult.PASSWORD_MISMATCH) {
throw new NotFoundException();
}
return getResult.record.getProcessorCustomer()
.map(processorCustomer -> getManagerForProcessor(processorCustomer.processor()).cancelAllActiveSubscriptions(processorCustomer.customerId()))
// a missing customer ID is OK; it means the subscriber never started to add a payment method
.orElseGet(() -> CompletableFuture.completedFuture(null));
})
.thenCompose(unused -> subscriptionManager.canceledAt(requestData.subscriberUser, requestData.now))
.thenApply(unused -> Response.ok().build());
} }
@PUT @PUT
@ -258,31 +230,13 @@ public class SubscriptionController {
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> updateSubscriber( public CompletableFuture<Response> updateSubscriber(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount, @ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) { @PathParam("subscriberId") String subscriberId) throws SubscriptionException {
RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); SubscriberCredentials subscriberCredentials =
return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
.thenCompose(getResult -> { return subscriptionManager.updateSubscriber(subscriberCredentials).thenApply(record -> Response.ok().build());
if (getResult == GetResult.PASSWORD_MISMATCH) {
throw new ForbiddenException("subscriberId mismatch");
} else if (getResult == GetResult.NOT_STORED) {
// create a customer and write it to ddb
return subscriptionManager.create(requestData.subscriberUser, requestData.hmac, requestData.now)
.thenApply(updatedRecord -> {
if (updatedRecord == null) {
throw new ForbiddenException();
}
return updatedRecord;
});
} else {
// already exists so just touch access time and return
return subscriptionManager.accessedAt(requestData.subscriberUser, requestData.now)
.thenApply(unused -> getResult.record);
}
})
.thenApply(record -> Response.ok().build());
} }
record CreatePaymentMethodResponse(String clientSecret, SubscriptionProcessor processor) { record CreatePaymentMethodResponse(String clientSecret, PaymentProvider processor) {
} }
@ -294,52 +248,25 @@ public class SubscriptionController {
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount, @ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId, @PathParam("subscriberId") String subscriberId,
@QueryParam("type") @DefaultValue("CARD") PaymentMethod paymentMethodType, @QueryParam("type") @DefaultValue("CARD") PaymentMethod paymentMethodType,
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString) { @HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString) throws SubscriptionException {
RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
final SubscriptionProcessorManager subscriptionProcessorManager = getManagerForPaymentMethod(paymentMethodType); final SubscriptionPaymentProcessor subscriptionPaymentProcessor = getManagerForPaymentMethod(paymentMethodType);
return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) return subscriptionManager.addPaymentMethodToCustomer(
.thenApply(this::requireRecordFromGetResult) subscriberCredentials,
.thenCompose(record -> { subscriptionPaymentProcessor,
final CompletableFuture<SubscriptionManager.Record> updatedRecordFuture = getClientPlatform(userAgentString),
record.getProcessorCustomer() SubscriptionPaymentProcessor::createPaymentMethodSetupToken)
.map(ProcessorCustomer::processor) .thenApply(token ->
.map(processor -> { Response.ok(new CreatePaymentMethodResponse(token, subscriptionPaymentProcessor.getProvider())).build());
if (processor != subscriptionProcessorManager.getProcessor()) {
throw new ClientErrorException("existing processor does not match", Status.CONFLICT);
}
return CompletableFuture.completedFuture(record);
})
.orElseGet(() -> subscriptionProcessorManager.createCustomer(requestData.subscriberUser, getClientPlatform(userAgentString))
.thenApply(ProcessorCustomer::customerId)
.thenCompose(customerId -> subscriptionManager.setProcessorAndCustomerId(record,
new ProcessorCustomer(customerId, subscriptionProcessorManager.getProcessor()),
Instant.now())));
return updatedRecordFuture.thenCompose(
updatedRecord -> {
final String customerId = updatedRecord.getProcessorCustomer()
.filter(pc -> pc.processor().equals(subscriptionProcessorManager.getProcessor()))
.orElseThrow(() -> new InternalServerErrorException("record should not be missing customer"))
.customerId();
return subscriptionProcessorManager.createPaymentMethodSetupToken(customerId);
});
})
.thenApply(
token -> Response.ok(new CreatePaymentMethodResponse(token, subscriptionProcessorManager.getProcessor()))
.build());
} }
public record CreatePayPalBillingAgreementRequest(@NotBlank String returnUrl, @NotBlank String cancelUrl) { public record CreatePayPalBillingAgreementRequest(@NotBlank String returnUrl, @NotBlank String cancelUrl) {}
} public record CreatePayPalBillingAgreementResponse(@NotBlank String approvalUrl, @NotBlank String token) {}
public record CreatePayPalBillingAgreementResponse(@NotBlank String approvalUrl, @NotBlank String token) {
}
@POST @POST
@Path("/{subscriberId}/create_payment_method/paypal") @Path("/{subscriberId}/create_payment_method/paypal")
@ -350,48 +277,29 @@ public class SubscriptionController {
@PathParam("subscriberId") String subscriberId, @PathParam("subscriberId") String subscriberId,
@NotNull @Valid CreatePayPalBillingAgreementRequest request, @NotNull @Valid CreatePayPalBillingAgreementRequest request,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString) { @HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString) throws SubscriptionException {
RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); final SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
final Locale locale = HeaderUtils.getAcceptableLanguagesForRequest(containerRequestContext).stream()
.filter(l -> !"*".equals(l.getLanguage()))
.findFirst()
.orElse(Locale.US);
return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) return subscriptionManager.addPaymentMethodToCustomer(
.thenApply(this::requireRecordFromGetResult) subscriberCredentials,
.thenCompose(record -> { braintreeManager,
getClientPlatform(userAgentString),
final CompletableFuture<SubscriptionManager.Record> updatedRecordFuture = (mgr, customerId) ->
record.getProcessorCustomer() mgr.createPayPalBillingAgreement(request.returnUrl, request.cancelUrl, locale.toLanguageTag()))
.map(ProcessorCustomer::processor) .thenApply(billingAgreementApprovalDetails -> Response.ok(
.map(processor -> { new CreatePayPalBillingAgreementResponse(
if (processor != braintreeManager.getProcessor()) { billingAgreementApprovalDetails.approvalUrl(),
throw new ClientErrorException("existing processor does not match", Status.CONFLICT); billingAgreementApprovalDetails.billingAgreementToken()))
} .build());
return CompletableFuture.completedFuture(record);
})
.orElseGet(() -> braintreeManager.createCustomer(requestData.subscriberUser, getClientPlatform(userAgentString))
.thenApply(ProcessorCustomer::customerId)
.thenCompose(customerId -> subscriptionManager.setProcessorAndCustomerId(record,
new ProcessorCustomer(customerId, braintreeManager.getProcessor()),
Instant.now())));
return updatedRecordFuture.thenCompose(
updatedRecord -> {
final Locale locale = HeaderUtils.getAcceptableLanguagesForRequest(containerRequestContext).stream()
.filter(l -> !"*".equals(l.getLanguage()))
.findFirst()
.orElse(Locale.US);
return braintreeManager.createPayPalBillingAgreement(request.returnUrl, request.cancelUrl,
locale.toLanguageTag());
});
})
.thenApply(
billingAgreementApprovalDetails -> Response.ok(
new CreatePayPalBillingAgreementResponse(billingAgreementApprovalDetails.approvalUrl(),
billingAgreementApprovalDetails.billingAgreementToken()))
.build());
} }
private SubscriptionProcessorManager getManagerForPaymentMethod(PaymentMethod paymentMethod) { private SubscriptionPaymentProcessor getManagerForPaymentMethod(PaymentMethod paymentMethod) {
return switch (paymentMethod) { return switch (paymentMethod) {
case CARD, SEPA_DEBIT, IDEAL -> stripeManager; case CARD, SEPA_DEBIT, IDEAL -> stripeManager;
case PAYPAL -> braintreeManager; case PAYPAL -> braintreeManager;
@ -399,7 +307,7 @@ public class SubscriptionController {
}; };
} }
private SubscriptionProcessorManager getManagerForProcessor(SubscriptionProcessor processor) { private SubscriptionPaymentProcessor getManagerForProcessor(PaymentProvider processor) {
return switch (processor) { return switch (processor) {
case STRIPE -> stripeManager; case STRIPE -> stripeManager;
case BRAINTREE -> braintreeManager; case BRAINTREE -> braintreeManager;
@ -413,13 +321,14 @@ public class SubscriptionController {
public CompletableFuture<Response> setDefaultPaymentMethodWithProcessor( public CompletableFuture<Response> setDefaultPaymentMethodWithProcessor(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount, @ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId, @PathParam("subscriberId") String subscriberId,
@PathParam("processor") SubscriptionProcessor processor, @PathParam("processor") PaymentProvider processor,
@PathParam("paymentMethodToken") @NotEmpty String paymentMethodToken) { @PathParam("paymentMethodToken") @NotEmpty String paymentMethodToken) throws SubscriptionException {
RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
final SubscriptionProcessorManager manager = getManagerForProcessor(processor); final SubscriptionPaymentProcessor manager = getManagerForProcessor(processor);
return setDefaultPaymentMethod(manager, paymentMethodToken, requestData); return setDefaultPaymentMethod(manager, paymentMethodToken, subscriberCredentials);
} }
public record SetSubscriptionLevelSuccessResponse(long level) { public record SetSubscriptionLevelSuccessResponse(long level) {
@ -446,12 +355,11 @@ public class SubscriptionController {
@PathParam("subscriberId") String subscriberId, @PathParam("subscriberId") String subscriberId,
@PathParam("level") long level, @PathParam("level") long level,
@PathParam("currency") String currency, @PathParam("currency") String currency,
@PathParam("idempotencyKey") String idempotencyKey) { @PathParam("idempotencyKey") String idempotencyKey) throws SubscriptionException {
RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); SubscriberCredentials subscriberCredentials =
return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
.thenApply(this::requireRecordFromGetResult) return subscriptionManager.getSubscriber(subscriberCredentials)
.thenCompose(record -> { .thenCompose(record -> {
final ProcessorCustomer processorCustomer = record.getProcessorCustomer() final ProcessorCustomer processorCustomer = record.getProcessorCustomer()
.orElseThrow(() -> .orElseThrow(() ->
// a missing customer ID indicates the client made requests out of order, // a missing customer ID indicates the client made requests out of order,
@ -461,64 +369,25 @@ public class SubscriptionController {
final String subscriptionTemplateId = getSubscriptionTemplateId(level, currency, final String subscriptionTemplateId = getSubscriptionTemplateId(level, currency,
processorCustomer.processor()); processorCustomer.processor());
final SubscriptionProcessorManager manager = getManagerForProcessor(processorCustomer.processor()); final SubscriptionPaymentProcessor manager = getManagerForProcessor(processorCustomer.processor());
return subscriptionManager.updateSubscriptionLevelForCustomer(subscriberCredentials, record, manager, level,
return Optional.ofNullable(record.subscriptionId).map(subId -> { currency, idempotencyKey, subscriptionTemplateId, this::subscriptionsAreSameType);
// we already have a subscription in our records so let's check the level and currency,
// and only change it if needed
return manager.getSubscription(subId).thenCompose(
subscription -> manager.getLevelAndCurrencyForSubscription(subscription)
.thenCompose(existingLevelAndCurrency -> {
if (existingLevelAndCurrency.equals(new SubscriptionProcessorManager.LevelAndCurrency(level,
currency.toLowerCase(Locale.ROOT)))) {
return CompletableFuture.completedFuture(subscription);
}
if (!subscriptionsAreSameType(existingLevelAndCurrency.level(), level)) {
throw new BadRequestException(Response.status(Status.BAD_REQUEST)
.entity(new SetSubscriptionLevelErrorResponse(List.of(
new SetSubscriptionLevelErrorResponse.Error(
SetSubscriptionLevelErrorResponse.Error.Type.UNSUPPORTED_LEVEL, null))))
.build());
}
return manager.updateSubscription(
subscription, subscriptionTemplateId, level, idempotencyKey)
.thenCompose(updatedSubscription ->
subscriptionManager.subscriptionLevelChanged(requestData.subscriberUser,
requestData.now,
level, updatedSubscription.id())
.thenApply(unused -> updatedSubscription));
}));
}).orElseGet(() -> {
long lastSubscriptionCreatedAt =
record.subscriptionCreatedAt != null ? record.subscriptionCreatedAt.getEpochSecond() : 0;
// we don't have a subscription yet so create it and then record the subscription id
return manager.createSubscription(processorCustomer.customerId(),
subscriptionTemplateId,
level,
lastSubscriptionCreatedAt)
.exceptionally(e -> {
if (e.getCause() instanceof StripeException stripeException
&& "subscription_payment_intent_requires_action".equals(stripeException.getCode())) {
throw new BadRequestException(Response.status(Status.BAD_REQUEST)
.entity(new SetSubscriptionLevelErrorResponse(List.of(
new SetSubscriptionLevelErrorResponse.Error(
SetSubscriptionLevelErrorResponse.Error.Type.PAYMENT_REQUIRES_ACTION, null
)
))).build());
}
if (e instanceof RuntimeException re) {
throw re;
}
throw new CompletionException(e);
})
.thenCompose(subscription -> subscriptionManager.subscriptionCreated(
requestData.subscriberUser, subscription.id(), requestData.now, level)
.thenApply(unused -> subscription));
});
}) })
.thenApply(unused -> Response.ok(new SetSubscriptionLevelSuccessResponse(level)).build()); .exceptionally(ExceptionUtils.exceptionallyHandler(SubscriptionException.InvalidLevel.class, e -> {
throw new BadRequestException(Response.status(Response.Status.BAD_REQUEST)
.entity(new SubscriptionController.SetSubscriptionLevelErrorResponse(List.of(
new SubscriptionController.SetSubscriptionLevelErrorResponse.Error(
SubscriptionController.SetSubscriptionLevelErrorResponse.Error.Type.UNSUPPORTED_LEVEL,
null))))
.build());
}))
.exceptionally(ExceptionUtils.exceptionallyHandler(SubscriptionException.PaymentRequiresAction.class, e -> {
throw new BadRequestException(Response.status(Response.Status.BAD_REQUEST)
.entity(new SetSubscriptionLevelErrorResponse(List.of(new SetSubscriptionLevelErrorResponse.Error(
SetSubscriptionLevelErrorResponse.Error.Type.PAYMENT_REQUIRES_ACTION, null))))
.build());
}))
.thenApply(unused -> Response.ok(new SetSubscriptionLevelSuccessResponse(level)).build());
} }
public boolean subscriptionsAreSameType(long level1, long level2) { public boolean subscriptionsAreSameType(long level1, long level2) {
@ -608,7 +477,7 @@ public class SubscriptionController {
public record Subscription(long level, Instant billingCycleAnchor, Instant endOfCurrentPeriod, boolean active, public record Subscription(long level, Instant billingCycleAnchor, Instant endOfCurrentPeriod, boolean active,
boolean cancelAtPeriodEnd, String currency, BigDecimal amount, String status, boolean cancelAtPeriodEnd, String currency, BigDecimal amount, String status,
SubscriptionProcessor processor, PaymentMethod paymentMethod, boolean paymentProcessing) { PaymentProvider processor, PaymentMethod paymentMethod, boolean paymentProcessing) {
} }
} }
@ -618,16 +487,15 @@ public class SubscriptionController {
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> getSubscriptionInformation( public CompletableFuture<Response> getSubscriptionInformation(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount, @ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) { @PathParam("subscriberId") String subscriberId) throws SubscriptionException {
RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); SubscriberCredentials subscriberCredentials = SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) return subscriptionManager.getSubscriber(subscriberCredentials)
.thenApply(this::requireRecordFromGetResult)
.thenCompose(record -> { .thenCompose(record -> {
if (record.subscriptionId == null) { if (record.subscriptionId == null) {
return CompletableFuture.completedFuture(Response.ok(new GetSubscriptionInformationResponse(null, null)).build()); return CompletableFuture.completedFuture(Response.ok(new GetSubscriptionInformationResponse(null, null)).build());
} }
final SubscriptionProcessorManager manager = getManagerForProcessor(record.getProcessorCustomer().orElseThrow().processor()); final SubscriptionPaymentProcessor manager = getManagerForProcessor(record.getProcessorCustomer().orElseThrow().processor());
return manager.getSubscription(record.subscriptionId).thenCompose(subscription -> return manager.getSubscription(record.subscriptionId).thenCompose(subscription ->
manager.getSubscriptionInformation(subscription).thenApply(subscriptionInformation -> Response.ok( manager.getSubscriptionInformation(subscription).thenApply(subscriptionInformation -> Response.ok(
@ -641,7 +509,7 @@ public class SubscriptionController {
subscriptionInformation.price().currency(), subscriptionInformation.price().currency(),
subscriptionInformation.price().amount(), subscriptionInformation.price().amount(),
subscriptionInformation.status().getApiValue(), subscriptionInformation.status().getApiValue(),
manager.getProcessor(), manager.getProvider(),
subscriptionInformation.paymentMethod(), subscriptionInformation.paymentMethod(),
subscriptionInformation.paymentProcessing()), subscriptionInformation.paymentProcessing()),
subscriptionInformation.chargeFailure() subscriptionInformation.chargeFailure()
@ -663,49 +531,22 @@ public class SubscriptionController {
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount, @ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@PathParam("subscriberId") String subscriberId, @PathParam("subscriberId") String subscriberId,
@NotNull @Valid GetReceiptCredentialsRequest request) { @NotNull @Valid GetReceiptCredentialsRequest request) throws SubscriptionException {
RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); SubscriberCredentials subscriberCredentials = SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) return subscriptionManager.createReceiptCredentials(subscriberCredentials, request, this::receiptExpirationWithGracePeriod)
.thenApply(this::requireRecordFromGetResult) .thenApply(receiptCredential -> {
.thenCompose(record -> { final ReceiptCredentialResponse receiptCredentialResponse = receiptCredential.receiptCredentialResponse();
if (record.subscriptionId == null) { final SubscriptionPaymentProcessor.ReceiptItem receipt = receiptCredential.receiptItem();
return CompletableFuture.completedFuture(Response.status(Status.NOT_FOUND).build()); Metrics.counter(RECEIPT_ISSUED_COUNTER_NAME,
} Tags.of(
ReceiptCredentialRequest receiptCredentialRequest; Tag.of(PROCESSOR_TAG_NAME, receiptCredential.paymentProvider().toString()),
try { Tag.of(TYPE_TAG_NAME, "subscription"),
receiptCredentialRequest = new ReceiptCredentialRequest(request.receiptCredentialRequest()); Tag.of(SUBSCRIPTION_TYPE_TAG_NAME,
} catch (InvalidInputException e) { subscriptionConfiguration.getSubscriptionLevel(receipt.level()).type().name()
throw new BadRequestException("invalid receipt credential request", e); .toLowerCase(Locale.ROOT)),
} UserAgentTagUtil.getPlatformTag(userAgent)))
.increment();
final SubscriptionProcessorManager manager = getManagerForProcessor(record.getProcessorCustomer().orElseThrow().processor()); return Response.ok(new GetReceiptCredentialsResponse(receiptCredentialResponse.serialize())).build();
return manager.getReceiptItem(record.subscriptionId)
.thenCompose(receipt -> issuedReceiptsManager.recordIssuance(
receipt.itemId(), manager.getProcessor(), receiptCredentialRequest,
requestData.now)
.thenApply(unused -> receipt))
.thenApply(receipt -> {
ReceiptCredentialResponse receiptCredentialResponse;
try {
receiptCredentialResponse = zkReceiptOperations.issueReceiptCredential(
receiptCredentialRequest,
receiptExpirationWithGracePeriod(receipt.paidAt(), receipt.level()).getEpochSecond(),
receipt.level());
} catch (VerificationFailedException e) {
throw new BadRequestException("receipt credential request failed verification", e);
}
Metrics.counter(RECEIPT_ISSUED_COUNTER_NAME,
Tags.of(
Tag.of(PROCESSOR_TAG_NAME, manager.getProcessor().toString()),
Tag.of(TYPE_TAG_NAME, "subscription"),
Tag.of(SUBSCRIPTION_TYPE_TAG_NAME,
subscriptionConfiguration.getSubscriptionLevel(receipt.level()).type().name()
.toLowerCase(Locale.ROOT)),
UserAgentTagUtil.getPlatformTag(userAgent)))
.increment();
return Response.ok(new GetReceiptCredentialsResponse(receiptCredentialResponse.serialize()))
.build();
});
}); });
} }
@ -715,18 +556,18 @@ public class SubscriptionController {
public CompletableFuture<Response> setDefaultPaymentMethodForIdeal( public CompletableFuture<Response> setDefaultPaymentMethodForIdeal(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount, @ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId, @PathParam("subscriberId") String subscriberId,
@PathParam("setupIntentId") @NotEmpty String setupIntentId) { @PathParam("setupIntentId") @NotEmpty String setupIntentId) throws SubscriptionException {
RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
return stripeManager.getGeneratedSepaIdFromSetupIntent(setupIntentId) return stripeManager.getGeneratedSepaIdFromSetupIntent(setupIntentId)
.thenCompose(generatedSepaId -> setDefaultPaymentMethod(stripeManager, generatedSepaId, requestData)); .thenCompose(generatedSepaId -> setDefaultPaymentMethod(stripeManager, generatedSepaId, subscriberCredentials));
} }
private CompletableFuture<Response> setDefaultPaymentMethod(final SubscriptionProcessorManager manager, private CompletableFuture<Response> setDefaultPaymentMethod(final SubscriptionPaymentProcessor manager,
final String paymentMethodId, final String paymentMethodId,
final RequestData requestData) { final SubscriberCredentials requestData) {
return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) return subscriptionManager.getSubscriber(requestData)
.thenApply(this::requireRecordFromGetResult)
.thenCompose(record -> record.getProcessorCustomer() .thenCompose(record -> record.getProcessorCustomer()
.map(processorCustomer -> manager.setDefaultPaymentMethodForCustomer(processorCustomer.customerId(), .map(processorCustomer -> manager.setDefaultPaymentMethodForCustomer(processorCustomer.customerId(),
paymentMethodId, record.subscriptionId)) paymentMethodId, record.subscriptionId))
@ -737,8 +578,9 @@ public class SubscriptionController {
.thenApply(customer -> Response.ok().build()); .thenApply(customer -> Response.ok().build());
} }
private Instant receiptExpirationWithGracePeriod(Instant paidAt, long level) { private Instant receiptExpirationWithGracePeriod(SubscriptionPaymentProcessor.ReceiptItem receiptItem) {
return switch (subscriptionConfiguration.getSubscriptionLevel(level).type()) { final Instant paidAt = receiptItem.paidAt();
return switch (subscriptionConfiguration.getSubscriptionLevel(receiptItem.level()).type()) {
case DONATION -> paidAt.plus(subscriptionConfiguration.getBadgeExpiration()) case DONATION -> paidAt.plus(subscriptionConfiguration.getBadgeExpiration())
.plus(subscriptionConfiguration.getBadgeGracePeriod()) .plus(subscriptionConfiguration.getBadgeGracePeriod())
.truncatedTo(ChronoUnit.DAYS) .truncatedTo(ChronoUnit.DAYS)
@ -750,7 +592,7 @@ public class SubscriptionController {
} }
private String getSubscriptionTemplateId(long level, String currency, SubscriptionProcessor processor) { private String getSubscriptionTemplateId(long level, String currency, PaymentProvider processor) {
final SubscriptionLevelConfiguration config = subscriptionConfiguration.getSubscriptionLevel(level); final SubscriptionLevelConfiguration config = subscriptionConfiguration.getSubscriptionLevel(level);
if (config == null) { if (config == null) {
throw new BadRequestException(Response.status(Status.BAD_REQUEST) throw new BadRequestException(Response.status(Status.BAD_REQUEST)
@ -769,16 +611,6 @@ public class SubscriptionController {
.build())); .build()));
} }
private SubscriptionManager.Record requireRecordFromGetResult(SubscriptionManager.GetResult getResult) {
if (getResult == GetResult.PASSWORD_MISMATCH) {
throw new ForbiddenException("subscriberId mismatch");
} else if (getResult == GetResult.NOT_STORED) {
throw new NotFoundException();
} else {
return getResult.record;
}
}
@Nullable @Nullable
private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) { private static ClientPlatform getClientPlatform(@Nullable final String userAgentString) {
try { try {
@ -787,60 +619,4 @@ public class SubscriptionController {
return null; return null;
} }
} }
private record RequestData(@Nonnull byte[] subscriberBytes,
@Nonnull byte[] subscriberUser,
@Nonnull byte[] subscriberKey,
@Nonnull byte[] hmac,
@Nonnull Instant now) {
public static RequestData process(
Optional<AuthenticatedDevice> authenticatedAccount,
String subscriberId,
Clock clock) {
Instant now = clock.instant();
if (authenticatedAccount.isPresent()) {
throw new ForbiddenException("must not use authenticated connection for subscriber operations");
}
byte[] subscriberBytes = convertSubscriberIdStringToBytes(subscriberId);
byte[] subscriberUser = getUser(subscriberBytes);
byte[] subscriberKey = getKey(subscriberBytes);
byte[] hmac = computeHmac(subscriberUser, subscriberKey);
return new RequestData(subscriberBytes, subscriberUser, subscriberKey, hmac, now);
}
private static byte[] convertSubscriberIdStringToBytes(String subscriberId) {
try {
byte[] bytes = Base64.getUrlDecoder().decode(subscriberId);
if (bytes.length != 32) {
throw new NotFoundException();
}
return bytes;
} catch (IllegalArgumentException e) {
throw new NotFoundException(e);
}
}
private static byte[] getUser(byte[] subscriberBytes) {
byte[] user = new byte[16];
System.arraycopy(subscriberBytes, 0, user, 0, user.length);
return user;
}
private static byte[] getKey(byte[] subscriberBytes) {
byte[] key = new byte[16];
System.arraycopy(subscriberBytes, 16, key, 0, key.length);
return key;
}
private static byte[] computeHmac(byte[] subscriberUser, byte[] subscriberKey) {
try {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(subscriberKey, "HmacSHA256"));
return mac.doFinal(subscriberUser);
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new InternalServerErrorException(e);
}
}
}
} }

View File

@ -0,0 +1,31 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.mappers;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.ClientErrorException;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.InternalServerErrorException;
import javax.ws.rs.NotFoundException;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import org.whispersystems.textsecuregcm.storage.SubscriptionException;
public class SubscriptionExceptionMapper implements ExceptionMapper<SubscriptionException> {
@Override
public Response toResponse(final SubscriptionException exception) {
return switch (exception) {
case SubscriptionException.NotFound e -> new NotFoundException(e.getMessage(), e.getCause()).getResponse();
case SubscriptionException.Forbidden e -> new ForbiddenException(e.getMessage(), e.getCause()).getResponse();
case SubscriptionException.InvalidArguments e ->
new BadRequestException(e.getMessage(), e.getCause()).getResponse();
case SubscriptionException.ProcessorConflict e ->
new ClientErrorException("existing processor does not match", Response.Status.CONFLICT).getResponse();
default -> new InternalServerErrorException(exception.getMessage(), exception.getCause()).getResponse();
};
}
}

View File

@ -26,7 +26,7 @@ import javax.crypto.spec.SecretKeySpec;
import javax.ws.rs.ClientErrorException; import javax.ws.rs.ClientErrorException;
import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.Response.Status;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequest; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequest;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
@ -62,17 +62,17 @@ public class IssuedReceiptsManager {
* <p> * <p>
* If this item has already been used to issue another receipt, throws a 409 conflict web application exception. * If this item has already been used to issue another receipt, throws a 409 conflict web application exception.
* <p> * <p>
* For {@link SubscriptionProcessor#STRIPE}, item is expected to refer to an invoice line item (subscriptions) or a * For {@link PaymentProvider#STRIPE}, item is expected to refer to an invoice line item (subscriptions) or a
* payment intent (one-time). * payment intent (one-time).
*/ */
public CompletableFuture<Void> recordIssuance( public CompletableFuture<Void> recordIssuance(
String processorItemId, String processorItemId,
SubscriptionProcessor processor, PaymentProvider processor,
ReceiptCredentialRequest request, ReceiptCredentialRequest request,
Instant now) { Instant now) {
final AttributeValue key; final AttributeValue key;
if (processor == SubscriptionProcessor.STRIPE) { if (processor == PaymentProvider.STRIPE) {
// As the first processor, Stripes IDs were not prefixed. Its item IDs have documented prefixes (`il_`, `pi_`) // As the first processor, Stripes IDs were not prefixed. Its item IDs have documented prefixes (`il_`, `pi_`)
// that will not collide with `SubscriptionProcessor` names // that will not collide with `SubscriptionProcessor` names
key = s(processorItemId); key = s(processorItemId);

View File

@ -0,0 +1,73 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import javax.annotation.Nonnull;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.ws.rs.InternalServerErrorException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.time.Clock;
import java.time.Instant;
import java.util.Base64;
import java.util.Optional;
public record SubscriberCredentials(@Nonnull byte[] subscriberBytes,
@Nonnull byte[] subscriberUser,
@Nonnull byte[] subscriberKey,
@Nonnull byte[] hmac,
@Nonnull Instant now) {
public static SubscriberCredentials process(
Optional<AuthenticatedDevice> authenticatedAccount,
String subscriberId,
Clock clock) throws SubscriptionException{
Instant now = clock.instant();
if (authenticatedAccount.isPresent()) {
throw new SubscriptionException.Forbidden("must not use authenticated connection for subscriber operations");
}
byte[] subscriberBytes = convertSubscriberIdStringToBytes(subscriberId);
byte[] subscriberUser = getUser(subscriberBytes);
byte[] subscriberKey = getKey(subscriberBytes);
byte[] hmac = computeHmac(subscriberUser, subscriberKey);
return new SubscriberCredentials(subscriberBytes, subscriberUser, subscriberKey, hmac, now);
}
private static byte[] convertSubscriberIdStringToBytes(String subscriberId) throws SubscriptionException.NotFound {
try {
byte[] bytes = Base64.getUrlDecoder().decode(subscriberId);
if (bytes.length != 32) {
throw new SubscriptionException.NotFound();
}
return bytes;
} catch (IllegalArgumentException e) {
throw new SubscriptionException.NotFound(e);
}
}
private static byte[] getUser(byte[] subscriberBytes) {
byte[] user = new byte[16];
System.arraycopy(subscriberBytes, 0, user, 0, user.length);
return user;
}
private static byte[] getKey(byte[] subscriberBytes) {
byte[] key = new byte[16];
System.arraycopy(subscriberBytes, 16, key, 0, key.length);
return key;
}
private static byte[] computeHmac(byte[] subscriberUser, byte[] subscriberKey) {
try {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(subscriberKey, "HmacSHA256"));
return mac.doFinal(subscriberUser);
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new InternalServerErrorException(e);
}
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class SubscriptionException extends Exception {
public SubscriptionException(String message, Exception cause) {
super(message, cause);
}
public static class NotFound extends SubscriptionException {
public NotFound() {
super(null, null);
}
public NotFound(Exception cause) {
super(null, cause);
}
}
public static class Forbidden extends SubscriptionException {
public Forbidden(final String message) {
super(message, null);
}
}
public static class InvalidArguments extends SubscriptionException {
public InvalidArguments(final String message, final Exception cause) {
super(message, cause);
}
}
public static class InvalidLevel extends InvalidArguments {
public InvalidLevel() {
super(null, null);
}
}
public static class PaymentRequiresAction extends InvalidArguments {
public PaymentRequiresAction() {
super(null, null);
}
}
public static class ProcessorConflict extends SubscriptionException {
public ProcessorConflict(final String message) {
super(message, null);
}
}
}

View File

@ -1,429 +1,359 @@
/* /*
* Copyright 2021 Signal Messenger, LLC * Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.util.AttributeValues.b; import com.stripe.exception.StripeException;
import static org.whispersystems.textsecuregcm.util.AttributeValues.n;
import static org.whispersystems.textsecuregcm.util.AttributeValues.s;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.time.Instant; import java.time.Instant;
import java.util.Map; import java.util.EnumMap;
import java.util.List;
import java.util.Locale;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import org.signal.libsignal.zkgroup.InvalidInputException;
import javax.ws.rs.ClientErrorException; import org.signal.libsignal.zkgroup.VerificationFailedException;
import javax.ws.rs.core.Response; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequest;
import org.slf4j.Logger; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialResponse;
import org.slf4j.LoggerFactory; import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.whispersystems.textsecuregcm.controllers.SubscriptionController;
import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer; import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.subscriptions.SubscriptionPaymentProcessor;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import org.whispersystems.textsecuregcm.util.Util;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
/**
* Manages updates to the Subscriptions table and the upstream subscription payment providers.
* <p>
* This handles a number of common subscription management operations like adding/removing subscribers and creating ZK
* receipt credentials for a subscriber's active subscription. Some subscription management operations only apply to
* certain payment providers. In those cases, the operation will take the payment provider that implements the specific
* functionality as an argument to the method.
*/
public class SubscriptionManager { public class SubscriptionManager {
private static final Logger logger = LoggerFactory.getLogger(SubscriptionManager.class); private final Subscriptions subscriptions;
private final EnumMap<PaymentProvider, Processor> processors;
private static final int USER_LENGTH = 16; private final ServerZkReceiptOperations zkReceiptOperations;
private final IssuedReceiptsManager issuedReceiptsManager;
public static final String KEY_USER = "U"; // B (Hash Key)
public static final String KEY_PASSWORD = "P"; // B
public static final String KEY_PROCESSOR_ID_CUSTOMER_ID = "PC"; // B (GSI Hash Key of `pc_to_u` index)
public static final String KEY_CREATED_AT = "R"; // N
public static final String KEY_SUBSCRIPTION_ID = "S"; // S
public static final String KEY_SUBSCRIPTION_CREATED_AT = "T"; // N
public static final String KEY_SUBSCRIPTION_LEVEL = "L";
public static final String KEY_SUBSCRIPTION_LEVEL_CHANGED_AT = "V"; // N
public static final String KEY_ACCESSED_AT = "A"; // N
public static final String KEY_CANCELED_AT = "B"; // N
public static final String KEY_CURRENT_PERIOD_ENDS_AT = "D"; // N
public static final String INDEX_NAME = "pc_to_u"; // Hash Key "PC"
public static class Record {
public final byte[] user;
public final byte[] password;
public final Instant createdAt;
@VisibleForTesting
@Nullable
ProcessorCustomer processorCustomer;
@Nullable
public String subscriptionId;
public Instant subscriptionCreatedAt;
public Long subscriptionLevel;
public Instant subscriptionLevelChangedAt;
public Instant accessedAt;
public Instant canceledAt;
public Instant currentPeriodEndsAt;
private Record(byte[] user, byte[] password, Instant createdAt) {
this.user = checkUserLength(user);
this.password = Objects.requireNonNull(password);
this.createdAt = Objects.requireNonNull(createdAt);
}
public static Record from(byte[] user, Map<String, AttributeValue> item) {
Record record = new Record(
user,
item.get(KEY_PASSWORD).b().asByteArray(),
getInstant(item, KEY_CREATED_AT));
final Pair<SubscriptionProcessor, String> processorCustomerId = getProcessorAndCustomer(item);
if (processorCustomerId != null) {
record.processorCustomer = new ProcessorCustomer(processorCustomerId.second(), processorCustomerId.first());
}
record.subscriptionId = getString(item, KEY_SUBSCRIPTION_ID);
record.subscriptionCreatedAt = getInstant(item, KEY_SUBSCRIPTION_CREATED_AT);
record.subscriptionLevel = getLong(item, KEY_SUBSCRIPTION_LEVEL);
record.subscriptionLevelChangedAt = getInstant(item, KEY_SUBSCRIPTION_LEVEL_CHANGED_AT);
record.accessedAt = getInstant(item, KEY_ACCESSED_AT);
record.canceledAt = getInstant(item, KEY_CANCELED_AT);
record.currentPeriodEndsAt = getInstant(item, KEY_CURRENT_PERIOD_ENDS_AT);
return record;
}
public Optional<ProcessorCustomer> getProcessorCustomer() {
return Optional.ofNullable(processorCustomer);
}
/**
* Extracts the active processor and customer from a single attribute value in the given item.
* <p>
* Until existing data is migrated, this may return {@code null}.
*/
@Nullable
private static Pair<SubscriptionProcessor, String> getProcessorAndCustomer(Map<String, AttributeValue> item) {
final AttributeValue attributeValue = item.get(KEY_PROCESSOR_ID_CUSTOMER_ID);
if (attributeValue == null) {
// temporarily allow null values
return null;
}
final byte[] processorAndCustomerId = attributeValue.b().asByteArray();
final byte processorId = processorAndCustomerId[0];
final SubscriptionProcessor processor = SubscriptionProcessor.forId(processorId);
if (processor == null) {
throw new IllegalStateException("unknown processor id: " + processorId);
}
final String customerId = new String(processorAndCustomerId, 1, processorAndCustomerId.length - 1,
StandardCharsets.UTF_8);
return new Pair<>(processor, customerId);
}
private static String getString(Map<String, AttributeValue> item, String key) {
AttributeValue attributeValue = item.get(key);
if (attributeValue == null) {
return null;
}
return attributeValue.s();
}
private static Long getLong(Map<String, AttributeValue> item, String key) {
AttributeValue attributeValue = item.get(key);
if (attributeValue == null || attributeValue.n() == null) {
return null;
}
return Long.valueOf(attributeValue.n());
}
private static Instant getInstant(Map<String, AttributeValue> item, String key) {
AttributeValue attributeValue = item.get(key);
if (attributeValue == null || attributeValue.n() == null) {
return null;
}
return Instant.ofEpochSecond(Long.parseLong(attributeValue.n()));
}
}
private final String table;
private final DynamoDbAsyncClient client;
public SubscriptionManager( public SubscriptionManager(
@Nonnull String table, @Nonnull Subscriptions subscriptions,
@Nonnull DynamoDbAsyncClient client) { @Nonnull List<Processor> processors,
this.table = Objects.requireNonNull(table); @Nonnull ServerZkReceiptOperations zkReceiptOperations,
this.client = Objects.requireNonNull(client); @Nonnull IssuedReceiptsManager issuedReceiptsManager) {
this.subscriptions = Objects.requireNonNull(subscriptions);
this.processors = new EnumMap<>(processors.stream()
.collect(Collectors.toMap(Processor::getProvider, Function.identity())));
this.zkReceiptOperations = Objects.requireNonNull(zkReceiptOperations);
this.issuedReceiptsManager = Objects.requireNonNull(issuedReceiptsManager);
}
public interface Processor {
PaymentProvider getProvider();
/**
* A receipt of payment from a payment provider
*
* @param itemId An identifier for the payment that should be unique within the payment provider. Note that this
* must identify an actual individual charge, not the subscription as a whole.
* @param paidAt The time this payment was made
* @param level The level which this payment corresponds to
*/
record ReceiptItem(String itemId, Instant paidAt, long level) {}
/**
* Retrieve a {@link ReceiptItem} for the subscriptionId stored in the subscriptions table
*
* @param subscriptionId A subscriptionId that potentially corresponds to a valid subscription
* @return A {@link ReceiptItem} if the subscription is valid
*/
CompletableFuture<ReceiptItem> getReceiptItem(String subscriptionId);
/**
* Cancel all active subscriptions for this key within the payment provider.
*
* @param key An identifier for the subscriber within the payment provider, corresponds to the customerId field in
* the subscriptions table
* @return A stage that completes when all subscriptions associated with the key are cancelled
*/
CompletableFuture<Void> cancelAllActiveSubscriptions(String key);
} }
/** /**
* Looks in the GSI for a record with the given customer id and returns the user id. * Cancel a subscription with the upstream payment provider and remove the subscription from the table
*/
public CompletableFuture<byte[]> getSubscriberUserByProcessorCustomer(ProcessorCustomer processorCustomer) {
QueryRequest query = QueryRequest.builder()
.tableName(table)
.indexName(INDEX_NAME)
.keyConditionExpression("#processor_customer_id = :processor_customer_id")
.projectionExpression("#user")
.expressionAttributeNames(Map.of(
"#processor_customer_id", KEY_PROCESSOR_ID_CUSTOMER_ID,
"#user", KEY_USER))
.expressionAttributeValues(Map.of(
":processor_customer_id", b(processorCustomer.toDynamoBytes())))
.build();
return client.query(query).thenApply(queryResponse -> {
int count = queryResponse.count();
if (count == 0) {
return null;
} else if (count > 1) {
logger.error("expected invariant of 1-1 subscriber-customer violated for customer {} ({})",
processorCustomer.customerId(), processorCustomer.processor());
throw new IllegalStateException(
"expected invariant of 1-1 subscriber-customer violated for customer " + processorCustomer);
} else {
Map<String, AttributeValue> result = queryResponse.items().get(0);
return result.get(KEY_USER).b().asByteArray();
}
});
}
public static class GetResult {
public static final GetResult NOT_STORED = new GetResult(Type.NOT_STORED, null);
public static final GetResult PASSWORD_MISMATCH = new GetResult(Type.PASSWORD_MISMATCH, null);
public enum Type {
NOT_STORED,
PASSWORD_MISMATCH,
FOUND
}
public final Type type;
public final Record record;
private GetResult(Type type, Record record) {
this.type = type;
this.record = record;
}
public static GetResult found(Record record) {
return new GetResult(Type.FOUND, record);
}
}
/**
* Looks up a record with the given {@code user} and validates the {@code hmac} before returning it.
*/
public CompletableFuture<GetResult> get(byte[] user, byte[] hmac) {
return getUser(user).thenApply(getItemResponse -> {
if (!getItemResponse.hasItem()) {
return GetResult.NOT_STORED;
}
Record record = Record.from(user, getItemResponse.item());
if (!MessageDigest.isEqual(hmac, record.password)) {
return GetResult.PASSWORD_MISMATCH;
}
return GetResult.found(record);
});
}
private CompletableFuture<GetItemResponse> getUser(byte[] user) {
checkUserLength(user);
GetItemRequest request = GetItemRequest.builder()
.consistentRead(Boolean.TRUE)
.tableName(table)
.key(Map.of(KEY_USER, b(user)))
.build();
return client.getItem(request);
}
public CompletableFuture<Record> create(byte[] user, byte[] password, Instant createdAt) {
checkUserLength(user);
UpdateItemRequest request = UpdateItemRequest.builder()
.tableName(table)
.key(Map.of(KEY_USER, b(user)))
.returnValues(ReturnValue.ALL_NEW)
.conditionExpression("attribute_not_exists(#user) OR #password = :password")
.updateExpression("SET "
+ "#password = if_not_exists(#password, :password), "
+ "#created_at = if_not_exists(#created_at, :created_at), "
+ "#accessed_at = if_not_exists(#accessed_at, :accessed_at)"
)
.expressionAttributeNames(Map.of(
"#user", KEY_USER,
"#password", KEY_PASSWORD,
"#created_at", KEY_CREATED_AT,
"#accessed_at", KEY_ACCESSED_AT)
)
.expressionAttributeValues(Map.of(
":password", b(password),
":created_at", n(createdAt.getEpochSecond()),
":accessed_at", n(createdAt.getEpochSecond()))
)
.build();
return client.updateItem(request).handle((updateItemResponse, throwable) -> {
if (throwable != null) {
if (Throwables.getRootCause(throwable) instanceof ConditionalCheckFailedException) {
return null;
}
Throwables.throwIfUnchecked(throwable);
throw new CompletionException(throwable);
}
return Record.from(user, updateItemResponse.attributes());
});
}
/**
* Sets the processor and customer ID for the given user record.
* *
* @return the user record. * @param subscriberCredentials Subscriber credentials derived from the subscriberId
* @return A stage that completes when the subscription has been cancelled with the upstream payment provider and the
* subscription has been removed from the table.
*/ */
public CompletableFuture<Record> setProcessorAndCustomerId(Record userRecord, public CompletableFuture<Void> deleteSubscriber(final SubscriberCredentials subscriberCredentials) {
ProcessorCustomer activeProcessorCustomer, Instant updatedAt) { return subscriptions.get(subscriberCredentials.subscriberUser(), subscriberCredentials.hmac())
.thenCompose(getResult -> {
UpdateItemRequest request = UpdateItemRequest.builder() if (getResult == Subscriptions.GetResult.NOT_STORED
.tableName(table) || getResult == Subscriptions.GetResult.PASSWORD_MISMATCH) {
.key(Map.of(KEY_USER, b(userRecord.user))) return CompletableFuture.failedFuture(new SubscriptionException.NotFound());
.returnValues(ReturnValue.ALL_NEW) }
.conditionExpression("attribute_not_exists(#processor_customer_id)") return getResult.record.getProcessorCustomer()
.updateExpression("SET " .map(processorCustomer -> getProcessor(processorCustomer.processor())
+ "#processor_customer_id = :processor_customer_id, " .cancelAllActiveSubscriptions(processorCustomer.customerId()))
+ "#accessed_at = :accessed_at" // a missing customer ID is OK; it means the subscriber never started to add a payment method
) .orElseGet(() -> CompletableFuture.completedFuture(null));
.expressionAttributeNames(Map.of( })
"#accessed_at", KEY_ACCESSED_AT, .thenCompose(unused ->
"#processor_customer_id", KEY_PROCESSOR_ID_CUSTOMER_ID subscriptions.canceledAt(subscriberCredentials.subscriberUser(), subscriberCredentials.now()));
)) }
.expressionAttributeValues(Map.of(
":accessed_at", n(updatedAt.getEpochSecond()), /**
":processor_customer_id", b(activeProcessorCustomer.toDynamoBytes()) * Create or update a subscriber in the subscriptions table
)).build(); * <p>
* If the subscriber does not exist, a subscriber with the provided credentials will be created. If the subscriber
return client.updateItem(request) * already exists, its last access time will be updated.
.thenApply(updateItemResponse -> Record.from(userRecord.user, updateItemResponse.attributes())) *
.exceptionallyCompose(throwable -> { * @param subscriberCredentials Subscriber credentials derived from the subscriberId
if (Throwables.getRootCause(throwable) instanceof ConditionalCheckFailedException) { * @return A stage that completes when the subscriber has been updated.
throw new ClientErrorException(Response.Status.CONFLICT); */
public CompletableFuture<Void> updateSubscriber(final SubscriberCredentials subscriberCredentials) {
return subscriptions.get(subscriberCredentials.subscriberUser(), subscriberCredentials.hmac())
.thenCompose(getResult -> {
if (getResult == Subscriptions.GetResult.PASSWORD_MISMATCH) {
return CompletableFuture.failedFuture(new SubscriptionException.Forbidden("subscriberId mismatch"));
} else if (getResult == Subscriptions.GetResult.NOT_STORED) {
// create a customer and write it to ddb
return subscriptions.create(subscriberCredentials.subscriberUser(), subscriberCredentials.hmac(),
subscriberCredentials.now())
.thenApply(updatedRecord -> {
if (updatedRecord == null) {
throw ExceptionUtils.wrap(new SubscriptionException.Forbidden("subscriberId mismatch"));
}
return updatedRecord;
});
} else {
// already exists so just touch access time and return
return subscriptions.accessedAt(subscriberCredentials.subscriberUser(), subscriberCredentials.now())
.thenApply(unused -> getResult.record);
}
})
.thenRun(Util.NOOP);
}
/**
* Get the subscriber record
*
* @param subscriberCredentials Subscriber credentials derived from the subscriberId
* @return A stage that completes with the requested subscriber if it exists and the credentials are correct.
*/
public CompletableFuture<Subscriptions.Record> getSubscriber(final SubscriberCredentials subscriberCredentials) {
return subscriptions.get(subscriberCredentials.subscriberUser(), subscriberCredentials.hmac())
.thenApply(getResult -> {
if (getResult == Subscriptions.GetResult.PASSWORD_MISMATCH) {
throw ExceptionUtils.wrap(new SubscriptionException.Forbidden("subscriberId mismatch"));
} else if (getResult == Subscriptions.GetResult.NOT_STORED) {
throw ExceptionUtils.wrap(new SubscriptionException.NotFound());
} else {
return getResult.record;
} }
Throwables.throwIfUnchecked(throwable);
throw new CompletionException(throwable);
}); });
} }
public CompletableFuture<Void> accessedAt(byte[] user, Instant accessedAt) { public record ReceiptResult(
checkUserLength(user); ReceiptCredentialResponse receiptCredentialResponse,
SubscriptionPaymentProcessor.ReceiptItem receiptItem,
PaymentProvider paymentProvider) {}
UpdateItemRequest request = UpdateItemRequest.builder() /**
.tableName(table) * Create a ZK receipt credential for a subscription that can be used to obtain the user entitlement
.key(Map.of(KEY_USER, b(user))) *
.returnValues(ReturnValue.NONE) * @param subscriberCredentials Subscriber credentials derived from the subscriberId
.updateExpression("SET #accessed_at = :accessed_at") * @param request The ZK Receipt credential request
.expressionAttributeNames(Map.of("#accessed_at", KEY_ACCESSED_AT)) * @param expiration A function that takes a {@link SubscriptionPaymentProcessor.ReceiptItem} and returns
.expressionAttributeValues(Map.of(":accessed_at", n(accessedAt.getEpochSecond()))) * the expiration time of the receipt
.build(); * @return If the subscription had a valid payment, the requested ZK receipt credential
return client.updateItem(request).thenApply(updateItemResponse -> null); */
public CompletableFuture<ReceiptResult> createReceiptCredentials(
final SubscriberCredentials subscriberCredentials,
final SubscriptionController.GetReceiptCredentialsRequest request,
final Function<SubscriptionPaymentProcessor.ReceiptItem, Instant> expiration) {
return getSubscriber(subscriberCredentials).thenCompose(record -> {
if (record.subscriptionId == null) {
return CompletableFuture.failedFuture(new SubscriptionException.NotFound());
}
ReceiptCredentialRequest receiptCredentialRequest;
try {
receiptCredentialRequest = new ReceiptCredentialRequest(request.receiptCredentialRequest());
} catch (InvalidInputException e) {
return CompletableFuture.failedFuture(
new SubscriptionException.InvalidArguments("invalid receipt credential request", e));
}
final PaymentProvider processor = record.getProcessorCustomer().orElseThrow().processor();
final Processor manager = getProcessor(processor);
return manager.getReceiptItem(record.subscriptionId)
.thenCompose(receipt -> issuedReceiptsManager.recordIssuance(
receipt.itemId(), manager.getProvider(), receiptCredentialRequest,
subscriberCredentials.now())
.thenApply(unused -> receipt))
.thenApply(receipt -> {
ReceiptCredentialResponse receiptCredentialResponse;
try {
receiptCredentialResponse = zkReceiptOperations.issueReceiptCredential(
receiptCredentialRequest,
expiration.apply(receipt).getEpochSecond(),
receipt.level());
} catch (VerificationFailedException e) {
throw ExceptionUtils.wrap(
new SubscriptionException.InvalidArguments("receipt credential request failed verification", e));
}
return new ReceiptResult(receiptCredentialResponse, receipt, processor);
});
});
} }
public CompletableFuture<Void> canceledAt(byte[] user, Instant canceledAt) { /**
checkUserLength(user); * Add a payment method to a customer in a payment processor and update the table.
* <p>
UpdateItemRequest request = UpdateItemRequest.builder() * If the customer does not exist in the table, a customer is created via the subscriptionPaymentProcessor and added
.tableName(table) * to the table. Not all payment processors support server-managed customers, so a payment processor that implements
.key(Map.of(KEY_USER, b(user))) * {@link SubscriptionPaymentProcessor} must be passed in.
.returnValues(ReturnValue.NONE) *
.updateExpression("SET " * @param subscriberCredentials Subscriber credentials derived from the subscriberId
+ "#accessed_at = :accessed_at, " * @param subscriptionPaymentProcessor A customer-aware payment processor to use. If the subscriber already has a
+ "#canceled_at = :canceled_at " * payment processor, it must match the existing one.
+ "REMOVE #subscription_id") * @param clientPlatform The platform of the client making the request
.expressionAttributeNames(Map.of( * @param paymentSetupFunction A function that takes the payment processor and the customer ID and begins
"#accessed_at", KEY_ACCESSED_AT, * adding a payment method. The function should return something that allows the
"#canceled_at", KEY_CANCELED_AT, * client to configure the newly added payment method like a payment method setup
"#subscription_id", KEY_SUBSCRIPTION_ID)) * token.
.expressionAttributeValues(Map.of( * @param <T> A payment processor that has a notion of server-managed customers
":accessed_at", n(canceledAt.getEpochSecond()), * @param <R> The return type of the paymentSetupFunction, which should be used by a client
":canceled_at", n(canceledAt.getEpochSecond()))) * to configure the newly created payment method
.build(); * @return A stage that completes when the payment method has been created in the payment processor and the table has
return client.updateItem(request).thenApply(updateItemResponse -> null); * been updated
*/
public <T extends SubscriptionPaymentProcessor, R> CompletableFuture<R> addPaymentMethodToCustomer(
final SubscriberCredentials subscriberCredentials,
final T subscriptionPaymentProcessor,
final ClientPlatform clientPlatform,
final BiFunction<T, String, CompletableFuture<R>> paymentSetupFunction) {
return this.getSubscriber(subscriberCredentials).thenCompose(record -> record.getProcessorCustomer()
.map(ProcessorCustomer::processor)
.map(processor -> {
if (processor != subscriptionPaymentProcessor.getProvider()) {
return CompletableFuture.<Subscriptions.Record>failedFuture(
new SubscriptionException.ProcessorConflict("existing processor does not match"));
}
return CompletableFuture.completedFuture(record);
})
.orElseGet(() -> subscriptionPaymentProcessor
.createCustomer(subscriberCredentials.subscriberUser(), clientPlatform)
.thenApply(ProcessorCustomer::customerId)
.thenCompose(customerId -> subscriptions.setProcessorAndCustomerId(record,
new ProcessorCustomer(customerId, subscriptionPaymentProcessor.getProvider()),
Instant.now()))))
.thenCompose(updatedRecord -> {
final String customerId = updatedRecord.getProcessorCustomer()
.filter(pc -> pc.processor().equals(subscriptionPaymentProcessor.getProvider()))
.orElseThrow(() ->
ExceptionUtils.wrap(new SubscriptionException("record should not be missing customer", null)))
.customerId();
return paymentSetupFunction.apply(subscriptionPaymentProcessor, customerId);
});
} }
public CompletableFuture<Void> subscriptionCreated( public interface LevelTransitionValidator {
byte[] user, String subscriptionId, Instant subscriptionCreatedAt, long level) { /**
checkUserLength(user); * Check is a level update is valid
*
UpdateItemRequest request = UpdateItemRequest.builder() * @param oldLevel The current level of the subscription
.tableName(table) * @param newLevel The proposed updated level of the subscription
.key(Map.of(KEY_USER, b(user))) * @return true if the subscription can be changed from oldLevel to newLevel, otherwise false
.returnValues(ReturnValue.NONE) */
.updateExpression("SET " boolean isTransitionValid(long oldLevel, long newLevel);
+ "#accessed_at = :accessed_at, "
+ "#subscription_id = :subscription_id, "
+ "#subscription_created_at = :subscription_created_at, "
+ "#subscription_level = :subscription_level, "
+ "#subscription_level_changed_at = :subscription_level_changed_at")
.expressionAttributeNames(Map.of(
"#accessed_at", KEY_ACCESSED_AT,
"#subscription_id", KEY_SUBSCRIPTION_ID,
"#subscription_created_at", KEY_SUBSCRIPTION_CREATED_AT,
"#subscription_level", KEY_SUBSCRIPTION_LEVEL,
"#subscription_level_changed_at", KEY_SUBSCRIPTION_LEVEL_CHANGED_AT))
.expressionAttributeValues(Map.of(
":accessed_at", n(subscriptionCreatedAt.getEpochSecond()),
":subscription_id", s(subscriptionId),
":subscription_created_at", n(subscriptionCreatedAt.getEpochSecond()),
":subscription_level", n(level),
":subscription_level_changed_at", n(subscriptionCreatedAt.getEpochSecond())))
.build();
return client.updateItem(request).thenApply(updateItemResponse -> null);
} }
public CompletableFuture<Void> subscriptionLevelChanged( /**
byte[] user, Instant subscriptionLevelChangedAt, long level, String subscriptionId) { * Update the subscription level in the payment processor and update the table.
checkUserLength(user); * <p>
* If we don't have an existing subscription, create one in the payment processor and then update the table. If we do
* already have a subscription, and it does not match the requested subscription, update it in the payment processor
* and then update the table. When an update occurs, this is where a user's recurring charge to a payment method is
* created or modified.
*
* @param subscriberCredentials Subscriber credentials derived from the subscriberId
* @param record A subscription record previous read with {@link #getSubscriber}
* @param processor A subscription payment processor with a notion of server-managed customers
* @param level The desired subscription level
* @param currency The desired currency type for the subscription
* @param idempotencyKey An idempotencyKey that can be used to deduplicate requests within the payment
* processor
* @param subscriptionTemplateId Specifies the product associated with the provided level within the payment
* processor
* @param transitionValidator A function that checks if the level update is valid
* @return A stage that completes when the level has been updated in the payment processor and the table
*/
public CompletableFuture<Void> updateSubscriptionLevelForCustomer(
final SubscriberCredentials subscriberCredentials,
final Subscriptions.Record record,
final SubscriptionPaymentProcessor processor,
final long level,
final String currency,
final String idempotencyKey,
final String subscriptionTemplateId,
final LevelTransitionValidator transitionValidator) {
UpdateItemRequest request = UpdateItemRequest.builder() return Optional.ofNullable(record.subscriptionId)
.tableName(table)
.key(Map.of(KEY_USER, b(user))) // we already have a subscription in our records so let's check the level and currency,
.returnValues(ReturnValue.NONE) // and only change it if needed
.updateExpression("SET " .map(subId -> processor
+ "#accessed_at = :accessed_at, " .getSubscription(subId)
+ "#subscription_id = :subscription_id, " .thenCompose(subscription -> processor.getLevelAndCurrencyForSubscription(subscription)
+ "#subscription_level = :subscription_level, " .thenCompose(existingLevelAndCurrency -> {
+ "#subscription_level_changed_at = :subscription_level_changed_at") if (existingLevelAndCurrency.equals(new SubscriptionPaymentProcessor.LevelAndCurrency(level,
.expressionAttributeNames(Map.of( currency.toLowerCase(Locale.ROOT)))) {
"#accessed_at", KEY_ACCESSED_AT, return CompletableFuture.completedFuture(null);
"#subscription_id", KEY_SUBSCRIPTION_ID, }
"#subscription_level", KEY_SUBSCRIPTION_LEVEL, if (!transitionValidator.isTransitionValid(existingLevelAndCurrency.level(), level)) {
"#subscription_level_changed_at", KEY_SUBSCRIPTION_LEVEL_CHANGED_AT)) return CompletableFuture.failedFuture(new SubscriptionException.InvalidLevel());
.expressionAttributeValues(Map.of( }
":accessed_at", n(subscriptionLevelChangedAt.getEpochSecond()), return processor.updateSubscription(subscription, subscriptionTemplateId, level, idempotencyKey)
":subscription_id", s(subscriptionId), .thenCompose(updatedSubscription ->
":subscription_level", n(level), subscriptions.subscriptionLevelChanged(subscriberCredentials.subscriberUser(),
":subscription_level_changed_at", n(subscriptionLevelChangedAt.getEpochSecond()))) subscriberCredentials.now(),
.build(); level, updatedSubscription.id()));
return client.updateItem(request).thenApply(updateItemResponse -> null); })))
// Otherwise, we don't have a subscription yet so create it and then record the subscription id
.orElseGet(() -> {
long lastSubscriptionCreatedAt = record.subscriptionCreatedAt != null
? record.subscriptionCreatedAt.getEpochSecond()
: 0;
return processor.createSubscription(record.processorCustomer.customerId(),
subscriptionTemplateId,
level,
lastSubscriptionCreatedAt)
.exceptionally(ExceptionUtils.exceptionallyHandler(StripeException.class, stripeException -> {
if ("subscription_payment_intent_requires_action".equals(stripeException.getCode())) {
throw ExceptionUtils.wrap(new SubscriptionException.PaymentRequiresAction());
}
throw ExceptionUtils.wrap(stripeException);
}))
.thenCompose(subscription -> subscriptions.subscriptionCreated(
subscriberCredentials.subscriberUser(), subscription.id(), subscriberCredentials.now(), level));
});
} }
private static byte[] checkUserLength(final byte[] user) { private Processor getProcessor(PaymentProvider provider) {
if (user.length != USER_LENGTH) { return processors.get(provider);
throw new IllegalArgumentException("user length is wrong; expected " + USER_LENGTH + "; was " + user.length);
}
return user;
} }
} }

View File

@ -0,0 +1,429 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.util.AttributeValues.b;
import static org.whispersystems.textsecuregcm.util.AttributeValues.n;
import static org.whispersystems.textsecuregcm.util.AttributeValues.s;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.ws.rs.ClientErrorException;
import javax.ws.rs.core.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer;
import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import org.whispersystems.textsecuregcm.util.Pair;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
public class Subscriptions {
private static final Logger logger = LoggerFactory.getLogger(Subscriptions.class);
private static final int USER_LENGTH = 16;
public static final String KEY_USER = "U"; // B (Hash Key)
public static final String KEY_PASSWORD = "P"; // B
public static final String KEY_PROCESSOR_ID_CUSTOMER_ID = "PC"; // B (GSI Hash Key of `pc_to_u` index)
public static final String KEY_CREATED_AT = "R"; // N
public static final String KEY_SUBSCRIPTION_ID = "S"; // S
public static final String KEY_SUBSCRIPTION_CREATED_AT = "T"; // N
public static final String KEY_SUBSCRIPTION_LEVEL = "L";
public static final String KEY_SUBSCRIPTION_LEVEL_CHANGED_AT = "V"; // N
public static final String KEY_ACCESSED_AT = "A"; // N
public static final String KEY_CANCELED_AT = "B"; // N
public static final String KEY_CURRENT_PERIOD_ENDS_AT = "D"; // N
public static final String INDEX_NAME = "pc_to_u"; // Hash Key "PC"
public static class Record {
public final byte[] user;
public final byte[] password;
public final Instant createdAt;
@VisibleForTesting
@Nullable
ProcessorCustomer processorCustomer;
@Nullable
public String subscriptionId;
public Instant subscriptionCreatedAt;
public Long subscriptionLevel;
public Instant subscriptionLevelChangedAt;
public Instant accessedAt;
public Instant canceledAt;
public Instant currentPeriodEndsAt;
private Record(byte[] user, byte[] password, Instant createdAt) {
this.user = checkUserLength(user);
this.password = Objects.requireNonNull(password);
this.createdAt = Objects.requireNonNull(createdAt);
}
public static Record from(byte[] user, Map<String, AttributeValue> item) {
Record record = new Record(
user,
item.get(KEY_PASSWORD).b().asByteArray(),
getInstant(item, KEY_CREATED_AT));
final Pair<PaymentProvider, String> processorCustomerId = getProcessorAndCustomer(item);
if (processorCustomerId != null) {
record.processorCustomer = new ProcessorCustomer(processorCustomerId.second(), processorCustomerId.first());
}
record.subscriptionId = getString(item, KEY_SUBSCRIPTION_ID);
record.subscriptionCreatedAt = getInstant(item, KEY_SUBSCRIPTION_CREATED_AT);
record.subscriptionLevel = getLong(item, KEY_SUBSCRIPTION_LEVEL);
record.subscriptionLevelChangedAt = getInstant(item, KEY_SUBSCRIPTION_LEVEL_CHANGED_AT);
record.accessedAt = getInstant(item, KEY_ACCESSED_AT);
record.canceledAt = getInstant(item, KEY_CANCELED_AT);
record.currentPeriodEndsAt = getInstant(item, KEY_CURRENT_PERIOD_ENDS_AT);
return record;
}
public Optional<ProcessorCustomer> getProcessorCustomer() {
return Optional.ofNullable(processorCustomer);
}
/**
* Extracts the active processor and customer from a single attribute value in the given item.
* <p>
* Until existing data is migrated, this may return {@code null}.
*/
@Nullable
private static Pair<PaymentProvider, String> getProcessorAndCustomer(Map<String, AttributeValue> item) {
final AttributeValue attributeValue = item.get(KEY_PROCESSOR_ID_CUSTOMER_ID);
if (attributeValue == null) {
// temporarily allow null values
return null;
}
final byte[] processorAndCustomerId = attributeValue.b().asByteArray();
final byte processorId = processorAndCustomerId[0];
final PaymentProvider processor = PaymentProvider.forId(processorId);
if (processor == null) {
throw new IllegalStateException("unknown processor id: " + processorId);
}
final String customerId = new String(processorAndCustomerId, 1, processorAndCustomerId.length - 1,
StandardCharsets.UTF_8);
return new Pair<>(processor, customerId);
}
private static String getString(Map<String, AttributeValue> item, String key) {
AttributeValue attributeValue = item.get(key);
if (attributeValue == null) {
return null;
}
return attributeValue.s();
}
private static Long getLong(Map<String, AttributeValue> item, String key) {
AttributeValue attributeValue = item.get(key);
if (attributeValue == null || attributeValue.n() == null) {
return null;
}
return Long.valueOf(attributeValue.n());
}
private static Instant getInstant(Map<String, AttributeValue> item, String key) {
AttributeValue attributeValue = item.get(key);
if (attributeValue == null || attributeValue.n() == null) {
return null;
}
return Instant.ofEpochSecond(Long.parseLong(attributeValue.n()));
}
}
private final String table;
private final DynamoDbAsyncClient client;
public Subscriptions(
@Nonnull String table,
@Nonnull DynamoDbAsyncClient client) {
this.table = Objects.requireNonNull(table);
this.client = Objects.requireNonNull(client);
}
/**
* Looks in the GSI for a record with the given customer id and returns the user id.
*/
public CompletableFuture<byte[]> getSubscriberUserByProcessorCustomer(ProcessorCustomer processorCustomer) {
QueryRequest query = QueryRequest.builder()
.tableName(table)
.indexName(INDEX_NAME)
.keyConditionExpression("#processor_customer_id = :processor_customer_id")
.projectionExpression("#user")
.expressionAttributeNames(Map.of(
"#processor_customer_id", KEY_PROCESSOR_ID_CUSTOMER_ID,
"#user", KEY_USER))
.expressionAttributeValues(Map.of(
":processor_customer_id", b(processorCustomer.toDynamoBytes())))
.build();
return client.query(query).thenApply(queryResponse -> {
int count = queryResponse.count();
if (count == 0) {
return null;
} else if (count > 1) {
logger.error("expected invariant of 1-1 subscriber-customer violated for customer {} ({})",
processorCustomer.customerId(), processorCustomer.processor());
throw new IllegalStateException(
"expected invariant of 1-1 subscriber-customer violated for customer " + processorCustomer);
} else {
Map<String, AttributeValue> result = queryResponse.items().get(0);
return result.get(KEY_USER).b().asByteArray();
}
});
}
public static class GetResult {
public static final GetResult NOT_STORED = new GetResult(Type.NOT_STORED, null);
public static final GetResult PASSWORD_MISMATCH = new GetResult(Type.PASSWORD_MISMATCH, null);
public enum Type {
NOT_STORED,
PASSWORD_MISMATCH,
FOUND
}
public final Type type;
public final Record record;
private GetResult(Type type, Record record) {
this.type = type;
this.record = record;
}
public static GetResult found(Record record) {
return new GetResult(Type.FOUND, record);
}
}
/**
* Looks up a record with the given {@code user} and validates the {@code hmac} before returning it.
*/
public CompletableFuture<GetResult> get(byte[] user, byte[] hmac) {
return getUser(user).thenApply(getItemResponse -> {
if (!getItemResponse.hasItem()) {
return GetResult.NOT_STORED;
}
Record record = Record.from(user, getItemResponse.item());
if (!MessageDigest.isEqual(hmac, record.password)) {
return GetResult.PASSWORD_MISMATCH;
}
return GetResult.found(record);
});
}
private CompletableFuture<GetItemResponse> getUser(byte[] user) {
checkUserLength(user);
GetItemRequest request = GetItemRequest.builder()
.consistentRead(Boolean.TRUE)
.tableName(table)
.key(Map.of(KEY_USER, b(user)))
.build();
return client.getItem(request);
}
public CompletableFuture<Record> create(byte[] user, byte[] password, Instant createdAt) {
checkUserLength(user);
UpdateItemRequest request = UpdateItemRequest.builder()
.tableName(table)
.key(Map.of(KEY_USER, b(user)))
.returnValues(ReturnValue.ALL_NEW)
.conditionExpression("attribute_not_exists(#user) OR #password = :password")
.updateExpression("SET "
+ "#password = if_not_exists(#password, :password), "
+ "#created_at = if_not_exists(#created_at, :created_at), "
+ "#accessed_at = if_not_exists(#accessed_at, :accessed_at)"
)
.expressionAttributeNames(Map.of(
"#user", KEY_USER,
"#password", KEY_PASSWORD,
"#created_at", KEY_CREATED_AT,
"#accessed_at", KEY_ACCESSED_AT)
)
.expressionAttributeValues(Map.of(
":password", b(password),
":created_at", n(createdAt.getEpochSecond()),
":accessed_at", n(createdAt.getEpochSecond()))
)
.build();
return client.updateItem(request).handle((updateItemResponse, throwable) -> {
if (throwable != null) {
if (Throwables.getRootCause(throwable) instanceof ConditionalCheckFailedException) {
return null;
}
Throwables.throwIfUnchecked(throwable);
throw new CompletionException(throwable);
}
return Record.from(user, updateItemResponse.attributes());
});
}
/**
* Sets the processor and customer ID for the given user record.
*
* @return the user record.
*/
public CompletableFuture<Record> setProcessorAndCustomerId(Record userRecord,
ProcessorCustomer activeProcessorCustomer, Instant updatedAt) {
UpdateItemRequest request = UpdateItemRequest.builder()
.tableName(table)
.key(Map.of(KEY_USER, b(userRecord.user)))
.returnValues(ReturnValue.ALL_NEW)
.conditionExpression("attribute_not_exists(#processor_customer_id)")
.updateExpression("SET "
+ "#processor_customer_id = :processor_customer_id, "
+ "#accessed_at = :accessed_at"
)
.expressionAttributeNames(Map.of(
"#accessed_at", KEY_ACCESSED_AT,
"#processor_customer_id", KEY_PROCESSOR_ID_CUSTOMER_ID
))
.expressionAttributeValues(Map.of(
":accessed_at", n(updatedAt.getEpochSecond()),
":processor_customer_id", b(activeProcessorCustomer.toDynamoBytes())
)).build();
return client.updateItem(request)
.thenApply(updateItemResponse -> Record.from(userRecord.user, updateItemResponse.attributes()))
.exceptionallyCompose(throwable -> {
if (Throwables.getRootCause(throwable) instanceof ConditionalCheckFailedException) {
throw new ClientErrorException(Response.Status.CONFLICT);
}
Throwables.throwIfUnchecked(throwable);
throw new CompletionException(throwable);
});
}
public CompletableFuture<Void> accessedAt(byte[] user, Instant accessedAt) {
checkUserLength(user);
UpdateItemRequest request = UpdateItemRequest.builder()
.tableName(table)
.key(Map.of(KEY_USER, b(user)))
.returnValues(ReturnValue.NONE)
.updateExpression("SET #accessed_at = :accessed_at")
.expressionAttributeNames(Map.of("#accessed_at", KEY_ACCESSED_AT))
.expressionAttributeValues(Map.of(":accessed_at", n(accessedAt.getEpochSecond())))
.build();
return client.updateItem(request).thenApply(updateItemResponse -> null);
}
public CompletableFuture<Void> canceledAt(byte[] user, Instant canceledAt) {
checkUserLength(user);
UpdateItemRequest request = UpdateItemRequest.builder()
.tableName(table)
.key(Map.of(KEY_USER, b(user)))
.returnValues(ReturnValue.NONE)
.updateExpression("SET "
+ "#accessed_at = :accessed_at, "
+ "#canceled_at = :canceled_at "
+ "REMOVE #subscription_id")
.expressionAttributeNames(Map.of(
"#accessed_at", KEY_ACCESSED_AT,
"#canceled_at", KEY_CANCELED_AT,
"#subscription_id", KEY_SUBSCRIPTION_ID))
.expressionAttributeValues(Map.of(
":accessed_at", n(canceledAt.getEpochSecond()),
":canceled_at", n(canceledAt.getEpochSecond())))
.build();
return client.updateItem(request).thenApply(updateItemResponse -> null);
}
public CompletableFuture<Void> subscriptionCreated(
byte[] user, String subscriptionId, Instant subscriptionCreatedAt, long level) {
checkUserLength(user);
UpdateItemRequest request = UpdateItemRequest.builder()
.tableName(table)
.key(Map.of(KEY_USER, b(user)))
.returnValues(ReturnValue.NONE)
.updateExpression("SET "
+ "#accessed_at = :accessed_at, "
+ "#subscription_id = :subscription_id, "
+ "#subscription_created_at = :subscription_created_at, "
+ "#subscription_level = :subscription_level, "
+ "#subscription_level_changed_at = :subscription_level_changed_at")
.expressionAttributeNames(Map.of(
"#accessed_at", KEY_ACCESSED_AT,
"#subscription_id", KEY_SUBSCRIPTION_ID,
"#subscription_created_at", KEY_SUBSCRIPTION_CREATED_AT,
"#subscription_level", KEY_SUBSCRIPTION_LEVEL,
"#subscription_level_changed_at", KEY_SUBSCRIPTION_LEVEL_CHANGED_AT))
.expressionAttributeValues(Map.of(
":accessed_at", n(subscriptionCreatedAt.getEpochSecond()),
":subscription_id", s(subscriptionId),
":subscription_created_at", n(subscriptionCreatedAt.getEpochSecond()),
":subscription_level", n(level),
":subscription_level_changed_at", n(subscriptionCreatedAt.getEpochSecond())))
.build();
return client.updateItem(request).thenApply(updateItemResponse -> null);
}
public CompletableFuture<Void> subscriptionLevelChanged(
byte[] user, Instant subscriptionLevelChangedAt, long level, String subscriptionId) {
checkUserLength(user);
UpdateItemRequest request = UpdateItemRequest.builder()
.tableName(table)
.key(Map.of(KEY_USER, b(user)))
.returnValues(ReturnValue.NONE)
.updateExpression("SET "
+ "#accessed_at = :accessed_at, "
+ "#subscription_id = :subscription_id, "
+ "#subscription_level = :subscription_level, "
+ "#subscription_level_changed_at = :subscription_level_changed_at")
.expressionAttributeNames(Map.of(
"#accessed_at", KEY_ACCESSED_AT,
"#subscription_id", KEY_SUBSCRIPTION_ID,
"#subscription_level", KEY_SUBSCRIPTION_LEVEL,
"#subscription_level_changed_at", KEY_SUBSCRIPTION_LEVEL_CHANGED_AT))
.expressionAttributeValues(Map.of(
":accessed_at", n(subscriptionLevelChangedAt.getEpochSecond()),
":subscription_id", s(subscriptionId),
":subscription_level", n(level),
":subscription_level_changed_at", n(subscriptionLevelChangedAt.getEpochSecond())))
.build();
return client.updateItem(request).thenApply(updateItemResponse -> null);
}
private static byte[] checkUserLength(final byte[] user) {
if (user.length != USER_LENGTH) {
throw new IllegalArgumentException("user length is wrong; expected " + USER_LENGTH + "; was " + user.length);
}
return user;
}
}

View File

@ -52,7 +52,7 @@ import org.whispersystems.textsecuregcm.util.GoogleApiUtil;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
public class BraintreeManager implements SubscriptionProcessorManager { public class BraintreeManager implements SubscriptionPaymentProcessor {
private static final Logger logger = LoggerFactory.getLogger(BraintreeManager.class); private static final Logger logger = LoggerFactory.getLogger(BraintreeManager.class);
@ -124,8 +124,8 @@ public class BraintreeManager implements SubscriptionProcessorManager {
} }
@Override @Override
public SubscriptionProcessor getProcessor() { public PaymentProvider getProvider() {
return SubscriptionProcessor.BRAINTREE; return PaymentProvider.BRAINTREE;
} }
@Override @Override
@ -211,7 +211,7 @@ public class BraintreeManager implements SubscriptionProcessorManager {
return switch (unsuccessfulTx.getProcessorResponseCode()) { return switch (unsuccessfulTx.getProcessorResponseCode()) {
case GENERIC_DECLINED_PROCESSOR_CODE, PAYPAL_FUNDING_INSTRUMENT_DECLINED_PROCESSOR_CODE -> case GENERIC_DECLINED_PROCESSOR_CODE, PAYPAL_FUNDING_INSTRUMENT_DECLINED_PROCESSOR_CODE ->
CompletableFuture.failedFuture( CompletableFuture.failedFuture(
new SubscriptionProcessorException(getProcessor(), createChargeFailure(unsuccessfulTx))); new SubscriptionProcessorException(getProvider(), createChargeFailure(unsuccessfulTx)));
default -> { default -> {
logger.info("PayPal charge unexpectedly failed: {}", unsuccessfulTx.getProcessorResponseCode()); logger.info("PayPal charge unexpectedly failed: {}", unsuccessfulTx.getProcessorResponseCode());
@ -342,7 +342,7 @@ public class BraintreeManager implements SubscriptionProcessorManager {
throw new CompletionException(new BraintreeException(result.getMessage())); throw new CompletionException(new BraintreeException(result.getMessage()));
} }
return new ProcessorCustomer(result.getTarget().getId(), SubscriptionProcessor.BRAINTREE); return new ProcessorCustomer(result.getTarget().getId(), PaymentProvider.BRAINTREE);
}); });
} }
@ -423,7 +423,7 @@ public class BraintreeManager implements SubscriptionProcessorManager {
if (result.getTarget() != null) { if (result.getTarget() != null) {
completionException = result.getTarget().getTransactions().stream().findFirst() completionException = result.getTarget().getTransactions().stream().findFirst()
.map(transaction -> new CompletionException( .map(transaction -> new CompletionException(
new SubscriptionProcessorException(getProcessor(), createChargeFailure(transaction)))) new SubscriptionProcessorException(getProvider(), createChargeFailure(transaction))))
.orElseGet(() -> new CompletionException(new BraintreeException(result.getMessage()))); .orElseGet(() -> new CompletionException(new BraintreeException(result.getMessage())));
} else { } else {
completionException = new CompletionException(new BraintreeException(result.getMessage())); completionException = new CompletionException(new BraintreeException(result.getMessage()));

View File

@ -12,30 +12,30 @@ import java.util.Map;
/** /**
* A set of payment providers used for donations * A set of payment providers used for donations
*/ */
public enum SubscriptionProcessor { public enum PaymentProvider {
// because provider IDs are stored, they should not be reused, and great care // because provider IDs are stored, they should not be reused, and great care
// must be used if a provider is removed from the list // must be used if a provider is removed from the list
STRIPE(1), STRIPE(1),
BRAINTREE(2), BRAINTREE(2),
; ;
private static final Map<Integer, SubscriptionProcessor> IDS_TO_PROCESSORS = new HashMap<>(); private static final Map<Integer, PaymentProvider> IDS_TO_PROCESSORS = new HashMap<>();
static { static {
Arrays.stream(SubscriptionProcessor.values()) Arrays.stream(PaymentProvider.values())
.forEach(provider -> IDS_TO_PROCESSORS.put((int) provider.id, provider)); .forEach(provider -> IDS_TO_PROCESSORS.put((int) provider.id, provider));
} }
/** /**
* @return the provider associated with the given ID, or {@code null} if none exists * @return the provider associated with the given ID, or {@code null} if none exists
*/ */
public static SubscriptionProcessor forId(byte id) { public static PaymentProvider forId(byte id) {
return IDS_TO_PROCESSORS.get((int) id); return IDS_TO_PROCESSORS.get((int) id);
} }
private final byte id; private final byte id;
SubscriptionProcessor(int id) { PaymentProvider(int id) {
if (id > 255) { if (id > 255) {
throw new IllegalArgumentException("ID must fit in one byte: " + id); throw new IllegalArgumentException("ID must fit in one byte: " + id);
} }

View File

@ -7,7 +7,7 @@ package org.whispersystems.textsecuregcm.subscriptions;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
public record ProcessorCustomer(String customerId, SubscriptionProcessor processor) { public record ProcessorCustomer(String customerId, PaymentProvider processor) {
public byte[] toDynamoBytes() { public byte[] toDynamoBytes() {
final byte[] customerIdBytes = customerId.getBytes(StandardCharsets.UTF_8); final byte[] customerIdBytes = customerId.getBytes(StandardCharsets.UTF_8);

View File

@ -76,7 +76,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
public class StripeManager implements SubscriptionProcessorManager { public class StripeManager implements SubscriptionPaymentProcessor {
private static final Logger logger = LoggerFactory.getLogger(StripeManager.class); private static final Logger logger = LoggerFactory.getLogger(StripeManager.class);
private static final String METADATA_KEY_LEVEL = "level"; private static final String METADATA_KEY_LEVEL = "level";
private static final String METADATA_KEY_CLIENT_PLATFORM = "clientPlatform"; private static final String METADATA_KEY_CLIENT_PLATFORM = "clientPlatform";
@ -107,8 +107,8 @@ public class StripeManager implements SubscriptionProcessorManager {
} }
@Override @Override
public SubscriptionProcessor getProcessor() { public PaymentProvider getProvider() {
return SubscriptionProcessor.STRIPE; return PaymentProvider.STRIPE;
} }
@Override @Override
@ -145,7 +145,7 @@ public class StripeManager implements SubscriptionProcessorManager {
throw new CompletionException(e); throw new CompletionException(e);
} }
}, executor) }, executor)
.thenApply(customer -> new ProcessorCustomer(customer.getId(), getProcessor())); .thenApply(customer -> new ProcessorCustomer(customer.getId(), getProvider()));
} }
public CompletableFuture<Customer> getCustomer(String customerId) { public CompletableFuture<Customer> getCustomer(String customerId) {
@ -300,7 +300,7 @@ public class StripeManager implements SubscriptionProcessorManager {
if (e instanceof CardException ce) { if (e instanceof CardException ce) {
throw new CompletionException( throw new CompletionException(
new SubscriptionProcessorException(getProcessor(), createChargeFailureFromCardException(e, ce))); new SubscriptionProcessorException(getProvider(), createChargeFailureFromCardException(e, ce)));
} }
throw new CompletionException(e); throw new CompletionException(e);
@ -348,7 +348,7 @@ public class StripeManager implements SubscriptionProcessorManager {
if (e instanceof CardException ce) { if (e instanceof CardException ce) {
throw new CompletionException( throw new CompletionException(
new SubscriptionProcessorException(getProcessor(), createChargeFailureFromCardException(e, ce))); new SubscriptionProcessorException(getProvider(), createChargeFailureFromCardException(e, ce)));
} }
throw new CompletionException(e); throw new CompletionException(e);
} }

View File

@ -12,10 +12,10 @@ import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.SubscriptionManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
public interface SubscriptionProcessorManager { public interface SubscriptionPaymentProcessor extends SubscriptionManager.Processor {
SubscriptionProcessor getProcessor();
boolean supportsPaymentMethod(PaymentMethod paymentMethod); boolean supportsPaymentMethod(PaymentMethod paymentMethod);
@ -49,10 +49,6 @@ public interface SubscriptionProcessorManager {
*/ */
CompletableFuture<LevelAndCurrency> getLevelAndCurrencyForSubscription(Object subscription); CompletableFuture<LevelAndCurrency> getLevelAndCurrencyForSubscription(Object subscription);
CompletableFuture<Void> cancelAllActiveSubscriptions(String customerId);
CompletableFuture<ReceiptItem> getReceiptItem(String subscriptionId);
CompletableFuture<SubscriptionInformation> getSubscriptionInformation(Object subscription); CompletableFuture<SubscriptionInformation> getSubscriptionInformation(Object subscription);
enum SubscriptionStatus { enum SubscriptionStatus {
@ -102,13 +98,13 @@ public interface SubscriptionProcessorManager {
case "incomplete" -> INCOMPLETE; case "incomplete" -> INCOMPLETE;
case "trialing" -> { case "trialing" -> {
final Logger logger = LoggerFactory.getLogger(SubscriptionProcessorManager.class); final Logger logger = LoggerFactory.getLogger(SubscriptionPaymentProcessor.class);
logger.error("Subscription has status that should never happen: {}", status); logger.error("Subscription has status that should never happen: {}", status);
yield UNKNOWN; yield UNKNOWN;
} }
default -> { default -> {
final Logger logger = LoggerFactory.getLogger(SubscriptionProcessorManager.class); final Logger logger = LoggerFactory.getLogger(SubscriptionPaymentProcessor.class);
logger.error("Subscription has unknown status: {}", status); logger.error("Subscription has unknown status: {}", status);
yield UNKNOWN; yield UNKNOWN;
@ -137,10 +133,6 @@ public interface SubscriptionProcessorManager {
} }
record ReceiptItem(String itemId, Instant paidAt, long level) {
}
record LevelAndCurrency(long level, String currency) { record LevelAndCurrency(long level, String currency) {
} }

View File

@ -7,16 +7,16 @@ package org.whispersystems.textsecuregcm.subscriptions;
public class SubscriptionProcessorException extends Exception { public class SubscriptionProcessorException extends Exception {
private final SubscriptionProcessor processor; private final PaymentProvider processor;
private final ChargeFailure chargeFailure; private final ChargeFailure chargeFailure;
public SubscriptionProcessorException(final SubscriptionProcessor processor, public SubscriptionProcessorException(final PaymentProvider processor,
final ChargeFailure chargeFailure) { final ChargeFailure chargeFailure) {
this.processor = processor; this.processor = processor;
this.chargeFailure = chargeFailure; this.chargeFailure = chargeFailure;
} }
public SubscriptionProcessor getProcessor() { public PaymentProvider getProcessor() {
return processor; return processor;
} }

View File

@ -74,10 +74,12 @@ import org.whispersystems.textsecuregcm.controllers.SubscriptionController.GetSu
import org.whispersystems.textsecuregcm.entities.Badge; import org.whispersystems.textsecuregcm.entities.Badge;
import org.whispersystems.textsecuregcm.entities.BadgeSvg; import org.whispersystems.textsecuregcm.entities.BadgeSvg;
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.SubscriptionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.SubscriptionProcessorExceptionMapper; import org.whispersystems.textsecuregcm.mappers.SubscriptionProcessorExceptionMapper;
import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager; import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager;
import org.whispersystems.textsecuregcm.storage.OneTimeDonationsManager; import org.whispersystems.textsecuregcm.storage.OneTimeDonationsManager;
import org.whispersystems.textsecuregcm.storage.SubscriptionManager; import org.whispersystems.textsecuregcm.storage.SubscriptionManager;
import org.whispersystems.textsecuregcm.storage.Subscriptions;
import org.whispersystems.textsecuregcm.subscriptions.BankMandateTranslator; import org.whispersystems.textsecuregcm.subscriptions.BankMandateTranslator;
import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager; import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager;
import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager.PayPalOneTimePaymentApprovalDetails; import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager.PayPalOneTimePaymentApprovalDetails;
@ -87,10 +89,11 @@ import org.whispersystems.textsecuregcm.subscriptions.PaymentMethod;
import org.whispersystems.textsecuregcm.subscriptions.PaymentStatus; import org.whispersystems.textsecuregcm.subscriptions.PaymentStatus;
import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer; import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer;
import org.whispersystems.textsecuregcm.subscriptions.StripeManager; import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessorException; import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessorException;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessorManager; import org.whispersystems.textsecuregcm.subscriptions.SubscriptionPaymentProcessor;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -103,9 +106,11 @@ class SubscriptionControllerTest {
private static final SubscriptionConfiguration SUBSCRIPTION_CONFIG = ConfigHelper.getSubscriptionConfig(); private static final SubscriptionConfiguration SUBSCRIPTION_CONFIG = ConfigHelper.getSubscriptionConfig();
private static final OneTimeDonationConfiguration ONETIME_CONFIG = ConfigHelper.getOneTimeConfig(); private static final OneTimeDonationConfiguration ONETIME_CONFIG = ConfigHelper.getOneTimeConfig();
private static final SubscriptionManager SUBSCRIPTION_MANAGER = mock(SubscriptionManager.class); private static final Subscriptions SUBSCRIPTIONS = mock(Subscriptions.class);
private static final StripeManager STRIPE_MANAGER = mock(StripeManager.class); private static final StripeManager STRIPE_MANAGER = MockUtils.buildMock(StripeManager.class, mgr ->
private static final BraintreeManager BRAINTREE_MANAGER = mock(BraintreeManager.class); when(mgr.getProvider()).thenReturn(PaymentProvider.STRIPE));
private static final BraintreeManager BRAINTREE_MANAGER = MockUtils.buildMock(BraintreeManager.class, mgr ->
when(mgr.getProvider()).thenReturn(PaymentProvider.BRAINTREE));
private static final PaymentIntent PAYMENT_INTENT = mock(PaymentIntent.class); private static final PaymentIntent PAYMENT_INTENT = mock(PaymentIntent.class);
private static final ServerZkReceiptOperations ZK_OPS = mock(ServerZkReceiptOperations.class); private static final ServerZkReceiptOperations ZK_OPS = mock(ServerZkReceiptOperations.class);
private static final IssuedReceiptsManager ISSUED_RECEIPTS_MANAGER = mock(IssuedReceiptsManager.class); private static final IssuedReceiptsManager ISSUED_RECEIPTS_MANAGER = mock(IssuedReceiptsManager.class);
@ -113,17 +118,19 @@ class SubscriptionControllerTest {
private static final BadgeTranslator BADGE_TRANSLATOR = mock(BadgeTranslator.class); private static final BadgeTranslator BADGE_TRANSLATOR = mock(BadgeTranslator.class);
private static final LevelTranslator LEVEL_TRANSLATOR = mock(LevelTranslator.class); private static final LevelTranslator LEVEL_TRANSLATOR = mock(LevelTranslator.class);
private static final BankMandateTranslator BANK_MANDATE_TRANSLATOR = mock(BankMandateTranslator.class); private static final BankMandateTranslator BANK_MANDATE_TRANSLATOR = mock(BankMandateTranslator.class);
private static final SubscriptionController SUBSCRIPTION_CONTROLLER = new SubscriptionController( private final static SubscriptionController SUBSCRIPTION_CONTROLLER = new SubscriptionController(CLOCK, SUBSCRIPTION_CONFIG,
CLOCK, SUBSCRIPTION_CONFIG, ONETIME_CONFIG, SUBSCRIPTION_MANAGER, STRIPE_MANAGER, BRAINTREE_MANAGER, ZK_OPS, ONETIME_CONFIG, new SubscriptionManager(SUBSCRIPTIONS, List.of(STRIPE_MANAGER, BRAINTREE_MANAGER), ZK_OPS,
ISSUED_RECEIPTS_MANAGER, BADGE_TRANSLATOR, LEVEL_TRANSLATOR, BANK_MANDATE_TRANSLATOR); ISSUED_RECEIPTS_MANAGER), STRIPE_MANAGER, BRAINTREE_MANAGER, BADGE_TRANSLATOR, LEVEL_TRANSLATOR,
private static final OneTimeDonationController ONE_TIME_CONTROLLER = new OneTimeDonationController(CLOCK, ONETIME_CONFIG, STRIPE_MANAGER, BANK_MANDATE_TRANSLATOR);
BRAINTREE_MANAGER, ZK_OPS, ISSUED_RECEIPTS_MANAGER, ONE_TIME_DONATIONS_MANAGER); private static final OneTimeDonationController ONE_TIME_CONTROLLER = new OneTimeDonationController(CLOCK,
ONETIME_CONFIG, STRIPE_MANAGER, BRAINTREE_MANAGER, ZK_OPS, ISSUED_RECEIPTS_MANAGER, ONE_TIME_DONATIONS_MANAGER);
private static final ResourceExtension RESOURCE_EXTENSION = ResourceExtension.builder() private static final ResourceExtension RESOURCE_EXTENSION = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(CompletionExceptionMapper.class) .addProvider(CompletionExceptionMapper.class)
.addProvider(SubscriptionProcessorExceptionMapper.class) .addProvider(SubscriptionProcessorExceptionMapper.class)
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class)) .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(SubscriptionExceptionMapper.class)
.setMapper(SystemMapper.jsonMapper()) .setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(SUBSCRIPTION_CONTROLLER) .addResource(SUBSCRIPTION_CONTROLLER)
@ -132,11 +139,11 @@ class SubscriptionControllerTest {
@BeforeEach @BeforeEach
void setUp() { void setUp() {
reset(CLOCK, SUBSCRIPTION_MANAGER, STRIPE_MANAGER, BRAINTREE_MANAGER, ZK_OPS, ISSUED_RECEIPTS_MANAGER, reset(CLOCK, SUBSCRIPTIONS, STRIPE_MANAGER, BRAINTREE_MANAGER, ZK_OPS, ISSUED_RECEIPTS_MANAGER,
BADGE_TRANSLATOR, LEVEL_TRANSLATOR); BADGE_TRANSLATOR, LEVEL_TRANSLATOR);
when(STRIPE_MANAGER.getProcessor()).thenReturn(SubscriptionProcessor.STRIPE); when(STRIPE_MANAGER.getProvider()).thenReturn(PaymentProvider.STRIPE);
when(BRAINTREE_MANAGER.getProcessor()).thenReturn(SubscriptionProcessor.BRAINTREE); when(BRAINTREE_MANAGER.getProvider()).thenReturn(PaymentProvider.BRAINTREE);
List.of(STRIPE_MANAGER, BRAINTREE_MANAGER) List.of(STRIPE_MANAGER, BRAINTREE_MANAGER)
.forEach(manager -> { .forEach(manager -> {
@ -328,7 +335,7 @@ class SubscriptionControllerTest {
when(BRAINTREE_MANAGER.captureOneTimePayment(anyString(), anyString(), anyString(), anyString(), anyLong(), when(BRAINTREE_MANAGER.captureOneTimePayment(anyString(), anyString(), anyString(), anyString(), anyLong(),
anyLong(), any())) anyLong(), any()))
.thenReturn(CompletableFuture.failedFuture(new SubscriptionProcessorException(SubscriptionProcessor.BRAINTREE, .thenReturn(CompletableFuture.failedFuture(new SubscriptionProcessorException(PaymentProvider.BRAINTREE,
new ChargeFailure("2046", "Declined", null, null, null)))); new ChargeFailure("2046", "Declined", null, null, null))));
final Response response = RESOURCE_EXTENSION.target("/v1/subscription/boost/paypal/confirm") final Response response = RESOURCE_EXTENSION.target("/v1/subscription/boost/paypal/confirm")
@ -373,26 +380,26 @@ class SubscriptionControllerTest {
Arrays.fill(subscriberUserAndKey, (byte) 1); Arrays.fill(subscriberUserAndKey, (byte) 1);
subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey);
final ProcessorCustomer processorCustomer = new ProcessorCustomer("testCustomerId", SubscriptionProcessor.STRIPE); final ProcessorCustomer processorCustomer = new ProcessorCustomer("testCustomerId", PaymentProvider.STRIPE);
final Map<String, AttributeValue> dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), final Map<String, AttributeValue> dynamoItem = Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_PROCESSOR_ID_CUSTOMER_ID, b(processorCustomer.toDynamoBytes()) Subscriptions.KEY_PROCESSOR_ID_CUSTOMER_ID, b(processorCustomer.toDynamoBytes())
); );
final SubscriptionManager.Record record = SubscriptionManager.Record.from( final Subscriptions.Record record = Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem);
when(SUBSCRIPTION_MANAGER.get(eq(Arrays.copyOfRange(subscriberUserAndKey, 0, 16)), any())) when(SUBSCRIPTIONS.get(eq(Arrays.copyOfRange(subscriberUserAndKey, 0, 16)), any()))
.thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); .thenReturn(CompletableFuture.completedFuture(Subscriptions.GetResult.found(record)));
when(SUBSCRIPTION_MANAGER.subscriptionCreated(any(), any(), any(), anyLong())) when(SUBSCRIPTIONS.subscriptionCreated(any(), any(), any(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
} }
@Test @Test
void createSubscriptionSuccess() { void createSubscriptionSuccess() {
when(STRIPE_MANAGER.createSubscription(any(), any(), anyLong(), anyLong())) when(STRIPE_MANAGER.createSubscription(any(), any(), anyLong(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(mock(SubscriptionProcessorManager.SubscriptionId.class))); .thenReturn(CompletableFuture.completedFuture(mock(SubscriptionPaymentProcessor.SubscriptionId.class)));
final String level = String.valueOf(levelId); final String level = String.valueOf(levelId);
final String idempotencyKey = UUID.randomUUID().toString(); final String idempotencyKey = UUID.randomUUID().toString();
@ -407,7 +414,7 @@ class SubscriptionControllerTest {
@Test @Test
void createSubscriptionProcessorDeclined() { void createSubscriptionProcessorDeclined() {
when(STRIPE_MANAGER.createSubscription(any(), any(), anyLong(), anyLong())) when(STRIPE_MANAGER.createSubscription(any(), any(), anyLong(), anyLong()))
.thenReturn(CompletableFuture.failedFuture(new SubscriptionProcessorException(SubscriptionProcessor.STRIPE, .thenReturn(CompletableFuture.failedFuture(new SubscriptionProcessorException(PaymentProvider.STRIPE,
new ChargeFailure("card_declined", "Insufficient funds", null, null, null)))); new ChargeFailure("card_declined", "Insufficient funds", null, null, null))));
final String level = String.valueOf(levelId); final String level = String.valueOf(levelId);
@ -434,15 +441,15 @@ class SubscriptionControllerTest {
Arrays.fill(subscriberUserAndKey, (byte) 1); Arrays.fill(subscriberUserAndKey, (byte) 1);
subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey);
final Map<String, AttributeValue> dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), final Map<String, AttributeValue> dynamoItem = Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()) Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond())
// missing processor:customer field // missing processor:customer field
); );
final SubscriptionManager.Record record = SubscriptionManager.Record.from( final Subscriptions.Record record = Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem);
when(SUBSCRIPTION_MANAGER.get(eq(Arrays.copyOfRange(subscriberUserAndKey, 0, 16)), any())) when(SUBSCRIPTIONS.get(eq(Arrays.copyOfRange(subscriberUserAndKey, 0, 16)), any()))
.thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); .thenReturn(CompletableFuture.completedFuture(Subscriptions.GetResult.found(record)));
final String level = String.valueOf(levelId); final String level = String.valueOf(levelId);
final String idempotencyKey = UUID.randomUUID().toString(); final String idempotencyKey = UUID.randomUUID().toString();
@ -490,16 +497,16 @@ class SubscriptionControllerTest {
Arrays.fill(subscriberUserAndKey, (byte) 1); Arrays.fill(subscriberUserAndKey, (byte) 1);
final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey);
when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( when(SUBSCRIPTIONS.get(any(), any())).thenReturn(CompletableFuture.completedFuture(
SubscriptionManager.GetResult.NOT_STORED)); Subscriptions.GetResult.NOT_STORED));
final Map<String, AttributeValue> dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), final Map<String, AttributeValue> dynamoItem = Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()) Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond())
); );
final SubscriptionManager.Record record = SubscriptionManager.Record.from( final Subscriptions.Record record = Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem);
when(SUBSCRIPTION_MANAGER.create(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(record)); when(SUBSCRIPTIONS.create(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(record));
final Response createResponse = RESOURCE_EXTENSION.target(String.format("/v1/subscription/%s", subscriberId)) final Response createResponse = RESOURCE_EXTENSION.target(String.format("/v1/subscription/%s", subscriberId))
.request() .request()
@ -507,9 +514,9 @@ class SubscriptionControllerTest {
assertThat(createResponse.getStatus()).isEqualTo(200); assertThat(createResponse.getStatus()).isEqualTo(200);
// creating should be idempotent // creating should be idempotent
when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( when(SUBSCRIPTIONS.get(any(), any())).thenReturn(CompletableFuture.completedFuture(
SubscriptionManager.GetResult.found(record))); Subscriptions.GetResult.found(record)));
when(SUBSCRIPTION_MANAGER.accessedAt(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(SUBSCRIPTIONS.accessedAt(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Response idempotentCreateResponse = RESOURCE_EXTENSION.target( final Response idempotentCreateResponse = RESOURCE_EXTENSION.target(
String.format("/v1/subscription/%s", subscriberId)) String.format("/v1/subscription/%s", subscriberId))
@ -519,9 +526,9 @@ class SubscriptionControllerTest {
// when the manager returns `null`, it means there was a password mismatch from the storage layer `create`. // when the manager returns `null`, it means there was a password mismatch from the storage layer `create`.
// this could happen if there is a race between two concurrent `create` requests for the same user ID // this could happen if there is a race between two concurrent `create` requests for the same user ID
when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( when(SUBSCRIPTIONS.get(any(), any())).thenReturn(CompletableFuture.completedFuture(
SubscriptionManager.GetResult.NOT_STORED)); Subscriptions.GetResult.NOT_STORED));
when(SUBSCRIPTION_MANAGER.create(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(SUBSCRIPTIONS.create(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Response managerCreateNullResponse = RESOURCE_EXTENSION.target( final Response managerCreateNullResponse = RESOURCE_EXTENSION.target(
String.format("/v1/subscription/%s", subscriberId)) String.format("/v1/subscription/%s", subscriberId))
@ -535,8 +542,8 @@ class SubscriptionControllerTest {
final String mismatchedSubscriberId = Base64.getEncoder().encodeToString(subscriberUserAndMismatchedKey); final String mismatchedSubscriberId = Base64.getEncoder().encodeToString(subscriberUserAndMismatchedKey);
// a password mismatch for an existing record // a password mismatch for an existing record
when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( when(SUBSCRIPTIONS.get(any(), any())).thenReturn(CompletableFuture.completedFuture(
SubscriptionManager.GetResult.PASSWORD_MISMATCH)); Subscriptions.GetResult.PASSWORD_MISMATCH));
final Response passwordMismatchResponse = RESOURCE_EXTENSION.target( final Response passwordMismatchResponse = RESOURCE_EXTENSION.target(
String.format("/v1/subscription/%s", mismatchedSubscriberId)) String.format("/v1/subscription/%s", mismatchedSubscriberId))
@ -565,16 +572,16 @@ class SubscriptionControllerTest {
final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey);
when(CLOCK.instant()).thenReturn(Instant.now()); when(CLOCK.instant()).thenReturn(Instant.now());
when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( when(SUBSCRIPTIONS.get(any(), any())).thenReturn(CompletableFuture.completedFuture(
SubscriptionManager.GetResult.NOT_STORED)); Subscriptions.GetResult.NOT_STORED));
final Map<String, AttributeValue> dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), final Map<String, AttributeValue> dynamoItem = Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()) Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond())
); );
final SubscriptionManager.Record record = SubscriptionManager.Record.from( final Subscriptions.Record record = Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem);
when(SUBSCRIPTION_MANAGER.create(any(), any(), any(Instant.class))) when(SUBSCRIPTIONS.create(any(), any(), any(Instant.class)))
.thenReturn(CompletableFuture.completedFuture(record)); .thenReturn(CompletableFuture.completedFuture(record));
final Response createSubscriberResponse = RESOURCE_EXTENSION final Response createSubscriberResponse = RESOURCE_EXTENSION
@ -584,22 +591,22 @@ class SubscriptionControllerTest {
assertThat(createSubscriberResponse.getStatus()).isEqualTo(200); assertThat(createSubscriberResponse.getStatus()).isEqualTo(200);
when(SUBSCRIPTION_MANAGER.get(any(), any())) when(SUBSCRIPTIONS.get(any(), any()))
.thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); .thenReturn(CompletableFuture.completedFuture(Subscriptions.GetResult.found(record)));
final String customerId = "some-customer-id"; final String customerId = "some-customer-id";
final ProcessorCustomer customer = new ProcessorCustomer( final ProcessorCustomer customer = new ProcessorCustomer(
customerId, SubscriptionProcessor.STRIPE); customerId, PaymentProvider.STRIPE);
when(STRIPE_MANAGER.createCustomer(any(), any())) when(STRIPE_MANAGER.createCustomer(any(), any()))
.thenReturn(CompletableFuture.completedFuture(customer)); .thenReturn(CompletableFuture.completedFuture(customer));
final Map<String, AttributeValue> dynamoItemWithProcessorCustomer = new HashMap<>(dynamoItem); final Map<String, AttributeValue> dynamoItemWithProcessorCustomer = new HashMap<>(dynamoItem);
dynamoItemWithProcessorCustomer.put(SubscriptionManager.KEY_PROCESSOR_ID_CUSTOMER_ID, dynamoItemWithProcessorCustomer.put(Subscriptions.KEY_PROCESSOR_ID_CUSTOMER_ID,
b(new ProcessorCustomer(customerId, SubscriptionProcessor.STRIPE).toDynamoBytes())); b(new ProcessorCustomer(customerId, PaymentProvider.STRIPE).toDynamoBytes()));
final SubscriptionManager.Record recordWithCustomerId = SubscriptionManager.Record.from(record.user, final Subscriptions.Record recordWithCustomerId = Subscriptions.Record.from(record.user,
dynamoItemWithProcessorCustomer); dynamoItemWithProcessorCustomer);
when(SUBSCRIPTION_MANAGER.setProcessorAndCustomerId(any(SubscriptionManager.Record.class), any(), when(SUBSCRIPTIONS.setProcessorAndCustomerId(any(Subscriptions.Record.class), any(),
any(Instant.class))) any(Instant.class)))
.thenReturn(CompletableFuture.completedFuture(recordWithCustomerId)); .thenReturn(CompletableFuture.completedFuture(recordWithCustomerId));
@ -613,7 +620,7 @@ class SubscriptionControllerTest {
.post(Entity.json("")) .post(Entity.json(""))
.readEntity(SubscriptionController.CreatePaymentMethodResponse.class); .readEntity(SubscriptionController.CreatePaymentMethodResponse.class);
assertThat(createPaymentMethodResponse.processor()).isEqualTo(SubscriptionProcessor.STRIPE); assertThat(createPaymentMethodResponse.processor()).isEqualTo(PaymentProvider.STRIPE);
assertThat(createPaymentMethodResponse.clientSecret()).isEqualTo(clientSecret); assertThat(createPaymentMethodResponse.clientSecret()).isEqualTo(clientSecret);
} }
@ -625,19 +632,19 @@ class SubscriptionControllerTest {
Arrays.fill(subscriberUserAndKey, (byte) 1); Arrays.fill(subscriberUserAndKey, (byte) 1);
final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey);
final Map<String, AttributeValue> dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), final Map<String, AttributeValue> dynamoItem = Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()) Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond())
); );
final SubscriptionManager.Record record = SubscriptionManager.Record.from( final Subscriptions.Record record = Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem);
when(SUBSCRIPTION_MANAGER.create(any(), any(), any(Instant.class))) when(SUBSCRIPTIONS.create(any(), any(), any(Instant.class)))
.thenReturn(CompletableFuture.completedFuture(record)); .thenReturn(CompletableFuture.completedFuture(record));
// set up mocks // set up mocks
when(CLOCK.instant()).thenReturn(Instant.now()); when(CLOCK.instant()).thenReturn(Instant.now());
when(SUBSCRIPTION_MANAGER.get(any(), any())) when(SUBSCRIPTIONS.get(any(), any()))
.thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); .thenReturn(CompletableFuture.completedFuture(Subscriptions.GetResult.found(record)));
final Response response = RESOURCE_EXTENSION final Response response = RESOURCE_EXTENSION
.target(String.format("/v1/subscription/%s/level/%d/%s/%s", subscriberId, 5, "usd", "abcd")) .target(String.format("/v1/subscription/%s/level/%d/%s/%s", subscriberId, 5, "usd", "abcd"))
@ -661,26 +668,26 @@ class SubscriptionControllerTest {
final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey);
final String customerId = "customer"; final String customerId = "customer";
final Map<String, AttributeValue> dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), final Map<String, AttributeValue> dynamoItem = Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_PROCESSOR_ID_CUSTOMER_ID, Subscriptions.KEY_PROCESSOR_ID_CUSTOMER_ID,
b(new ProcessorCustomer(customerId, SubscriptionProcessor.BRAINTREE).toDynamoBytes()) b(new ProcessorCustomer(customerId, PaymentProvider.BRAINTREE).toDynamoBytes())
); );
final SubscriptionManager.Record record = SubscriptionManager.Record.from( final Subscriptions.Record record = Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem);
when(SUBSCRIPTION_MANAGER.create(any(), any(), any(Instant.class))) when(SUBSCRIPTIONS.create(any(), any(), any(Instant.class)))
.thenReturn(CompletableFuture.completedFuture(record)); .thenReturn(CompletableFuture.completedFuture(record));
// set up mocks // set up mocks
when(CLOCK.instant()).thenReturn(Instant.now()); when(CLOCK.instant()).thenReturn(Instant.now());
when(SUBSCRIPTION_MANAGER.get(any(), any())) when(SUBSCRIPTIONS.get(any(), any()))
.thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); .thenReturn(CompletableFuture.completedFuture(Subscriptions.GetResult.found(record)));
when(BRAINTREE_MANAGER.createSubscription(any(), any(), anyLong(), anyLong())) when(BRAINTREE_MANAGER.createSubscription(any(), any(), anyLong(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(new SubscriptionProcessorManager.SubscriptionId( .thenReturn(CompletableFuture.completedFuture(new SubscriptionPaymentProcessor.SubscriptionId(
"subscription"))); "subscription")));
when(SUBSCRIPTION_MANAGER.subscriptionCreated(any(), any(), any(), anyLong())) when(SUBSCRIPTIONS.subscriptionCreated(any(), any(), any(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
final Response response = RESOURCE_EXTENSION final Response response = RESOURCE_EXTENSION
@ -710,36 +717,36 @@ class SubscriptionControllerTest {
final String customerId = "customer"; final String customerId = "customer";
final String existingSubscriptionId = "existingSubscription"; final String existingSubscriptionId = "existingSubscription";
final Map<String, AttributeValue> dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), final Map<String, AttributeValue> dynamoItem = Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_PROCESSOR_ID_CUSTOMER_ID, Subscriptions.KEY_PROCESSOR_ID_CUSTOMER_ID,
b(new ProcessorCustomer(customerId, SubscriptionProcessor.BRAINTREE).toDynamoBytes()), b(new ProcessorCustomer(customerId, PaymentProvider.BRAINTREE).toDynamoBytes()),
SubscriptionManager.KEY_SUBSCRIPTION_ID, s(existingSubscriptionId) Subscriptions.KEY_SUBSCRIPTION_ID, s(existingSubscriptionId)
); );
final SubscriptionManager.Record record = SubscriptionManager.Record.from( final Subscriptions.Record record = Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem);
when(SUBSCRIPTION_MANAGER.create(any(), any(), any(Instant.class))) when(SUBSCRIPTIONS.create(any(), any(), any(Instant.class)))
.thenReturn(CompletableFuture.completedFuture(record)); .thenReturn(CompletableFuture.completedFuture(record));
// set up mocks // set up mocks
when(CLOCK.instant()).thenReturn(Instant.now()); when(CLOCK.instant()).thenReturn(Instant.now());
when(SUBSCRIPTION_MANAGER.get(any(), any())) when(SUBSCRIPTIONS.get(any(), any()))
.thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); .thenReturn(CompletableFuture.completedFuture(Subscriptions.GetResult.found(record)));
final Object subscriptionObj = new Object(); final Object subscriptionObj = new Object();
when(BRAINTREE_MANAGER.getSubscription(any())) when(BRAINTREE_MANAGER.getSubscription(any()))
.thenReturn(CompletableFuture.completedFuture(subscriptionObj)); .thenReturn(CompletableFuture.completedFuture(subscriptionObj));
when(BRAINTREE_MANAGER.getLevelAndCurrencyForSubscription(subscriptionObj)) when(BRAINTREE_MANAGER.getLevelAndCurrencyForSubscription(subscriptionObj))
.thenReturn(CompletableFuture.completedFuture( .thenReturn(CompletableFuture.completedFuture(
new SubscriptionProcessorManager.LevelAndCurrency(existingLevel, existingCurrency))); new SubscriptionPaymentProcessor.LevelAndCurrency(existingLevel, existingCurrency)));
final String updatedSubscriptionId = "updatedSubscriptionId"; final String updatedSubscriptionId = "updatedSubscriptionId";
if (expectUpdate) { if (expectUpdate) {
when(BRAINTREE_MANAGER.updateSubscription(any(), any(), anyLong(), anyString())) when(BRAINTREE_MANAGER.updateSubscription(any(), any(), anyLong(), anyString()))
.thenReturn(CompletableFuture.completedFuture(new SubscriptionProcessorManager.SubscriptionId( .thenReturn(CompletableFuture.completedFuture(new SubscriptionPaymentProcessor.SubscriptionId(
updatedSubscriptionId))); updatedSubscriptionId)));
when(SUBSCRIPTION_MANAGER.subscriptionLevelChanged(any(), any(), anyLong(), anyString())) when(SUBSCRIPTIONS.subscriptionLevelChanged(any(), any(), anyLong(), anyString()))
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
} }
@ -755,7 +762,7 @@ class SubscriptionControllerTest {
if (expectUpdate) { if (expectUpdate) {
verify(BRAINTREE_MANAGER).updateSubscription(any(), any(), eq(requestLevel), eq(idempotencyKey)); verify(BRAINTREE_MANAGER).updateSubscription(any(), any(), eq(requestLevel), eq(idempotencyKey));
verify(SUBSCRIPTION_MANAGER).subscriptionLevelChanged(any(), any(), eq(requestLevel), eq(updatedSubscriptionId)); verify(SUBSCRIPTIONS).subscriptionLevelChanged(any(), any(), eq(requestLevel), eq(updatedSubscriptionId));
} }
verifyNoMoreInteractions(BRAINTREE_MANAGER); verifyNoMoreInteractions(BRAINTREE_MANAGER);
@ -787,27 +794,27 @@ class SubscriptionControllerTest {
final String customerId = "customer"; final String customerId = "customer";
final String existingSubscriptionId = "existingSubscription"; final String existingSubscriptionId = "existingSubscription";
final Map<String, AttributeValue> dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), final Map<String, AttributeValue> dynamoItem = Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_PROCESSOR_ID_CUSTOMER_ID, Subscriptions.KEY_PROCESSOR_ID_CUSTOMER_ID,
b(new ProcessorCustomer(customerId, SubscriptionProcessor.BRAINTREE).toDynamoBytes()), b(new ProcessorCustomer(customerId, PaymentProvider.BRAINTREE).toDynamoBytes()),
SubscriptionManager.KEY_SUBSCRIPTION_ID, s(existingSubscriptionId)); Subscriptions.KEY_SUBSCRIPTION_ID, s(existingSubscriptionId));
final SubscriptionManager.Record record = SubscriptionManager.Record.from( final Subscriptions.Record record = Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem);
when(SUBSCRIPTION_MANAGER.create(any(), any(), any(Instant.class))) when(SUBSCRIPTIONS.create(any(), any(), any(Instant.class)))
.thenReturn(CompletableFuture.completedFuture(record)); .thenReturn(CompletableFuture.completedFuture(record));
when(CLOCK.instant()).thenReturn(Instant.now()); when(CLOCK.instant()).thenReturn(Instant.now());
when(SUBSCRIPTION_MANAGER.get(any(), any())) when(SUBSCRIPTIONS.get(any(), any()))
.thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); .thenReturn(CompletableFuture.completedFuture(Subscriptions.GetResult.found(record)));
final Object subscriptionObj = new Object(); final Object subscriptionObj = new Object();
when(BRAINTREE_MANAGER.getSubscription(any())) when(BRAINTREE_MANAGER.getSubscription(any()))
.thenReturn(CompletableFuture.completedFuture(subscriptionObj)); .thenReturn(CompletableFuture.completedFuture(subscriptionObj));
when(BRAINTREE_MANAGER.getLevelAndCurrencyForSubscription(subscriptionObj)) when(BRAINTREE_MANAGER.getLevelAndCurrencyForSubscription(subscriptionObj))
.thenReturn(CompletableFuture.completedFuture( .thenReturn(CompletableFuture.completedFuture(
new SubscriptionProcessorManager.LevelAndCurrency(201, "usd"))); new SubscriptionPaymentProcessor.LevelAndCurrency(201, "usd")));
// Try to change from a backup subscription (201) to a donation subscription (5) // Try to change from a backup subscription (201) to a donation subscription (5)
final Response response = RESOURCE_EXTENSION final Response response = RESOURCE_EXTENSION
@ -833,13 +840,13 @@ class SubscriptionControllerTest {
final String customerId = "customer"; final String customerId = "customer";
final String subscriptionId = "subscriptionId"; final String subscriptionId = "subscriptionId";
final Map<String, AttributeValue> dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), final Map<String, AttributeValue> dynamoItem = Map.of(Subscriptions.KEY_PASSWORD, b(new byte[16]),
SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_CREATED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()), Subscriptions.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()),
SubscriptionManager.KEY_PROCESSOR_ID_CUSTOMER_ID, Subscriptions.KEY_PROCESSOR_ID_CUSTOMER_ID,
b(new ProcessorCustomer(customerId, SubscriptionProcessor.BRAINTREE).toDynamoBytes()), b(new ProcessorCustomer(customerId, PaymentProvider.BRAINTREE).toDynamoBytes()),
SubscriptionManager.KEY_SUBSCRIPTION_ID, s(subscriptionId)); Subscriptions.KEY_SUBSCRIPTION_ID, s(subscriptionId));
final SubscriptionManager.Record record = SubscriptionManager.Record.from( final Subscriptions.Record record = Subscriptions.Record.from(
Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem);
final ReceiptCredentialRequest receiptRequest = new ClientZkReceiptOperations( final ReceiptCredentialRequest receiptRequest = new ClientZkReceiptOperations(
ServerSecretParams.generate().getPublicParams()).createReceiptCredentialRequestContext( ServerSecretParams.generate().getPublicParams()).createReceiptCredentialRequestContext(
@ -847,15 +854,15 @@ class SubscriptionControllerTest {
final ReceiptCredentialResponse receiptCredentialResponse = mock(ReceiptCredentialResponse.class); final ReceiptCredentialResponse receiptCredentialResponse = mock(ReceiptCredentialResponse.class);
when(CLOCK.instant()).thenReturn(Instant.now()); when(CLOCK.instant()).thenReturn(Instant.now());
when(SUBSCRIPTION_MANAGER.get(any(), any())) when(SUBSCRIPTIONS.get(any(), any()))
.thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); .thenReturn(CompletableFuture.completedFuture(Subscriptions.GetResult.found(record)));
when(BRAINTREE_MANAGER.getReceiptItem(subscriptionId)).thenReturn( when(BRAINTREE_MANAGER.getReceiptItem(subscriptionId)).thenReturn(
CompletableFuture.completedFuture(new SubscriptionProcessorManager.ReceiptItem( CompletableFuture.completedFuture(new SubscriptionPaymentProcessor.ReceiptItem(
"itemId", "itemId",
Instant.ofEpochSecond(10).plus(Duration.ofDays(1)), Instant.ofEpochSecond(10).plus(Duration.ofDays(1)),
level level
))); )));
when(ISSUED_RECEIPTS_MANAGER.recordIssuance(eq("itemId"), eq(SubscriptionProcessor.BRAINTREE), eq(receiptRequest), any())) when(ISSUED_RECEIPTS_MANAGER.recordIssuance(eq("itemId"), eq(PaymentProvider.BRAINTREE), eq(receiptRequest), any()))
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
when(ZK_OPS.issueReceiptCredential(any(), anyLong(), eq(level))).thenReturn(receiptCredentialResponse); when(ZK_OPS.issueReceiptCredential(any(), anyLong(), eq(level))).thenReturn(receiptCredentialResponse);
when(receiptCredentialResponse.serialize()).thenReturn(new byte[0]); when(receiptCredentialResponse.serialize()).thenReturn(new byte[0]);

View File

@ -330,21 +330,21 @@ public final class DynamoDbExtensionSchema {
List.of()), List.of()),
SUBSCRIPTIONS("subscriptions_test", SUBSCRIPTIONS("subscriptions_test",
SubscriptionManager.KEY_USER, Subscriptions.KEY_USER,
null, null,
List.of( List.of(
AttributeDefinition.builder() AttributeDefinition.builder()
.attributeName(SubscriptionManager.KEY_USER) .attributeName(Subscriptions.KEY_USER)
.attributeType(ScalarAttributeType.B) .attributeType(ScalarAttributeType.B)
.build(), .build(),
AttributeDefinition.builder() AttributeDefinition.builder()
.attributeName(SubscriptionManager.KEY_PROCESSOR_ID_CUSTOMER_ID) .attributeName(Subscriptions.KEY_PROCESSOR_ID_CUSTOMER_ID)
.attributeType(ScalarAttributeType.B) .attributeType(ScalarAttributeType.B)
.build()), .build()),
List.of(GlobalSecondaryIndex.builder() List.of(GlobalSecondaryIndex.builder()
.indexName(SubscriptionManager.INDEX_NAME) .indexName(Subscriptions.INDEX_NAME)
.keySchema(KeySchemaElement.builder() .keySchema(KeySchemaElement.builder()
.attributeName(SubscriptionManager.KEY_PROCESSOR_ID_CUSTOMER_ID) .attributeName(Subscriptions.KEY_PROCESSOR_ID_CUSTOMER_ID)
.keyType(KeyType.HASH) .keyType(KeyType.HASH)
.build()) .build())
.projection(Projection.builder() .projection(Projection.builder()

View File

@ -19,7 +19,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequest; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequest;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
class IssuedReceiptsManagerTest { class IssuedReceiptsManagerTest {
@ -47,19 +47,19 @@ class IssuedReceiptsManagerTest {
Instant now = Instant.ofEpochSecond(NOW_EPOCH_SECONDS); Instant now = Instant.ofEpochSecond(NOW_EPOCH_SECONDS);
byte[] request1 = TestRandomUtil.nextBytes(20); byte[] request1 = TestRandomUtil.nextBytes(20);
when(receiptCredentialRequest.serialize()).thenReturn(request1); when(receiptCredentialRequest.serialize()).thenReturn(request1);
CompletableFuture<Void> future = issuedReceiptsManager.recordIssuance("item-1", SubscriptionProcessor.STRIPE, CompletableFuture<Void> future = issuedReceiptsManager.recordIssuance("item-1", PaymentProvider.STRIPE,
receiptCredentialRequest, now); receiptCredentialRequest, now);
assertThat(future).succeedsWithin(Duration.ofSeconds(3)); assertThat(future).succeedsWithin(Duration.ofSeconds(3));
// same request should succeed // same request should succeed
future = issuedReceiptsManager.recordIssuance("item-1", SubscriptionProcessor.STRIPE, receiptCredentialRequest, future = issuedReceiptsManager.recordIssuance("item-1", PaymentProvider.STRIPE, receiptCredentialRequest,
now); now);
assertThat(future).succeedsWithin(Duration.ofSeconds(3)); assertThat(future).succeedsWithin(Duration.ofSeconds(3));
// same item with new request should fail // same item with new request should fail
byte[] request2 = TestRandomUtil.nextBytes(20); byte[] request2 = TestRandomUtil.nextBytes(20);
when(receiptCredentialRequest.serialize()).thenReturn(request2); when(receiptCredentialRequest.serialize()).thenReturn(request2);
future = issuedReceiptsManager.recordIssuance("item-1", SubscriptionProcessor.STRIPE, receiptCredentialRequest, future = issuedReceiptsManager.recordIssuance("item-1", PaymentProvider.STRIPE, receiptCredentialRequest,
now); now);
assertThat(future).failsWithin(Duration.ofSeconds(3)). assertThat(future).failsWithin(Duration.ofSeconds(3)).
withThrowableOfType(Throwable.class). withThrowableOfType(Throwable.class).
@ -70,7 +70,7 @@ class IssuedReceiptsManagerTest {
"status 409")); "status 409"));
// different item with new request should be okay though // different item with new request should be okay though
future = issuedReceiptsManager.recordIssuance("item-2", SubscriptionProcessor.STRIPE, receiptCredentialRequest, future = issuedReceiptsManager.recordIssuance("item-2", PaymentProvider.STRIPE, receiptCredentialRequest,
now); now);
assertThat(future).succeedsWithin(Duration.ofSeconds(3)); assertThat(future).succeedsWithin(Duration.ofSeconds(3));
} }

View File

@ -6,9 +6,9 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult.Type.FOUND; import static org.whispersystems.textsecuregcm.storage.Subscriptions.GetResult.Type.FOUND;
import static org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult.Type.NOT_STORED; import static org.whispersystems.textsecuregcm.storage.Subscriptions.GetResult.Type.NOT_STORED;
import static org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult.Type.PASSWORD_MISMATCH; import static org.whispersystems.textsecuregcm.storage.Subscriptions.GetResult.Type.PASSWORD_MISMATCH;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.time.Duration; import java.time.Duration;
@ -25,13 +25,13 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult; import org.whispersystems.textsecuregcm.storage.Subscriptions.GetResult;
import org.whispersystems.textsecuregcm.storage.SubscriptionManager.Record; import org.whispersystems.textsecuregcm.storage.Subscriptions.Record;
import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer; import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.subscriptions.PaymentProvider;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
class SubscriptionManagerTest { class SubscriptionsTest {
private static final long NOW_EPOCH_SECONDS = 1_500_000_000L; private static final long NOW_EPOCH_SECONDS = 1_500_000_000L;
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(3); private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(3);
@ -44,7 +44,7 @@ class SubscriptionManagerTest {
byte[] password; byte[] password;
String customer; String customer;
Instant created; Instant created;
SubscriptionManager subscriptionManager; Subscriptions subscriptions;
@BeforeEach @BeforeEach
void beforeEach() { void beforeEach() {
@ -52,7 +52,7 @@ class SubscriptionManagerTest {
password = TestRandomUtil.nextBytes(16); password = TestRandomUtil.nextBytes(16);
customer = Base64.getEncoder().encodeToString(TestRandomUtil.nextBytes(16)); customer = Base64.getEncoder().encodeToString(TestRandomUtil.nextBytes(16));
created = Instant.ofEpochSecond(NOW_EPOCH_SECONDS); created = Instant.ofEpochSecond(NOW_EPOCH_SECONDS);
subscriptionManager = new SubscriptionManager( subscriptions = new Subscriptions(
Tables.SUBSCRIPTIONS.tableName(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient()); Tables.SUBSCRIPTIONS.tableName(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient());
} }
@ -63,29 +63,29 @@ class SubscriptionManagerTest {
Instant created1 = Instant.ofEpochSecond(NOW_EPOCH_SECONDS); Instant created1 = Instant.ofEpochSecond(NOW_EPOCH_SECONDS);
Instant created2 = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); Instant created2 = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1);
CompletableFuture<GetResult> getFuture = subscriptionManager.get(user, password1); CompletableFuture<GetResult> getFuture = subscriptions.get(user, password1);
assertThat(getFuture).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> { assertThat(getFuture).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> {
assertThat(getResult.type).isEqualTo(NOT_STORED); assertThat(getResult.type).isEqualTo(NOT_STORED);
assertThat(getResult.record).isNull(); assertThat(getResult.record).isNull();
}); });
getFuture = subscriptionManager.get(user, password2); getFuture = subscriptions.get(user, password2);
assertThat(getFuture).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> { assertThat(getFuture).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> {
assertThat(getResult.type).isEqualTo(NOT_STORED); assertThat(getResult.type).isEqualTo(NOT_STORED);
assertThat(getResult.record).isNull(); assertThat(getResult.record).isNull();
}); });
CompletableFuture<SubscriptionManager.Record> createFuture = CompletableFuture<Subscriptions.Record> createFuture =
subscriptionManager.create(user, password1, created1); subscriptions.create(user, password1, created1);
Consumer<Record> recordRequirements = checkFreshlyCreatedRecord(user, password1, created1); Consumer<Record> recordRequirements = checkFreshlyCreatedRecord(user, password1, created1);
assertThat(createFuture).succeedsWithin(DEFAULT_TIMEOUT).satisfies(recordRequirements); assertThat(createFuture).succeedsWithin(DEFAULT_TIMEOUT).satisfies(recordRequirements);
// password check fails so this should return null // password check fails so this should return null
createFuture = subscriptionManager.create(user, password2, created2); createFuture = subscriptions.create(user, password2, created2);
assertThat(createFuture).succeedsWithin(DEFAULT_TIMEOUT).isNull(); assertThat(createFuture).succeedsWithin(DEFAULT_TIMEOUT).isNull();
// password check matches, but the record already exists so nothing should get updated // password check matches, but the record already exists so nothing should get updated
createFuture = subscriptionManager.create(user, password1, created2); createFuture = subscriptions.create(user, password1, created2);
assertThat(createFuture).succeedsWithin(DEFAULT_TIMEOUT).satisfies(recordRequirements); assertThat(createFuture).succeedsWithin(DEFAULT_TIMEOUT).satisfies(recordRequirements);
} }
@ -93,20 +93,20 @@ class SubscriptionManagerTest {
void testGet() { void testGet() {
byte[] wrongUser = TestRandomUtil.nextBytes(16); byte[] wrongUser = TestRandomUtil.nextBytes(16);
byte[] wrongPassword = TestRandomUtil.nextBytes(16); byte[] wrongPassword = TestRandomUtil.nextBytes(16);
assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT); assertThat(subscriptions.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT);
assertThat(subscriptionManager.get(user, password)).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> { assertThat(subscriptions.get(user, password)).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> {
assertThat(getResult.type).isEqualTo(FOUND); assertThat(getResult.type).isEqualTo(FOUND);
assertThat(getResult.record).isNotNull().satisfies(checkFreshlyCreatedRecord(user, password, created)); assertThat(getResult.record).isNotNull().satisfies(checkFreshlyCreatedRecord(user, password, created));
}); });
assertThat(subscriptionManager.get(user, wrongPassword)).succeedsWithin(DEFAULT_TIMEOUT) assertThat(subscriptions.get(user, wrongPassword)).succeedsWithin(DEFAULT_TIMEOUT)
.satisfies(getResult -> { .satisfies(getResult -> {
assertThat(getResult.type).isEqualTo(PASSWORD_MISMATCH); assertThat(getResult.type).isEqualTo(PASSWORD_MISMATCH);
assertThat(getResult.record).isNull(); assertThat(getResult.record).isNull();
}); });
assertThat(subscriptionManager.get(wrongUser, password)).succeedsWithin(DEFAULT_TIMEOUT) assertThat(subscriptions.get(wrongUser, password)).succeedsWithin(DEFAULT_TIMEOUT)
.satisfies(getResult -> { .satisfies(getResult -> {
assertThat(getResult.type).isEqualTo(NOT_STORED); assertThat(getResult.type).isEqualTo(NOT_STORED);
assertThat(getResult.record).isNull(); assertThat(getResult.record).isNull();
@ -116,25 +116,25 @@ class SubscriptionManagerTest {
@Test @Test
void testSetCustomerIdAndProcessor() throws Exception { void testSetCustomerIdAndProcessor() throws Exception {
Instant subscriptionUpdated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); Instant subscriptionUpdated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1);
assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT); assertThat(subscriptions.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT);
final CompletableFuture<GetResult> getUser = subscriptionManager.get(user, password); final CompletableFuture<GetResult> getUser = subscriptions.get(user, password);
assertThat(getUser).succeedsWithin(DEFAULT_TIMEOUT); assertThat(getUser).succeedsWithin(DEFAULT_TIMEOUT);
final Record userRecord = getUser.get().record; final Record userRecord = getUser.get().record;
assertThat(subscriptionManager.setProcessorAndCustomerId(userRecord, assertThat(subscriptions.setProcessorAndCustomerId(userRecord,
new ProcessorCustomer(customer, SubscriptionProcessor.STRIPE), new ProcessorCustomer(customer, PaymentProvider.STRIPE),
subscriptionUpdated)).succeedsWithin(DEFAULT_TIMEOUT) subscriptionUpdated)).succeedsWithin(DEFAULT_TIMEOUT)
.hasFieldOrPropertyWithValue("processorCustomer", .hasFieldOrPropertyWithValue("processorCustomer",
Optional.of(new ProcessorCustomer(customer, SubscriptionProcessor.STRIPE))); Optional.of(new ProcessorCustomer(customer, PaymentProvider.STRIPE)));
final Condition<Throwable> clientError409Condition = new Condition<>(e -> final Condition<Throwable> clientError409Condition = new Condition<>(e ->
e instanceof ClientErrorException cee && cee.getResponse().getStatus() == 409, "Client error: 409"); e instanceof ClientErrorException cee && cee.getResponse().getStatus() == 409, "Client error: 409");
// changing the customer ID is not permitted // changing the customer ID is not permitted
assertThat( assertThat(
subscriptionManager.setProcessorAndCustomerId(userRecord, subscriptions.setProcessorAndCustomerId(userRecord,
new ProcessorCustomer(customer + "1", SubscriptionProcessor.STRIPE), new ProcessorCustomer(customer + "1", PaymentProvider.STRIPE),
subscriptionUpdated)).failsWithin(DEFAULT_TIMEOUT) subscriptionUpdated)).failsWithin(DEFAULT_TIMEOUT)
.withThrowableOfType(ExecutionException.class) .withThrowableOfType(ExecutionException.class)
.withCauseInstanceOf(ClientErrorException.class) .withCauseInstanceOf(ClientErrorException.class)
@ -143,16 +143,16 @@ class SubscriptionManagerTest {
// calling setProcessorAndCustomerId() with the same customer ID is also an error // calling setProcessorAndCustomerId() with the same customer ID is also an error
assertThat( assertThat(
subscriptionManager.setProcessorAndCustomerId(userRecord, subscriptions.setProcessorAndCustomerId(userRecord,
new ProcessorCustomer(customer, SubscriptionProcessor.STRIPE), new ProcessorCustomer(customer, PaymentProvider.STRIPE),
subscriptionUpdated)).failsWithin(DEFAULT_TIMEOUT) subscriptionUpdated)).failsWithin(DEFAULT_TIMEOUT)
.withThrowableOfType(ExecutionException.class) .withThrowableOfType(ExecutionException.class)
.withCauseInstanceOf(ClientErrorException.class) .withCauseInstanceOf(ClientErrorException.class)
.extracting(Throwable::getCause) .extracting(Throwable::getCause)
.satisfies(clientError409Condition); .satisfies(clientError409Condition);
assertThat(subscriptionManager.getSubscriberUserByProcessorCustomer( assertThat(subscriptions.getSubscriberUserByProcessorCustomer(
new ProcessorCustomer(customer, SubscriptionProcessor.STRIPE))) new ProcessorCustomer(customer, PaymentProvider.STRIPE)))
.succeedsWithin(DEFAULT_TIMEOUT). .succeedsWithin(DEFAULT_TIMEOUT).
isEqualTo(user); isEqualTo(user);
} }
@ -160,17 +160,17 @@ class SubscriptionManagerTest {
@Test @Test
void testLookupByCustomerId() throws Exception { void testLookupByCustomerId() throws Exception {
Instant subscriptionUpdated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); Instant subscriptionUpdated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1);
assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT); assertThat(subscriptions.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT);
final CompletableFuture<GetResult> getUser = subscriptionManager.get(user, password); final CompletableFuture<GetResult> getUser = subscriptions.get(user, password);
assertThat(getUser).succeedsWithin(DEFAULT_TIMEOUT); assertThat(getUser).succeedsWithin(DEFAULT_TIMEOUT);
final Record userRecord = getUser.get().record; final Record userRecord = getUser.get().record;
assertThat(subscriptionManager.setProcessorAndCustomerId(userRecord, assertThat(subscriptions.setProcessorAndCustomerId(userRecord,
new ProcessorCustomer(customer, SubscriptionProcessor.STRIPE), new ProcessorCustomer(customer, PaymentProvider.STRIPE),
subscriptionUpdated)).succeedsWithin(DEFAULT_TIMEOUT); subscriptionUpdated)).succeedsWithin(DEFAULT_TIMEOUT);
assertThat(subscriptionManager.getSubscriberUserByProcessorCustomer( assertThat(subscriptions.getSubscriberUserByProcessorCustomer(
new ProcessorCustomer(customer, SubscriptionProcessor.STRIPE))). new ProcessorCustomer(customer, PaymentProvider.STRIPE))).
succeedsWithin(DEFAULT_TIMEOUT). succeedsWithin(DEFAULT_TIMEOUT).
isEqualTo(user); isEqualTo(user);
} }
@ -178,9 +178,9 @@ class SubscriptionManagerTest {
@Test @Test
void testCanceledAt() { void testCanceledAt() {
Instant canceled = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 42); Instant canceled = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 42);
assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT); assertThat(subscriptions.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT);
assertThat(subscriptionManager.canceledAt(user, canceled)).succeedsWithin(DEFAULT_TIMEOUT); assertThat(subscriptions.canceledAt(user, canceled)).succeedsWithin(DEFAULT_TIMEOUT);
assertThat(subscriptionManager.get(user, password)).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> { assertThat(subscriptions.get(user, password)).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> {
assertThat(getResult).isNotNull(); assertThat(getResult).isNotNull();
assertThat(getResult.type).isEqualTo(FOUND); assertThat(getResult.type).isEqualTo(FOUND);
assertThat(getResult.record).isNotNull().satisfies(record -> { assertThat(getResult.record).isNotNull().satisfies(record -> {
@ -196,10 +196,10 @@ class SubscriptionManagerTest {
String subscriptionId = Base64.getEncoder().encodeToString(TestRandomUtil.nextBytes(16)); String subscriptionId = Base64.getEncoder().encodeToString(TestRandomUtil.nextBytes(16));
Instant subscriptionCreated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); Instant subscriptionCreated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1);
long level = 42; long level = 42;
assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT); assertThat(subscriptions.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT);
assertThat(subscriptionManager.subscriptionCreated(user, subscriptionId, subscriptionCreated, level)). assertThat(subscriptions.subscriptionCreated(user, subscriptionId, subscriptionCreated, level)).
succeedsWithin(DEFAULT_TIMEOUT); succeedsWithin(DEFAULT_TIMEOUT);
assertThat(subscriptionManager.get(user, password)).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> { assertThat(subscriptions.get(user, password)).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> {
assertThat(getResult).isNotNull(); assertThat(getResult).isNotNull();
assertThat(getResult.type).isEqualTo(FOUND); assertThat(getResult.type).isEqualTo(FOUND);
assertThat(getResult.record).isNotNull().satisfies(record -> { assertThat(getResult.record).isNotNull().satisfies(record -> {
@ -217,12 +217,12 @@ class SubscriptionManagerTest {
Instant at = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 500); Instant at = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 500);
long level = 1776; long level = 1776;
String updatedSubscriptionId = "new"; String updatedSubscriptionId = "new";
assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT); assertThat(subscriptions.create(user, password, created)).succeedsWithin(DEFAULT_TIMEOUT);
assertThat(subscriptionManager.subscriptionCreated(user, "original", created, level - 1)).succeedsWithin( assertThat(subscriptions.subscriptionCreated(user, "original", created, level - 1)).succeedsWithin(
DEFAULT_TIMEOUT); DEFAULT_TIMEOUT);
assertThat(subscriptionManager.subscriptionLevelChanged(user, at, level, updatedSubscriptionId)).succeedsWithin( assertThat(subscriptions.subscriptionLevelChanged(user, at, level, updatedSubscriptionId)).succeedsWithin(
DEFAULT_TIMEOUT); DEFAULT_TIMEOUT);
assertThat(subscriptionManager.get(user, password)).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> { assertThat(subscriptions.get(user, password)).succeedsWithin(DEFAULT_TIMEOUT).satisfies(getResult -> {
assertThat(getResult).isNotNull(); assertThat(getResult).isNotNull();
assertThat(getResult.type).isEqualTo(FOUND); assertThat(getResult.type).isEqualTo(FOUND);
assertThat(getResult.record).isNotNull().satisfies(record -> { assertThat(getResult.record).isNotNull().satisfies(record -> {
@ -237,7 +237,7 @@ class SubscriptionManagerTest {
@Test @Test
void testProcessorAndCustomerId() { void testProcessorAndCustomerId() {
final ProcessorCustomer processorCustomer = final ProcessorCustomer processorCustomer =
new ProcessorCustomer("abc", SubscriptionProcessor.STRIPE); new ProcessorCustomer("abc", PaymentProvider.STRIPE);
assertThat(processorCustomer.toDynamoBytes()).isEqualTo(new byte[]{1, 97, 98, 99}); assertThat(processorCustomer.toDynamoBytes()).isEqualTo(new byte[]{1, 97, 98, 99});
} }

View File

@ -8,9 +8,9 @@ class ProcessorCustomerTest {
@Test @Test
void toDynamoBytes() { void toDynamoBytes() {
final ProcessorCustomer processorCustomer = new ProcessorCustomer("Test", SubscriptionProcessor.BRAINTREE); final ProcessorCustomer processorCustomer = new ProcessorCustomer("Test", PaymentProvider.BRAINTREE);
assertArrayEquals(new byte[] { SubscriptionProcessor.BRAINTREE.getId(), 'T', 'e', 's', 't' }, assertArrayEquals(new byte[] { PaymentProvider.BRAINTREE.getId(), 'T', 'e', 's', 't' },
processorCustomer.toDynamoBytes()); processorCustomer.toDynamoBytes());
} }
} }