diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index bfe524c20..c12ed591e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -205,7 +205,7 @@ import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.storage.SubscriptionManager; import org.whispersystems.textsecuregcm.storage.VerificationCodeStore; -import org.whispersystems.textsecuregcm.stripe.StripeManager; +import org.whispersystems.textsecuregcm.subscriptions.StripeManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig; import org.whispersystems.textsecuregcm.util.HostnameUtil; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java index 00feb2d5e..c3cc6e2ec 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java @@ -48,6 +48,7 @@ import javax.validation.constraints.NotNull; import javax.ws.rs.BadRequestException; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; +import javax.ws.rs.DefaultValue; import javax.ws.rs.ForbiddenException; import javax.ws.rs.GET; import javax.ws.rs.InternalServerErrorException; @@ -58,6 +59,7 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.ProcessingException; import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; import javax.ws.rs.WebApplicationException; import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.core.Context; @@ -87,7 +89,11 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager; import org.whispersystems.textsecuregcm.storage.SubscriptionManager; import org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult; -import org.whispersystems.textsecuregcm.stripe.StripeManager; +import org.whispersystems.textsecuregcm.subscriptions.PaymentMethod; +import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer; +import org.whispersystems.textsecuregcm.subscriptions.StripeManager; +import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; +import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessorManager; import org.whispersystems.textsecuregcm.util.ExactlySize; @Path("/v1/subscription") @@ -179,15 +185,13 @@ public class SubscriptionController { throw new ForbiddenException("subscriberId mismatch"); } else if (getResult == GetResult.NOT_STORED) { // create a customer and write it to ddb - return stripeManager.createCustomer(requestData.subscriberUser).thenCompose( - customer -> subscriptionManager.create( - requestData.subscriberUser, requestData.hmac, customer.getId(), requestData.now) - .thenApply(updatedRecord -> { - if (updatedRecord == null) { - throw new NotFoundException(); - } - return updatedRecord; - })); + return subscriptionManager.create(requestData.subscriberUser, requestData.hmac, requestData.now) + .thenApply(updatedRecord -> { + if (updatedRecord == null) { + throw new ForbiddenException(); + } + return updatedRecord; + }); } else { // already exists so just touch access time and return return subscriptionManager.accessedAt(requestData.subscriberUser, requestData.now) @@ -197,20 +201,8 @@ public class SubscriptionController { .thenApply(record -> Response.ok().build()); } - public static class CreatePaymentMethodResponse { + record CreatePaymentMethodResponse(String clientSecret, SubscriptionProcessor processor) { - private final String clientSecret; - - @JsonCreator - public CreatePaymentMethodResponse( - @JsonProperty("clientSecret") String clientSecret) { - this.clientSecret = clientSecret; - } - - @SuppressWarnings("unused") - public String getClientSecret() { - return clientSecret; - } } @Timed @@ -220,12 +212,39 @@ public class SubscriptionController { @Produces(MediaType.APPLICATION_JSON) public CompletableFuture createPaymentMethod( @Auth Optional authenticatedAccount, - @PathParam("subscriberId") String subscriberId) { + @PathParam("subscriberId") String subscriberId, + @QueryParam("type") @DefaultValue("CARD") PaymentMethod paymentMethodType) { + RequestData requestData = RequestData.process(authenticatedAccount, subscriberId, clock); + + final SubscriptionProcessorManager subscriptionProcessorManager = getManagerForPaymentMethod(paymentMethodType); + return subscriptionManager.get(requestData.subscriberUser, requestData.hmac) .thenApply(this::requireRecordFromGetResult) - .thenCompose(record -> stripeManager.createSetupIntent(record.customerId)) - .thenApply(setupIntent -> Response.ok(new CreatePaymentMethodResponse(setupIntent.getClientSecret())).build()); + .thenCompose(record -> { + final CompletableFuture updatedRecordFuture; + if (record.customerId == null) { + updatedRecordFuture = subscriptionProcessorManager.createCustomer(requestData.subscriberUser) + .thenApply(ProcessorCustomer::customerId) + .thenCompose(customerId -> subscriptionManager.updateProcessorAndCustomerId(record, + new ProcessorCustomer(customerId, + subscriptionProcessorManager.getProcessor()), Instant.now())); + } else { + updatedRecordFuture = CompletableFuture.completedFuture(record); + } + + return updatedRecordFuture.thenCompose( + updatedRecord -> subscriptionProcessorManager.createPaymentMethodSetupToken(updatedRecord.customerId)); + }) + .thenApply( + token -> Response.ok(new CreatePaymentMethodResponse(token, subscriptionProcessorManager.getProcessor())) + .build()); + } + + private SubscriptionProcessorManager getManagerForPaymentMethod(PaymentMethod paymentMethod) { + return switch (paymentMethod) { + case CARD -> stripeManager; + }; } @Timed diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SubscriptionManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SubscriptionManager.java index 0dc867b8f..333a2bcc7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SubscriptionManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SubscriptionManager.java @@ -6,23 +6,32 @@ package org.whispersystems.textsecuregcm.storage; import static org.whispersystems.textsecuregcm.util.AttributeValues.b; +import static org.whispersystems.textsecuregcm.util.AttributeValues.m; import static org.whispersystems.textsecuregcm.util.AttributeValues.n; import static org.whispersystems.textsecuregcm.util.AttributeValues.s; import com.google.common.base.Throwables; +import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer; +import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; +import org.whispersystems.textsecuregcm.util.Pair; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; import software.amazon.awssdk.services.dynamodb.model.QueryRequest; import software.amazon.awssdk.services.dynamodb.model.ReturnValue; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; @@ -35,8 +44,11 @@ public class SubscriptionManager { public static final String KEY_USER = "U"; // B (Hash Key) public static final String KEY_PASSWORD = "P"; // B + @Deprecated public static final String KEY_CUSTOMER_ID = "C"; // S (GSI Hash Key of `c_to_u` index) + public static final String KEY_PROCESSOR_ID_CUSTOMER_ID = "PC"; // B (GSI Hash Key of `pc_to_u` index) public static final String KEY_CREATED_AT = "R"; // N + public static final String KEY_PROCESSOR_CUSTOMER_IDS_MAP = "PCI"; // M public static final String KEY_SUBSCRIPTION_ID = "S"; // S public static final String KEY_SUBSCRIPTION_CREATED_AT = "T"; // N public static final String KEY_SUBSCRIPTION_LEVEL = "L"; @@ -51,8 +63,10 @@ public class SubscriptionManager { public final byte[] user; public final byte[] password; - public final String customerId; public final Instant createdAt; + public @Nullable String customerId; + public @Nullable SubscriptionProcessor processor; + public Map processorsToCustomerIds; public String subscriptionId; public Instant subscriptionCreatedAt; public Long subscriptionLevel; @@ -61,10 +75,9 @@ public class SubscriptionManager { public Instant canceledAt; public Instant currentPeriodEndsAt; - private Record(byte[] user, byte[] password, String customerId, Instant createdAt) { + private Record(byte[] user, byte[] password, Instant createdAt) { this.user = checkUserLength(user); this.password = Objects.requireNonNull(password); - this.customerId = Objects.requireNonNull(customerId); this.createdAt = Objects.requireNonNull(createdAt); } @@ -72,8 +85,17 @@ public class SubscriptionManager { Record self = new Record( user, item.get(KEY_PASSWORD).b().asByteArray(), - item.get(KEY_CUSTOMER_ID).s(), getInstant(item, KEY_CREATED_AT)); + + final Pair processorCustomerId = getProcessorAndCustomer(item); + if (processorCustomerId != null) { + self.customerId = processorCustomerId.second(); + self.processor = processorCustomerId.first(); + } else { + // Until all existing data is migrated to KEY_PROCESSOR_ID_CUSTOMER_ID, fall back to KEY_CUSTOMER_ID + self.customerId = getString(item, KEY_CUSTOMER_ID); + } + self.processorsToCustomerIds = getProcessorsToCustomerIds(item); self.subscriptionId = getString(item, KEY_SUBSCRIPTION_ID); self.subscriptionCreatedAt = getInstant(item, KEY_SUBSCRIPTION_CREATED_AT); self.subscriptionLevel = getLong(item, KEY_SUBSCRIPTION_LEVEL); @@ -84,8 +106,45 @@ public class SubscriptionManager { return self; } - public Map asKey() { - return Map.of(KEY_USER, b(user)); + private static Map getProcessorsToCustomerIds(Map item) { + final AttributeValue attributeValue = item.get(KEY_PROCESSOR_CUSTOMER_IDS_MAP); + final Map attribute = + attributeValue == null ? Collections.emptyMap() : attributeValue.m(); + + final Map processorsToCustomerIds = new HashMap<>(); + attribute.forEach((processorName, customerId) -> + processorsToCustomerIds.put(SubscriptionProcessor.valueOf(processorName), customerId.s())); + + return processorsToCustomerIds; + } + + /** + * Extracts the active processor and customer from a single attribute value in the given item. + *

+ * Until existing data is migrated, this may return {@code null}. + */ + @Nullable + private static Pair getProcessorAndCustomer(Map item) { + + final AttributeValue attributeValue = item.get(KEY_PROCESSOR_ID_CUSTOMER_ID); + + if (attributeValue == null) { + // temporarily allow null values + return null; + } + + final byte[] processorAndCustomerId = attributeValue.b().asByteArray(); + final byte processorId = processorAndCustomerId[0]; + + final SubscriptionProcessor processor = SubscriptionProcessor.forId(processorId); + if (processor == null) { + throw new IllegalStateException("unknown processor id: " + processorId); + } + + final String customerId = new String(processorAndCustomerId, 1, processorAndCustomerId.length - 1, + StandardCharsets.UTF_8); + + return new Pair<>(processor, customerId); } private static String getString(Map item, String key) { @@ -181,14 +240,7 @@ public class SubscriptionManager { * Looks up a record with the given {@code user} and validates the {@code hmac} before returning it. */ public CompletableFuture get(byte[] user, byte[] hmac) { - checkUserLength(user); - - GetItemRequest request = GetItemRequest.builder() - .consistentRead(Boolean.TRUE) - .tableName(table) - .key(Map.of(KEY_USER, b(user))) - .build(); - return client.getItem(request).thenApply(getItemResponse -> { + return getUser(user).thenApply(getItemResponse -> { if (!getItemResponse.hasItem()) { return GetResult.NOT_STORED; } @@ -201,7 +253,19 @@ public class SubscriptionManager { }); } - public CompletableFuture create(byte[] user, byte[] password, String customerId, Instant createdAt) { + private CompletableFuture getUser(byte[] user) { + checkUserLength(user); + + GetItemRequest request = GetItemRequest.builder() + .consistentRead(Boolean.TRUE) + .tableName(table) + .key(Map.of(KEY_USER, b(user))) + .build(); + + return client.getItem(request); + } + + public CompletableFuture create(byte[] user, byte[] password, Instant createdAt) { checkUserLength(user); UpdateItemRequest request = UpdateItemRequest.builder() @@ -211,20 +275,23 @@ public class SubscriptionManager { .conditionExpression("attribute_not_exists(#user) OR #password = :password") .updateExpression("SET " + "#password = if_not_exists(#password, :password), " - + "#customer_id = if_not_exists(#customer_id, :customer_id), " + "#created_at = if_not_exists(#created_at, :created_at), " - + "#accessed_at = if_not_exists(#accessed_at, :accessed_at)") + + "#accessed_at = if_not_exists(#accessed_at, :accessed_at), " + + "#processors_to_customer_ids = if_not_exists(#processors_to_customer_ids, :initial_empty_map)" + ) .expressionAttributeNames(Map.of( "#user", KEY_USER, "#password", KEY_PASSWORD, - "#customer_id", KEY_CUSTOMER_ID, "#created_at", KEY_CREATED_AT, - "#accessed_at", KEY_ACCESSED_AT)) + "#accessed_at", KEY_ACCESSED_AT, + "#processors_to_customer_ids", KEY_PROCESSOR_CUSTOMER_IDS_MAP) + ) .expressionAttributeValues(Map.of( ":password", b(password), - ":customer_id", s(customerId), ":created_at", n(createdAt.getEpochSecond()), - ":accessed_at", n(createdAt.getEpochSecond()))) + ":accessed_at", n(createdAt.getEpochSecond()), + ":initial_empty_map", m(Map.of())) + ) .build(); return client.updateItem(request).handle((updateItemResponse, throwable) -> { if (throwable != null) { @@ -239,6 +306,76 @@ public class SubscriptionManager { }); } + /** + * Updates the active processor and customer ID for the given user record. + * + * @return the updated user record. + */ + public CompletableFuture updateProcessorAndCustomerId(Record userRecord, + ProcessorCustomer activeProcessorCustomer, Instant updatedAt) { + + // Don’t attempt to modify the existing map, since it may be immutable, and we also don’t want to have side effects + final Map allProcessorsAndCustomerIds = new HashMap<>( + userRecord.processorsToCustomerIds); + allProcessorsAndCustomerIds.put(activeProcessorCustomer.processor(), activeProcessorCustomer.customerId()); + + UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(table) + .key(Map.of(KEY_USER, b(userRecord.user))) + .returnValues(ReturnValue.ALL_NEW) + .conditionExpression( + // there is no customer attribute yet + "attribute_not_exists(#customer_id) " + + // OR this record doesn't have the new processor+customer attributes yet + "OR (#customer_id = :customer_id " + + "AND attribute_not_exists(#processor_customer_id) " + + // TODO once all records are guaranteed to have the map, we can do a more targeted update + // "AND attribute_not_exists(#processors_to_customer_ids.#processor_name) " + + "AND attribute_not_exists(#processors_to_customer_ids))" + ) + .updateExpression("SET " + + "#customer_id = :customer_id, " + + "#processor_customer_id = :processor_customer_id, " + // TODO once all records are guaranteed to have the map, we can do a more targeted update + // + "#processors_to_customer_ids.#processor_name = :customer_id, " + + "#processors_to_customer_ids = :processors_and_customer_ids, " + + "#accessed_at = :accessed_at" + ) + .expressionAttributeNames(Map.of( + "#accessed_at", KEY_ACCESSED_AT, + "#customer_id", KEY_CUSTOMER_ID, + "#processor_customer_id", KEY_PROCESSOR_ID_CUSTOMER_ID, + // TODO "#processor_name", activeProcessor.name(), + "#processors_to_customer_ids", KEY_PROCESSOR_CUSTOMER_IDS_MAP + )) + .expressionAttributeValues(Map.of( + ":accessed_at", n(updatedAt.getEpochSecond()), + ":customer_id", s(activeProcessorCustomer.customerId()), + ":processor_customer_id", b(activeProcessorCustomer.toDynamoBytes()), + ":processors_and_customer_ids", m(createProcessorsToCustomerIdsAttributeMap(allProcessorsAndCustomerIds)) + )).build(); + + return client.updateItem(request) + .thenApply(updateItemResponse -> Record.from(userRecord.user, updateItemResponse.attributes())) + .exceptionallyCompose(throwable -> { + if (Throwables.getRootCause(throwable) instanceof ConditionalCheckFailedException) { + return getUser(userRecord.user).thenApply(getItemResponse -> + Record.from(userRecord.user, getItemResponse.item())); + } + Throwables.throwIfUnchecked(throwable); + throw new CompletionException(throwable); + }); + } + + private Map createProcessorsToCustomerIdsAttributeMap( + Map allProcessorsAndCustomerIds) { + final Map result = new HashMap<>(); + + allProcessorsAndCustomerIds.forEach((processor, customerId) -> result.put(processor.name(), s(customerId))); + + return result; + } + public CompletableFuture accessedAt(byte[] user, Instant accessedAt) { checkUserLength(user); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/PaymentMethod.java b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/PaymentMethod.java new file mode 100644 index 000000000..306d51f4e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/PaymentMethod.java @@ -0,0 +1,13 @@ +/* + * Copyright 2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.subscriptions; + +public enum PaymentMethod { + /** + * A credit card or debit card, including those from Apple Pay and Google Pay + */ + CARD, +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/ProcessorCustomer.java b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/ProcessorCustomer.java new file mode 100644 index 000000000..73c5d6b46 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/ProcessorCustomer.java @@ -0,0 +1,16 @@ +/* + * Copyright 2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.subscriptions; + +import java.nio.charset.StandardCharsets; +import org.whispersystems.dispatch.util.Util; + +public record ProcessorCustomer(String customerId, SubscriptionProcessor processor) { + + public byte[] toDynamoBytes() { + return Util.combine(new byte[]{processor.getId()}, customerId.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/stripe/StripeManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/StripeManager.java similarity index 91% rename from service/src/main/java/org/whispersystems/textsecuregcm/stripe/StripeManager.java rename to service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/StripeManager.java index cd7b00bb2..5f15fc7d5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/stripe/StripeManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/StripeManager.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: AGPL-3.0-only */ -package org.whispersystems.textsecuregcm.stripe; +package org.whispersystems.textsecuregcm.subscriptions; import com.google.common.base.Strings; import com.google.common.collect.Lists; @@ -61,7 +61,7 @@ import javax.ws.rs.core.Response.Status; import org.apache.commons.codec.binary.Hex; import org.whispersystems.textsecuregcm.util.Conversions; -public class StripeManager { +public class StripeManager implements SubscriptionProcessorManager { private static final String METADATA_KEY_LEVEL = "level"; @@ -87,6 +87,16 @@ public class StripeManager { this.boostDescription = Objects.requireNonNull(boostDescription); } + @Override + public SubscriptionProcessor getProcessor() { + return SubscriptionProcessor.STRIPE; + } + + @Override + public boolean supportsPaymentMethod(PaymentMethod paymentMethod) { + return paymentMethod == PaymentMethod.CARD; + } + private RequestOptions commonOptions() { return commonOptions(null); } @@ -98,17 +108,19 @@ public class StripeManager { .build(); } - public CompletableFuture createCustomer(byte[] subscriberUser) { + @Override + public CompletableFuture createCustomer(byte[] subscriberUser) { return CompletableFuture.supplyAsync(() -> { - CustomerCreateParams params = CustomerCreateParams.builder() - .putMetadata("subscriberUser", Hex.encodeHexString(subscriberUser)) - .build(); - try { - return Customer.create(params, commonOptions(generateIdempotencyKeyForSubscriberUser(subscriberUser))); - } catch (StripeException e) { - throw new CompletionException(e); - } - }, executor); + CustomerCreateParams params = CustomerCreateParams.builder() + .putMetadata("subscriberUser", Hex.encodeHexString(subscriberUser)) + .build(); + try { + return Customer.create(params, commonOptions(generateIdempotencyKeyForSubscriberUser(subscriberUser))); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor) + .thenApply(customer -> new ProcessorCustomer(customer.getId(), getProcessor())); } public CompletableFuture getCustomer(String customerId) { @@ -139,17 +151,19 @@ public class StripeManager { }, executor); } - public CompletableFuture createSetupIntent(String customerId) { + @Override + public CompletableFuture createPaymentMethodSetupToken(String customerId) { return CompletableFuture.supplyAsync(() -> { - SetupIntentCreateParams params = SetupIntentCreateParams.builder() - .setCustomer(customerId) - .build(); - try { - return SetupIntent.create(params, commonOptions()); - } catch (StripeException e) { - throw new CompletionException(e); - } - }, executor); + SetupIntentCreateParams params = SetupIntentCreateParams.builder() + .setCustomer(customerId) + .build(); + try { + return SetupIntent.create(params, commonOptions()); + } catch (StripeException e) { + throw new CompletionException(e); + } + }, executor) + .thenApply(SetupIntent::getClientSecret); } /** diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/SubscriptionProcessor.java b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/SubscriptionProcessor.java new file mode 100644 index 000000000..fe7a9eaeb --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/SubscriptionProcessor.java @@ -0,0 +1,48 @@ +/* + * Copyright 2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.subscriptions; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * A set of payment providers used for donations + */ +public enum SubscriptionProcessor { + // because provider IDs are stored, they should not be reused, and great care + // must be used if a provider is removed from the list + STRIPE(1), + ; + + private static final Map IDS_TO_PROCESSORS = new HashMap<>(); + + static { + Arrays.stream(SubscriptionProcessor.values()) + .forEach(provider -> IDS_TO_PROCESSORS.put((int) provider.id, provider)); + } + + /** + * @return the provider associated with the given ID, or {@code null} if none exists + */ + public static SubscriptionProcessor forId(byte id) { + return IDS_TO_PROCESSORS.get((int) id); + } + + private final byte id; + + SubscriptionProcessor(int id) { + if (id > 256) { + throw new IllegalArgumentException("ID must fit in one byte: " + id); + } + + this.id = (byte) id; + } + + public byte getId() { + return id; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/SubscriptionProcessorManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/SubscriptionProcessorManager.java new file mode 100644 index 000000000..dff259aae --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/SubscriptionProcessorManager.java @@ -0,0 +1,19 @@ +/* + * Copyright 2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.subscriptions; + +import java.util.concurrent.CompletableFuture; + +public interface SubscriptionProcessorManager { + + SubscriptionProcessor getProcessor(); + + boolean supportsPaymentMethod(PaymentMethod paymentMethod); + + CompletableFuture createCustomer(byte[] subscriberUser); + + CompletableFuture createPaymentMethodSetupToken(String customerId); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java index 8cfd6ba77..494703e00 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java @@ -5,12 +5,12 @@ package org.whispersystems.textsecuregcm.util; -import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import java.nio.ByteBuffer; import java.util.Map; import java.util.Optional; import java.util.UUID; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; /** AwsAV provides static helper methods for working with AWS AttributeValues. */ public class AttributeValues { @@ -37,6 +37,9 @@ public class AttributeValues { return AttributeValue.builder().s(value).build(); } + public static AttributeValue m(Map value) { + return AttributeValue.builder().m(value).build(); + } // More opinionated methods diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SubscriptionControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SubscriptionControllerTest.java index b352d7e21..7c7075f41 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SubscriptionControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SubscriptionControllerTest.java @@ -11,20 +11,27 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.util.AttributeValues.b; +import static org.whispersystems.textsecuregcm.util.AttributeValues.n; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; import java.math.BigDecimal; import java.time.Clock; +import java.time.Instant; +import java.util.Arrays; +import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import javax.ws.rs.client.Entity; import javax.ws.rs.core.Response; import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations; @@ -42,9 +49,12 @@ import org.whispersystems.textsecuregcm.entities.Badge; import org.whispersystems.textsecuregcm.entities.BadgeSvg; import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager; import org.whispersystems.textsecuregcm.storage.SubscriptionManager; -import org.whispersystems.textsecuregcm.stripe.StripeManager; +import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer; +import org.whispersystems.textsecuregcm.subscriptions.StripeManager; +import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @ExtendWith(DropwizardExtensionsSupport.class) class SubscriptionControllerTest { @@ -72,6 +82,11 @@ class SubscriptionControllerTest { .addResource(SUBSCRIPTION_CONTROLLER) .build(); + @BeforeEach + void setUp() { + when(STRIPE_MANAGER.getProcessor()).thenReturn(SubscriptionProcessor.STRIPE); + } + @AfterEach void tearDown() { reset(CLOCK, SUBSCRIPTION_CONFIG, SUBSCRIPTION_MANAGER, STRIPE_MANAGER, ZK_OPS, ISSUED_RECEIPTS_MANAGER, @@ -95,19 +110,161 @@ class SubscriptionControllerTest { assertThat(response.getStatus()).isEqualTo(422); } + @Test + void createSubscriber() { + when(CLOCK.instant()).thenReturn(Instant.now()); + + // basic create + final byte[] subscriberUserAndKey = new byte[32]; + Arrays.fill(subscriberUserAndKey, (byte) 1); + final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); + + when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( + SubscriptionManager.GetResult.NOT_STORED)); + + final Map dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), + SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), + SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()) + ); + final SubscriptionManager.Record record = SubscriptionManager.Record.from( + Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); + when(SUBSCRIPTION_MANAGER.create(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(record)); + + final Response createResponse = RESOURCE_EXTENSION.target(String.format("/v1/subscription/%s", subscriberId)) + .request() + .put(Entity.json("")); + assertThat(createResponse.getStatus()).isEqualTo(200); + + // creating should be idempotent + when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( + SubscriptionManager.GetResult.found(record))); + when(SUBSCRIPTION_MANAGER.accessedAt(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + + final Response idempotentCreateResponse = RESOURCE_EXTENSION.target( + String.format("/v1/subscription/%s", subscriberId)) + .request() + .put(Entity.json("")); + assertThat(idempotentCreateResponse.getStatus()).isEqualTo(200); + + // when the manager returns `null`, it means there was a password mismatch from the storage layer `create`. + // this could happen if there is a race between two concurrent `create` requests for the same user ID + when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( + SubscriptionManager.GetResult.NOT_STORED)); + when(SUBSCRIPTION_MANAGER.create(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + + final Response managerCreateNullResponse = RESOURCE_EXTENSION.target( + String.format("/v1/subscription/%s", subscriberId)) + .request() + .put(Entity.json("")); + assertThat(managerCreateNullResponse.getStatus()).isEqualTo(403); + + final byte[] subscriberUserAndMismatchedKey = new byte[32]; + Arrays.fill(subscriberUserAndMismatchedKey, 0, 16, (byte) 1); + Arrays.fill(subscriberUserAndMismatchedKey, 16, 32, (byte) 2); + final String mismatchedSubscriberId = Base64.getEncoder().encodeToString(subscriberUserAndMismatchedKey); + + // a password mismatch for an existing record + when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( + SubscriptionManager.GetResult.PASSWORD_MISMATCH)); + + final Response passwordMismatchResponse = RESOURCE_EXTENSION.target( + String.format("/v1/subscription/%s", mismatchedSubscriberId)) + .request() + .put(Entity.json("")); + + assertThat(passwordMismatchResponse.getStatus()).isEqualTo(403); + + // invalid request data is a 404 + final byte[] malformedUserAndKey = new byte[16]; + Arrays.fill(malformedUserAndKey, (byte) 1); + final String malformedUserId = Base64.getEncoder().encodeToString(malformedUserAndKey); + + final Response malformedUserAndKeyResponse = RESOURCE_EXTENSION.target( + String.format("/v1/subscription/%s", malformedUserId)) + .request() + .put(Entity.json("")); + + assertThat(malformedUserAndKeyResponse.getStatus()).isEqualTo(404); + } + + @Test + void createPaymentMethod() { + final byte[] subscriberUserAndKey = new byte[32]; + Arrays.fill(subscriberUserAndKey, (byte) 1); + final String subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); + + when(CLOCK.instant()).thenReturn(Instant.now()); + when(SUBSCRIPTION_MANAGER.get(any(), any())).thenReturn(CompletableFuture.completedFuture( + SubscriptionManager.GetResult.NOT_STORED)); + + final Map dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), + SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), + SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()) + ); + final SubscriptionManager.Record record = SubscriptionManager.Record.from( + Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); + when(SUBSCRIPTION_MANAGER.create(any(), any(), any(Instant.class))) + .thenReturn(CompletableFuture.completedFuture( + record)); + + final Response createSubscriberResponse = RESOURCE_EXTENSION + .target(String.format("/v1/subscription/%s", subscriberId)) + .request() + .put(Entity.json("")); + + assertThat(createSubscriberResponse.getStatus()).isEqualTo(200); + + when(SUBSCRIPTION_MANAGER.get(any(), any())) + .thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); + + final String customerId = "some-customer-id"; + final ProcessorCustomer customer = new ProcessorCustomer( + customerId, SubscriptionProcessor.STRIPE); + when(STRIPE_MANAGER.createCustomer(any())) + .thenReturn(CompletableFuture.completedFuture(customer)); + + final SubscriptionManager.Record recordWithCustomerId = SubscriptionManager.Record.from(record.user, dynamoItem); + recordWithCustomerId.customerId = customerId; + recordWithCustomerId.processorsToCustomerIds.put(SubscriptionProcessor.STRIPE, customerId); + + when(SUBSCRIPTION_MANAGER.updateProcessorAndCustomerId(any(SubscriptionManager.Record.class), any(), + any(Instant.class))) + .thenReturn(CompletableFuture.completedFuture(recordWithCustomerId)); + + final String clientSecret = "some-client-secret"; + when(STRIPE_MANAGER.createPaymentMethodSetupToken(customerId)) + .thenReturn(CompletableFuture.completedFuture(clientSecret)); + + final SubscriptionController.CreatePaymentMethodResponse createPaymentMethodResponse = RESOURCE_EXTENSION + .target(String.format("/v1/subscription/%s/create_payment_method", subscriberId)) + .request() + .post(Entity.json("")) + .readEntity(SubscriptionController.CreatePaymentMethodResponse.class); + + assertThat(createPaymentMethodResponse.processor()).isEqualTo(SubscriptionProcessor.STRIPE); + assertThat(createPaymentMethodResponse.clientSecret()).isEqualTo(clientSecret); + + } + @Test void getLevels() { when(SUBSCRIPTION_CONFIG.getLevels()).thenReturn(Map.of( - 1L, new SubscriptionLevelConfiguration("B1", "P1", Map.of("USD", new SubscriptionPriceConfiguration("R1", BigDecimal.valueOf(100)))), - 2L, new SubscriptionLevelConfiguration("B2", "P2", Map.of("USD", new SubscriptionPriceConfiguration("R2", BigDecimal.valueOf(200)))), - 3L, new SubscriptionLevelConfiguration("B3", "P3", Map.of("USD", new SubscriptionPriceConfiguration("R3", BigDecimal.valueOf(300)))) + 1L, new SubscriptionLevelConfiguration("B1", "P1", + Map.of("USD", new SubscriptionPriceConfiguration("R1", BigDecimal.valueOf(100)))), + 2L, new SubscriptionLevelConfiguration("B2", "P2", + Map.of("USD", new SubscriptionPriceConfiguration("R2", BigDecimal.valueOf(200)))), + 3L, new SubscriptionLevelConfiguration("B3", "P3", + Map.of("USD", new SubscriptionPriceConfiguration("R3", BigDecimal.valueOf(300)))) )); when(BADGE_TRANSLATOR.translate(any(), eq("B1"))).thenReturn(new Badge("B1", "cat1", "name1", "desc1", - List.of("l", "m", "h", "x", "xx", "xxx"), "SVG", List.of(new BadgeSvg("sl", "sd"), new BadgeSvg("ml", "md"), new BadgeSvg("ll", "ld")))); + List.of("l", "m", "h", "x", "xx", "xxx"), "SVG", + List.of(new BadgeSvg("sl", "sd"), new BadgeSvg("ml", "md"), new BadgeSvg("ll", "ld")))); when(BADGE_TRANSLATOR.translate(any(), eq("B2"))).thenReturn(new Badge("B2", "cat2", "name2", "desc2", - List.of("l", "m", "h", "x", "xx", "xxx"), "SVG", List.of(new BadgeSvg("sl", "sd"), new BadgeSvg("ml", "md"), new BadgeSvg("ll", "ld")))); + List.of("l", "m", "h", "x", "xx", "xxx"), "SVG", + List.of(new BadgeSvg("sl", "sd"), new BadgeSvg("ml", "md"), new BadgeSvg("ll", "ld")))); when(BADGE_TRANSLATOR.translate(any(), eq("B3"))).thenReturn(new Badge("B3", "cat3", "name3", "desc3", - List.of("l", "m", "h", "x", "xx", "xxx"), "SVG", List.of(new BadgeSvg("sl", "sd"), new BadgeSvg("ml", "md"), new BadgeSvg("ll", "ld")))); + List.of("l", "m", "h", "x", "xx", "xxx"), "SVG", + List.of(new BadgeSvg("sl", "sd"), new BadgeSvg("ml", "md"), new BadgeSvg("ll", "ld")))); when(LEVEL_TRANSLATOR.translate(any(), eq("B1"))).thenReturn("Z1"); when(LEVEL_TRANSLATOR.translate(any(), eq("B2"))).thenReturn("Z2"); when(LEVEL_TRANSLATOR.translate(any(), eq("B3"))).thenReturn("Z3"); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SubscriptionManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SubscriptionManagerTest.java index 279bbea14..760fcf3f7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SubscriptionManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SubscriptionManagerTest.java @@ -9,11 +9,16 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult.Type.FOUND; import static org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult.Type.NOT_STORED; import static org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult.Type.PASSWORD_MISMATCH; +import static org.whispersystems.textsecuregcm.util.AttributeValues.b; +import static org.whispersystems.textsecuregcm.util.AttributeValues.n; +import static org.whispersystems.textsecuregcm.util.AttributeValues.s; import java.security.SecureRandom; import java.time.Duration; import java.time.Instant; +import java.util.Arrays; import java.util.Base64; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; import javax.annotation.Nonnull; @@ -22,6 +27,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.storage.SubscriptionManager.GetResult; import org.whispersystems.textsecuregcm.storage.SubscriptionManager.Record; +import org.whispersystems.textsecuregcm.subscriptions.ProcessorCustomer; +import org.whispersystems.textsecuregcm.subscriptions.SubscriptionProcessor; import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex; import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; @@ -85,8 +92,6 @@ class SubscriptionManagerTest { void testCreateOnlyOnce() { byte[] password1 = getRandomBytes(16); byte[] password2 = getRandomBytes(16); - String customer1 = Base64.getEncoder().encodeToString(getRandomBytes(16)); - String customer2 = Base64.getEncoder().encodeToString(getRandomBytes(16)); Instant created1 = Instant.ofEpochSecond(NOW_EPOCH_SECONDS); Instant created2 = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); @@ -103,16 +108,16 @@ class SubscriptionManagerTest { }); CompletableFuture createFuture = - subscriptionManager.create(user, password1, customer1, created1); - Consumer recordRequirements = checkFreshlyCreatedRecord(user, password1, customer1, created1); + subscriptionManager.create(user, password1, created1); + Consumer recordRequirements = checkFreshlyCreatedRecord(user, password1, created1); assertThat(createFuture).succeedsWithin(Duration.ofSeconds(3)).satisfies(recordRequirements); // password check fails so this should return null - createFuture = subscriptionManager.create(user, password2, customer2, created2); + createFuture = subscriptionManager.create(user, password2, created2); assertThat(createFuture).succeedsWithin(Duration.ofSeconds(3)).isNull(); // password check matches, but the record already exists so nothing should get updated - createFuture = subscriptionManager.create(user, password1, customer2, created2); + createFuture = subscriptionManager.create(user, password1, created2); assertThat(createFuture).succeedsWithin(Duration.ofSeconds(3)).satisfies(recordRequirements); } @@ -120,27 +125,67 @@ class SubscriptionManagerTest { void testGet() { byte[] wrongUser = getRandomBytes(16); byte[] wrongPassword = getRandomBytes(16); - assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(Duration.ofSeconds(3)); assertThat(subscriptionManager.get(user, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { assertThat(getResult.type).isEqualTo(FOUND); - assertThat(getResult.record).isNotNull().satisfies(checkFreshlyCreatedRecord(user, password, customer, created)); + assertThat(getResult.record).isNotNull().satisfies(checkFreshlyCreatedRecord(user, password, created)); }); - assertThat(subscriptionManager.get(user, wrongPassword)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { - assertThat(getResult.type).isEqualTo(PASSWORD_MISMATCH); - assertThat(getResult.record).isNull(); - }); + assertThat(subscriptionManager.get(user, wrongPassword)).succeedsWithin(Duration.ofSeconds(3)) + .satisfies(getResult -> { + assertThat(getResult.type).isEqualTo(PASSWORD_MISMATCH); + assertThat(getResult.record).isNull(); + }); - assertThat(subscriptionManager.get(wrongUser, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { - assertThat(getResult.type).isEqualTo(NOT_STORED); - assertThat(getResult.record).isNull(); - }); + assertThat(subscriptionManager.get(wrongUser, password)).succeedsWithin(Duration.ofSeconds(3)) + .satisfies(getResult -> { + assertThat(getResult.type).isEqualTo(NOT_STORED); + assertThat(getResult.record).isNull(); + }); } @Test - void testLookupByCustomerId() { - assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + void testUpdateCustomerIdAndProcessor() throws Exception { + Instant subscriptionUpdated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); + assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(Duration.ofSeconds(3)); + + final CompletableFuture getUser = subscriptionManager.get(user, password); + assertThat(getUser).succeedsWithin(Duration.ofSeconds(3)); + final Record userRecord = getUser.get().record; + + assertThat(subscriptionManager.updateProcessorAndCustomerId(userRecord, + new ProcessorCustomer(customer, SubscriptionProcessor.STRIPE), + subscriptionUpdated)).succeedsWithin(Duration.ofSeconds(3)) + .hasFieldOrPropertyWithValue("customerId", customer) + .hasFieldOrPropertyWithValue("processorsToCustomerIds", Map.of(SubscriptionProcessor.STRIPE, customer)); + + assertThat( + subscriptionManager.updateProcessorAndCustomerId(userRecord, + new ProcessorCustomer(customer + "1", SubscriptionProcessor.STRIPE), + subscriptionUpdated)).succeedsWithin(Duration.ofSeconds(3)) + .hasFieldOrPropertyWithValue("customerId", customer) + .hasFieldOrPropertyWithValue("processorsToCustomerIds", Map.of(SubscriptionProcessor.STRIPE, customer)); + + // TODO test new customer ID with new processor does change the customer ID, once there is another processor + + assertThat(subscriptionManager.getSubscriberUserByStripeCustomerId(customer)) + .succeedsWithin(Duration.ofSeconds(3)). + isEqualTo(user); + } + + @Test + void testLookupByCustomerId() throws Exception { + Instant subscriptionUpdated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); + assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(Duration.ofSeconds(3)); + + final CompletableFuture getUser = subscriptionManager.get(user, password); + assertThat(getUser).succeedsWithin(Duration.ofSeconds(3)); + final Record userRecord = getUser.get().record; + + assertThat(subscriptionManager.updateProcessorAndCustomerId(userRecord, + new ProcessorCustomer(customer, SubscriptionProcessor.STRIPE), + subscriptionUpdated)).succeedsWithin(Duration.ofSeconds(3)); assertThat(subscriptionManager.getSubscriberUserByStripeCustomerId(customer)). succeedsWithin(Duration.ofSeconds(3)). isEqualTo(user); @@ -149,7 +194,7 @@ class SubscriptionManagerTest { @Test void testCanceledAt() { Instant canceled = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 42); - assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(Duration.ofSeconds(3)); assertThat(subscriptionManager.canceledAt(user, canceled)).succeedsWithin(Duration.ofSeconds(3)); assertThat(subscriptionManager.get(user, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { assertThat(getResult).isNotNull(); @@ -167,7 +212,7 @@ class SubscriptionManagerTest { String subscriptionId = Base64.getEncoder().encodeToString(getRandomBytes(16)); Instant subscriptionCreated = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 1); long level = 42; - assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(Duration.ofSeconds(3)); assertThat(subscriptionManager.subscriptionCreated(user, subscriptionId, subscriptionCreated, level)). succeedsWithin(Duration.ofSeconds(3)); assertThat(subscriptionManager.get(user, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { @@ -187,7 +232,7 @@ class SubscriptionManagerTest { void testSubscriptionLevelChanged() { Instant at = Instant.ofEpochSecond(NOW_EPOCH_SECONDS + 500); long level = 1776; - assertThat(subscriptionManager.create(user, password, customer, created)).succeedsWithin(Duration.ofSeconds(3)); + assertThat(subscriptionManager.create(user, password, created)).succeedsWithin(Duration.ofSeconds(3)); assertThat(subscriptionManager.subscriptionLevelChanged(user, at, level)).succeedsWithin(Duration.ofSeconds(3)); assertThat(subscriptionManager.get(user, password)).succeedsWithin(Duration.ofSeconds(3)).satisfies(getResult -> { assertThat(getResult).isNotNull(); @@ -200,6 +245,74 @@ class SubscriptionManagerTest { }); } + @Test + void testSubscriptionAddProcessorAttribute() throws Exception { + + final byte[] user = new byte[16]; + Arrays.fill(user, (byte) 1); + final byte[] hmac = new byte[16]; + Arrays.fill(hmac, (byte) 2); + final String customerId = "abcdef"; + + // manually create an existing record, with only KEY_CUSTOMER_ID + dynamoDbExtension.getDynamoDbClient().putItem(p -> + p.tableName(dynamoDbExtension.getTableName()) + .item(Map.of( + SubscriptionManager.KEY_USER, b(user), + SubscriptionManager.KEY_PASSWORD, b(hmac), + SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), + SubscriptionManager.KEY_CUSTOMER_ID, s(customerId), + SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()) + )) + ); + + final CompletableFuture firstGetResult = subscriptionManager.get(user, hmac); + assertThat(firstGetResult).succeedsWithin(Duration.ofSeconds(1)); + + final Record firstRecord = firstGetResult.get().record; + + assertThat(firstRecord.customerId).isEqualTo(customerId); + assertThat(firstRecord.processor).isNull(); + assertThat(firstRecord.processorsToCustomerIds).isEmpty(); + + // Try to update the user to have a different customer ID. This should quietly fail, + // and just return the existing customer ID. + final CompletableFuture firstUpdate = subscriptionManager.updateProcessorAndCustomerId(firstRecord, + new ProcessorCustomer(customerId + "something else", SubscriptionProcessor.STRIPE), + Instant.now()); + + assertThat(firstUpdate).succeedsWithin(Duration.ofSeconds(1)); + + final String firstUpdateCustomerId = firstUpdate.get().customerId; + assertThat(firstUpdateCustomerId).isEqualTo(customerId); + + // Now update with the existing customer ID. All fields should now be populated. + final CompletableFuture secondUpdate = subscriptionManager.updateProcessorAndCustomerId(firstRecord, + new ProcessorCustomer(customerId, SubscriptionProcessor.STRIPE), Instant.now()); + + assertThat(secondUpdate).succeedsWithin(Duration.ofSeconds(1)); + + final String secondUpdateCustomerId = secondUpdate.get().customerId; + assertThat(secondUpdateCustomerId).isEqualTo(customerId); + + final CompletableFuture secondGetResult = subscriptionManager.get(user, hmac); + assertThat(secondGetResult).succeedsWithin(Duration.ofSeconds(1)); + + final Record secondRecord = secondGetResult.get().record; + + assertThat(secondRecord.customerId).isEqualTo(customerId); + assertThat(secondRecord.processor).isEqualTo(SubscriptionProcessor.STRIPE); + assertThat(secondRecord.processorsToCustomerIds).isEqualTo(Map.of(SubscriptionProcessor.STRIPE, customerId)); + } + + @Test + void testProcessorAndCustomerId() { + final ProcessorCustomer processorCustomer = + new ProcessorCustomer("abc", SubscriptionProcessor.STRIPE); + + assertThat(processorCustomer.toDynamoBytes()).isEqualTo(new byte[]{1, 97, 98, 99}); + } + private static byte[] getRandomBytes(int length) { byte[] result = new byte[length]; SECURE_RANDOM.nextBytes(result); @@ -208,12 +321,12 @@ class SubscriptionManagerTest { @Nonnull private static Consumer checkFreshlyCreatedRecord( - byte[] user, byte[] password, String customer, Instant created) { + byte[] user, byte[] password, Instant created) { return record -> { assertThat(record).isNotNull(); assertThat(record.user).isEqualTo(user); assertThat(record.password).isEqualTo(password); - assertThat(record.customerId).isEqualTo(customer); + assertThat(record.customerId).isNull(); assertThat(record.createdAt).isEqualTo(created); assertThat(record.subscriptionId).isNull(); assertThat(record.subscriptionCreatedAt).isNull();