Support changing just the currency of an existing subscription

This commit is contained in:
Chris Eager 2023-01-18 12:54:34 -06:00 committed by Chris Eager
parent e8978ef91c
commit dd98f7f043
5 changed files with 93 additions and 83 deletions

View File

@ -510,64 +510,64 @@ public class SubscriptionController {
.thenApply(this::requireRecordFromGetResult) .thenApply(this::requireRecordFromGetResult)
.thenCompose(record -> { .thenCompose(record -> {
final ProcessorCustomer processorCustomer = record.getProcessorCustomer() final ProcessorCustomer processorCustomer = record.getProcessorCustomer()
.orElseThrow(() -> .orElseThrow(() ->
// a missing customer ID indicates the client made requests out of order, // a missing customer ID indicates the client made requests out of order,
// and needs to call create_payment_method to create a customer for the given payment method // and needs to call create_payment_method to create a customer for the given payment method
new ClientErrorException(Status.CONFLICT)); new ClientErrorException(Status.CONFLICT));
final String subscriptionTemplateId = getSubscriptionTemplateId(level, currency, processorCustomer.processor()); final String subscriptionTemplateId = getSubscriptionTemplateId(level, currency,
processorCustomer.processor());
final SubscriptionProcessorManager manager = getManagerForProcessor(processorCustomer.processor()); final SubscriptionProcessorManager manager = getManagerForProcessor(processorCustomer.processor());
return Optional.ofNullable(record.subscriptionId) return Optional.ofNullable(record.subscriptionId).map(subId -> {
.map(subId -> { // we already have a subscription in our records so let's check the level and currency,
// we already have a subscription in our records so let's check the level and change it if needed // and only change it if needed
return manager.getSubscription(subId).thenCompose( return manager.getSubscription(subId).thenCompose(
subscription -> manager.getLevelForSubscription(subscription).thenCompose(existingLevel -> { subscription -> manager.getLevelAndCurrencyForSubscription(subscription)
if (level == existingLevel) { .thenCompose(existingLevelAndCurrency -> {
return CompletableFuture.completedFuture(subscription); if (existingLevelAndCurrency.equals(new SubscriptionProcessorManager.LevelAndCurrency(level,
} currency.toLowerCase(Locale.ROOT)))) {
return manager.updateSubscription( return CompletableFuture.completedFuture(subscription);
subscription, subscriptionTemplateId, level, idempotencyKey)
.thenCompose(updatedSubscription ->
subscriptionManager.subscriptionLevelChanged(requestData.subscriberUser,
requestData.now,
level, updatedSubscription.id())
.thenApply(unused -> updatedSubscription));
}));
}).orElseGet(() -> {
long lastSubscriptionCreatedAt =
record.subscriptionCreatedAt != null ? record.subscriptionCreatedAt.getEpochSecond() : 0;
// we don't have a subscription yet so create it and then record the subscription id
//
// this relies on stripe's idempotency key to avoid creating more than one subscription if the client
// retries this request
return manager.createSubscription(processorCustomer.customerId(),
subscriptionTemplateId,
level,
lastSubscriptionCreatedAt)
.exceptionally(e -> {
if (e.getCause() instanceof StripeException stripeException
&& stripeException.getCode().equals("subscription_payment_intent_requires_action")) {
throw new BadRequestException(Response.status(Status.BAD_REQUEST)
.entity(new SetSubscriptionLevelErrorResponse(List.of(
new SetSubscriptionLevelErrorResponse.Error(
SetSubscriptionLevelErrorResponse.Error.Type.PAYMENT_REQUIRES_ACTION, null
)
))).build());
} }
if (e instanceof RuntimeException re) { return manager.updateSubscription(
throw re; subscription, subscriptionTemplateId, level, idempotencyKey)
} .thenCompose(updatedSubscription ->
subscriptionManager.subscriptionLevelChanged(requestData.subscriberUser,
requestData.now,
level, updatedSubscription.id())
.thenApply(unused -> updatedSubscription));
}));
}).orElseGet(() -> {
long lastSubscriptionCreatedAt =
record.subscriptionCreatedAt != null ? record.subscriptionCreatedAt.getEpochSecond() : 0;
throw new CompletionException(e); // we don't have a subscription yet so create it and then record the subscription id
}) return manager.createSubscription(processorCustomer.customerId(),
.thenCompose(subscription -> subscriptionManager.subscriptionCreated( subscriptionTemplateId,
requestData.subscriberUser, subscription.id(), requestData.now, level) level,
.thenApply(unused -> subscription)); lastSubscriptionCreatedAt)
}); .exceptionally(e -> {
if (e.getCause() instanceof StripeException stripeException
&& stripeException.getCode().equals("subscription_payment_intent_requires_action")) {
throw new BadRequestException(Response.status(Status.BAD_REQUEST)
.entity(new SetSubscriptionLevelErrorResponse(List.of(
new SetSubscriptionLevelErrorResponse.Error(
SetSubscriptionLevelErrorResponse.Error.Type.PAYMENT_REQUIRES_ACTION, null
)
))).build());
}
if (e instanceof RuntimeException re) {
throw re;
}
throw new CompletionException(e);
})
.thenCompose(subscription -> subscriptionManager.subscriptionCreated(
requestData.subscriberUser, subscription.id(), requestData.now, level)
.thenApply(unused -> subscription));
});
}) })
.thenApply(unused -> Response.ok(new SetSubscriptionLevelSuccessResponse(level)).build()); .thenApply(unused -> Response.ok(new SetSubscriptionLevelSuccessResponse(level)).build());
} }

View File

@ -358,11 +358,13 @@ public class BraintreeManager implements SubscriptionProcessorManager {
} }
@Override @Override
public CompletableFuture<Long> getLevelForSubscription(Object subscriptionObj) { public CompletableFuture<LevelAndCurrency> getLevelAndCurrencyForSubscription(Object subscriptionObj) {
final Subscription subscription = getSubscription(subscriptionObj); final Subscription subscription = getSubscription(subscriptionObj);
return findPlan(subscription.getPlanId()) return findPlan(subscription.getPlanId())
.thenApply(this::getLevelForPlan); .thenApply(
plan -> new LevelAndCurrency(getLevelForPlan(plan), plan.getCurrencyIsoCode().toLowerCase(Locale.ROOT)));
} }
private CompletableFuture<Plan> findPlan(String planId) { private CompletableFuture<Plan> findPlan(String planId) {

View File

@ -246,37 +246,37 @@ public class StripeManager implements SubscriptionProcessorManager {
@Override @Override
public CompletableFuture<SubscriptionId> createSubscription(String customerId, String priceId, long level, public CompletableFuture<SubscriptionId> createSubscription(String customerId, String priceId, long level,
long lastSubscriptionCreatedAt) { long lastSubscriptionCreatedAt) {
// this relies on Stripe's idempotency key to avoid creating more than one subscription if the client
// retries this request
return CompletableFuture.supplyAsync(() -> { return CompletableFuture.supplyAsync(() -> {
SubscriptionCreateParams params = SubscriptionCreateParams.builder() SubscriptionCreateParams params = SubscriptionCreateParams.builder()
.setCustomer(customerId) .setCustomer(customerId)
.setOffSession(true) .setOffSession(true)
.setPaymentBehavior(SubscriptionCreateParams.PaymentBehavior.ERROR_IF_INCOMPLETE) .setPaymentBehavior(SubscriptionCreateParams.PaymentBehavior.ERROR_IF_INCOMPLETE)
.addItem(SubscriptionCreateParams.Item.builder() .addItem(SubscriptionCreateParams.Item.builder()
.setPrice(priceId) .setPrice(priceId)
.build()) .build())
.putMetadata(METADATA_KEY_LEVEL, Long.toString(level)) .putMetadata(METADATA_KEY_LEVEL, Long.toString(level))
.build(); .build();
try { try {
// the idempotency key intentionally excludes priceId // the idempotency key intentionally excludes priceId
// //
// If the client tells the server several times in a row before the initial creation of a subscription to // If the client tells the server several times in a row before the initial creation of a subscription to
// create a subscription, we want to ensure only one gets created. // create a subscription, we want to ensure only one gets created.
return Subscription.create(params, commonOptions(generateIdempotencyKeyForCreateSubscription( return Subscription.create(params, commonOptions(generateIdempotencyKeyForCreateSubscription(
customerId, lastSubscriptionCreatedAt))); customerId, lastSubscriptionCreatedAt)));
} catch (StripeException e) { } catch (StripeException e) {
throw new CompletionException(e); throw new CompletionException(e);
} }
}, executor) }, executor)
.thenApply(subscription -> new SubscriptionId(subscription.getId())); .thenApply(subscription -> new SubscriptionId(subscription.getId()));
} }
@Override @Override
public CompletableFuture<SubscriptionId> updateSubscription( public CompletableFuture<SubscriptionId> updateSubscription(
Object subscriptionObj, String priceId, long level, String idempotencyKey) { Object subscriptionObj, String priceId, long level, String idempotencyKey) {
if (!(subscriptionObj instanceof final Subscription subscription)) { final Subscription subscription = getSubscription(subscriptionObj);
throw new IllegalArgumentException("invalid subscription object: " + subscriptionObj.getClass().getName());
}
return CompletableFuture.supplyAsync(() -> { return CompletableFuture.supplyAsync(() -> {
List<SubscriptionUpdateParams.Item> items = new ArrayList<>(); List<SubscriptionUpdateParams.Item> items = new ArrayList<>();
@ -400,12 +400,12 @@ public class StripeManager implements SubscriptionProcessorManager {
} }
@Override @Override
public CompletableFuture<Long> getLevelForSubscription(Object subscriptionObj) { public CompletableFuture<LevelAndCurrency> getLevelAndCurrencyForSubscription(Object subscriptionObj) {
if (!(subscriptionObj instanceof final Subscription subscription)) { final Subscription subscription = getSubscription(subscriptionObj);
throw new IllegalArgumentException("Invalid subscription object: " + subscriptionObj.getClass().getName()); return getProductForSubscription(subscription).thenApply(
} product -> new LevelAndCurrency(getLevelForProduct(product), subscription.getCurrency().toLowerCase(
return getProductForSubscription(subscription).thenApply(this::getLevelForProduct); Locale.ROOT)));
} }
public CompletableFuture<Long> getLevelForPrice(Price price) { public CompletableFuture<Long> getLevelForPrice(Price price) {

View File

@ -42,12 +42,16 @@ public interface SubscriptionProcessorManager {
CompletableFuture<Object> getSubscription(String subscriptionId); CompletableFuture<Object> getSubscription(String subscriptionId);
CompletableFuture<SubscriptionId> createSubscription(String customerId, String templateId, long level, CompletableFuture<SubscriptionId> createSubscription(String customerId, String templateId, long level,
long lastSubscriptionCreatedAt); long lastSubscriptionCreatedAt);
CompletableFuture<SubscriptionId> updateSubscription( CompletableFuture<SubscriptionId> updateSubscription(
Object subscription, String templateId, long level, String idempotencyKey); Object subscription, String templateId, long level, String idempotencyKey);
CompletableFuture<Long> getLevelForSubscription(Object subscription); /**
* @param subscription
* @return the subscriptions current level and lower-case currency code
*/
CompletableFuture<LevelAndCurrency> getLevelAndCurrencyForSubscription(Object subscription);
CompletableFuture<Void> cancelAllActiveSubscriptions(String customerId); CompletableFuture<Void> cancelAllActiveSubscriptions(String customerId);
@ -160,4 +164,8 @@ public interface SubscriptionProcessorManager {
} }
record LevelAndCurrency(long level, String currency) {
}
} }

View File

@ -72,7 +72,6 @@ import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor;
import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessorManager; import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessorManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -548,7 +547,8 @@ class SubscriptionControllerTest {
when(BRAINTREE_MANAGER.getSubscription(any())) when(BRAINTREE_MANAGER.getSubscription(any()))
.thenReturn(CompletableFuture.completedFuture(subscriptionObj)); .thenReturn(CompletableFuture.completedFuture(subscriptionObj));
when(BRAINTREE_MANAGER.getLevelAndCurrencyForSubscription(subscriptionObj)) when(BRAINTREE_MANAGER.getLevelAndCurrencyForSubscription(subscriptionObj))
.thenReturn(CompletableFuture.completedFuture(new Pair<>(existingLevel, existingCurrency))); .thenReturn(CompletableFuture.completedFuture(
new SubscriptionProcessorManager.LevelAndCurrency(existingLevel, existingCurrency)));
final String updatedSubscriptionId = "updatedSubscriptionId"; final String updatedSubscriptionId = "updatedSubscriptionId";
if (expectUpdate) { if (expectUpdate) {