forked from Nixius/authelia
180 lines
6.2 KiB
Go
180 lines
6.2 KiB
Go
package stripe
|
|
|
|
import (
|
|
"errors"
|
|
"log"
|
|
"strconv"
|
|
"time"
|
|
|
|
"git.nixc.us/a250/ss-atlas/internal/config"
|
|
"git.nixc.us/a250/ss-atlas/internal/pricing"
|
|
stripego "github.com/stripe/stripe-go/v84"
|
|
portalsession "github.com/stripe/stripe-go/v84/billingportal/session"
|
|
checkoutsession "github.com/stripe/stripe-go/v84/checkout/session"
|
|
"github.com/stripe/stripe-go/v84/subscription"
|
|
)
|
|
|
|
var ErrNoPriceForTier = errors.New("no Stripe price configured for this tier")
|
|
|
|
type SubscriptionStatus struct {
|
|
Label string // "Active", "Cancels soon", etc.
|
|
Badge string // "badge-active", "badge-inactive", etc.
|
|
CancelAt string // empty or formatted date
|
|
}
|
|
|
|
type Client struct {
|
|
cfg *config.Config
|
|
}
|
|
|
|
func New(cfg *config.Config) *Client {
|
|
stripego.Key = cfg.StripeSecretKey
|
|
return &Client{cfg: cfg}
|
|
}
|
|
|
|
func (c *Client) priceForTier(t pricing.Tier) string {
|
|
switch t {
|
|
case pricing.TierFree:
|
|
if c.cfg.StripePriceIDFree != "" {
|
|
return c.cfg.StripePriceIDFree
|
|
}
|
|
case pricing.TierYear:
|
|
if c.cfg.StripePriceIDYear != "" {
|
|
return c.cfg.StripePriceIDYear
|
|
}
|
|
case pricing.TierPremium:
|
|
if c.cfg.StripePriceIDMonth200 != "" {
|
|
return c.cfg.StripePriceIDMonth200
|
|
}
|
|
if c.cfg.StripePriceIDMonth100 != "" {
|
|
return c.cfg.StripePriceIDMonth100
|
|
}
|
|
}
|
|
return c.cfg.StripePriceID
|
|
}
|
|
|
|
func (c *Client) CreateCheckoutSession(email, customerDomain, customerPhone string, customerCount int) (*stripego.CheckoutSession, error) {
|
|
t := pricing.ForCustomer(customerCount, c.cfg.FreeTierLimit, c.cfg.YearTierLimit)
|
|
priceID := c.priceForTier(t)
|
|
if priceID == "" {
|
|
log.Printf("stripe: no price for tier %d (count=%d freeLimit=%d yearLimit=%d); set STRIPE_PRICE_ID or tier-specific price",
|
|
t, customerCount, c.cfg.FreeTierLimit, c.cfg.YearTierLimit)
|
|
return nil, ErrNoPriceForTier
|
|
}
|
|
|
|
params := &stripego.CheckoutSessionParams{
|
|
Mode: stripego.String(string(stripego.CheckoutSessionModeSubscription)),
|
|
LineItems: []*stripego.CheckoutSessionLineItemParams{
|
|
{Price: stripego.String(priceID), Quantity: stripego.Int64(1)},
|
|
},
|
|
CustomerEmail: stripego.String(email),
|
|
SuccessURL: stripego.String(c.cfg.AppURL + "/success?session_id={CHECKOUT_SESSION_ID}"),
|
|
CancelURL: stripego.String(c.cfg.AppURL + "/"),
|
|
}
|
|
if customerDomain != "" {
|
|
params.AddMetadata("customer_domain", customerDomain)
|
|
}
|
|
if customerPhone != "" {
|
|
params.AddMetadata("customer_phone", customerPhone)
|
|
}
|
|
params.AddMetadata("pricing_tier", strconv.Itoa(int(t)))
|
|
return checkoutsession.New(params)
|
|
}
|
|
|
|
// CreateCheckoutForCustomer creates a new subscription checkout for an existing
|
|
// Stripe customer (e.g. resubscribe after expiry). Uses current tier by customer count.
|
|
func (c *Client) CreateCheckoutForCustomer(customerID string, customerCount int) (*stripego.CheckoutSession, error) {
|
|
t := pricing.ForCustomer(customerCount, c.cfg.FreeTierLimit, c.cfg.YearTierLimit)
|
|
priceID := c.priceForTier(t)
|
|
if priceID == "" {
|
|
log.Printf("stripe: no price for tier %d (count=%d); set STRIPE_PRICE_ID or tier-specific price", t, customerCount)
|
|
return nil, ErrNoPriceForTier
|
|
}
|
|
|
|
params := &stripego.CheckoutSessionParams{
|
|
Mode: stripego.String(string(stripego.CheckoutSessionModeSubscription)),
|
|
LineItems: []*stripego.CheckoutSessionLineItemParams{
|
|
{Price: stripego.String(priceID), Quantity: stripego.Int64(1)},
|
|
},
|
|
Customer: stripego.String(customerID),
|
|
SuccessURL: stripego.String(c.cfg.AppURL + "/success?session_id={CHECKOUT_SESSION_ID}"),
|
|
CancelURL: stripego.String(c.cfg.AppURL + "/dashboard"),
|
|
}
|
|
return checkoutsession.New(params)
|
|
}
|
|
|
|
func (c *Client) CreatePortalSession(customerID string) (*stripego.BillingPortalSession, error) {
|
|
params := &stripego.BillingPortalSessionParams{
|
|
Customer: stripego.String(customerID),
|
|
ReturnURL: stripego.String(c.cfg.AppURL + "/dashboard"),
|
|
}
|
|
return portalsession.New(params)
|
|
}
|
|
|
|
func (c *Client) GetCheckoutSession(sessionID string) (*stripego.CheckoutSession, error) {
|
|
params := &stripego.CheckoutSessionParams{}
|
|
params.AddExpand("customer")
|
|
params.AddExpand("subscription")
|
|
return checkoutsession.Get(sessionID, params)
|
|
}
|
|
|
|
func (c *Client) GetSubscription(subID string) (*stripego.Subscription, error) {
|
|
return subscription.Get(subID, nil)
|
|
}
|
|
|
|
// ScheduleSubscriptionCancelAt sets the subscription to cancel at the given unix timestamp.
|
|
// Used for free-tier subs that should auto-cancel after 3 months.
|
|
func (c *Client) ScheduleSubscriptionCancelAt(subID string, cancelAt int64) error {
|
|
_, err := subscription.Update(subID, &stripego.SubscriptionParams{
|
|
CancelAt: stripego.Int64(cancelAt),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (c *Client) IsFreeTierPrice(priceID string) bool {
|
|
return c.cfg.StripePriceIDFree != "" && priceID == c.cfg.StripePriceIDFree
|
|
}
|
|
|
|
func (c *Client) IsYearTierPrice(priceID string) bool {
|
|
return c.cfg.StripePriceIDYear != "" && priceID == c.cfg.StripePriceIDYear
|
|
}
|
|
|
|
func (c *Client) GetCustomerSubscriptionStatus(customerID string) *SubscriptionStatus {
|
|
if customerID == "" {
|
|
return &SubscriptionStatus{Label: "Active", Badge: "badge-active"}
|
|
}
|
|
params := &stripego.SubscriptionListParams{
|
|
Customer: stripego.String(customerID),
|
|
Status: stripego.String(string(stripego.SubscriptionStatusActive)),
|
|
}
|
|
iter := subscription.List(params)
|
|
if iter.Next() {
|
|
sub := iter.Subscription()
|
|
log.Printf("stripe: customer=%s sub=%s cancel_at_period_end=%v cancel_at=%d",
|
|
customerID, sub.ID, sub.CancelAtPeriodEnd, sub.CancelAt)
|
|
if sub.CancelAtPeriodEnd || sub.CancelAt > 0 {
|
|
// Prefer explicit cancel_at; fall back to current_period_end from the first item
|
|
endTs := sub.CancelAt
|
|
if endTs == 0 && sub.Items != nil && len(sub.Items.Data) > 0 {
|
|
endTs = sub.Items.Data[0].CurrentPeriodEnd
|
|
}
|
|
var cancelAt string
|
|
if endTs > 0 {
|
|
cancelAt = time.Unix(endTs, 0).Format("Jan 2, 2006")
|
|
}
|
|
return &SubscriptionStatus{
|
|
Label: "Expiring",
|
|
Badge: "badge-inactive",
|
|
CancelAt: cancelAt,
|
|
}
|
|
}
|
|
return &SubscriptionStatus{Label: "Active", Badge: "badge-active"}
|
|
}
|
|
log.Printf("stripe: no active subscription found for customer=%s", customerID)
|
|
// No active subscription; user was a customer so subscription has expired
|
|
return &SubscriptionStatus{Label: "Expired", Badge: "badge-inactive"}
|
|
}
|
|
|
|
func (c *Client) WebhookSecret() string {
|
|
return c.cfg.StripeWebhookSecret
|
|
}
|