forked from Nixius/authelia
346 lines
11 KiB
Go
346 lines
11 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.nixc.us/a250/ss-atlas/internal/accounts"
|
|
"git.nixc.us/a250/ss-atlas/internal/pricing"
|
|
"git.nixc.us/a250/ss-atlas/internal/stripe"
|
|
"git.nixc.us/a250/ss-atlas/internal/validation"
|
|
"git.nixc.us/a250/ss-atlas/internal/version"
|
|
stripego "github.com/stripe/stripe-go/v84"
|
|
)
|
|
|
|
func (a *App) handleLanding(w http.ResponseWriter, r *http.Request) {
|
|
acct, _, _ := a.currentAccount(r)
|
|
|
|
if isSubscribedAccount(acct) {
|
|
http.Redirect(w, r, a.cfg.AppURL+"/dashboard", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
if acct != nil && acct.StripeCustomerID != "" {
|
|
http.Redirect(w, r, a.cfg.AppURL+"/activate", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
count, _ := a.accounts.CountCustomers(r.Context())
|
|
soldOut := a.cfg.MaxSignups > 0 && count >= a.cfg.MaxSignups
|
|
tier := pricing.ForCustomer(count, a.cfg.FreeTierLimit, a.cfg.YearTierLimit)
|
|
|
|
useForm := a.cfg.StripePriceID != "" || a.cfg.StripePriceIDFree != "" ||
|
|
a.cfg.StripePriceIDYear != "" || a.cfg.StripePriceIDMonth100 != "" ||
|
|
a.cfg.StripePriceIDMonth200 != ""
|
|
data := map[string]any{
|
|
"AppURL": a.cfg.AppURL,
|
|
"Commit": version.Commit,
|
|
"BuildTime": version.BuildTime,
|
|
"StripePaymentLink": a.cfg.StripePaymentLink,
|
|
"SoldOut": soldOut,
|
|
"PricingTier": int(tier),
|
|
"UseCheckoutForm": useForm,
|
|
"Tagline": a.cfg.LandingTagline,
|
|
"Features": a.cfg.LandingFeatures,
|
|
}
|
|
if err := a.tmpl.ExecuteTemplate(w, "landing.html", data); err != nil {
|
|
log.Printf("template error: %v", err)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func (a *App) handleCheckout(w http.ResponseWriter, r *http.Request) {
|
|
acct, _, err := a.currentAccount(r)
|
|
if err != nil || acct == nil {
|
|
http.Redirect(w, r, a.cfg.IdentityURL, http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
if isSubscribedAccount(acct) {
|
|
http.Redirect(w, r, a.cfg.AppURL+"/dashboard", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
if acct.StripeCustomerID != "" {
|
|
http.Redirect(w, r, a.cfg.AppURL+"/activate", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
count, _ := a.accounts.CountCustomers(r.Context())
|
|
soldOut := a.cfg.MaxSignups > 0 && count >= a.cfg.MaxSignups
|
|
if soldOut {
|
|
http.Error(w, "signup limit reached, try again later", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
tier := pricing.ForCustomer(count, a.cfg.FreeTierLimit, a.cfg.YearTierLimit)
|
|
if err := a.tmpl.ExecuteTemplate(w, "checkout.html", map[string]any{
|
|
"AppURL": a.cfg.AppURL,
|
|
"Email": acct.PrimaryEmail,
|
|
"PricingTier": int(tier),
|
|
}); err != nil {
|
|
log.Printf("template error: %v", err)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func (a *App) handleCreateCheckout(w http.ResponseWriter, r *http.Request) {
|
|
acct, identity, err := a.currentAccount(r)
|
|
if err != nil || acct == nil {
|
|
http.Redirect(w, r, a.cfg.IdentityURL, http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
if a.cfg.MaxSignups > 0 {
|
|
count, err := a.accounts.CountCustomers(r.Context())
|
|
if err == nil && count >= a.cfg.MaxSignups {
|
|
http.Error(w, "signup limit reached, try again later", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
}
|
|
|
|
email := validation.SanitizeEmail(firstNonEmpty(acct.PrimaryEmail, identity.Email))
|
|
if email == "" {
|
|
http.Error(w, "signed-in account is missing a valid email", http.StatusBadRequest)
|
|
return
|
|
}
|
|
domain := strings.TrimSpace(r.FormValue("domain"))
|
|
if domain == "" {
|
|
http.Error(w, "domain required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if !validation.DomainResolves(domain, 5*time.Second) {
|
|
http.Error(w, "domain does not resolve; please enter a valid domain", http.StatusBadRequest)
|
|
return
|
|
}
|
|
rawPhone := r.FormValue("phone")
|
|
phone := validation.SanitizePhone(rawPhone)
|
|
if phone == "" {
|
|
http.Error(w, "valid phone number required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
count, _ := a.accounts.CountCustomers(r.Context())
|
|
sess, err := a.stripe.CreateCheckoutSession(email, domain, phone, acct.ID, count)
|
|
if err != nil {
|
|
if errors.Is(err, stripe.ErrNoPriceForTier) {
|
|
http.Error(w, "pricing not configured for current tier — set STRIPE_PRICE_ID or tier prices in env", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
log.Printf("stripe checkout error: %v", err)
|
|
http.Error(w, "failed to create checkout", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, sess.URL, http.StatusSeeOther)
|
|
}
|
|
|
|
func (a *App) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
|
sessionID := r.URL.Query().Get("session_id")
|
|
if sessionID == "" {
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
if a.cfg.MaxSignups > 0 {
|
|
count, err := a.accounts.CountCustomers(r.Context())
|
|
if err == nil && count >= a.cfg.MaxSignups {
|
|
http.Error(w, "signup limit reached, contact support", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
}
|
|
|
|
sess, err := a.stripe.GetCheckoutSession(sessionID)
|
|
if err != nil {
|
|
log.Printf("stripe get session error: %v", err)
|
|
http.Error(w, "could not verify payment", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if sess.PaymentStatus != "paid" {
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
email := sess.CustomerDetails.Email
|
|
customerID := sess.Customer.ID
|
|
phone := ""
|
|
domain := ""
|
|
accountID := int64(0)
|
|
if sess.Metadata != nil {
|
|
phone = sess.Metadata["customer_phone"]
|
|
domain = sess.Metadata["customer_domain"]
|
|
accountID, _ = strconv.ParseInt(sess.Metadata["account_id"], 10, 64)
|
|
}
|
|
|
|
input := accounts.CheckoutInput{
|
|
AccountID: accountID,
|
|
Email: email,
|
|
DisplayName: email,
|
|
Phone: phone,
|
|
CustomerDomain: domain,
|
|
StripeCustomerID: customerID,
|
|
StripeSubscriptionID: subscriptionIDFromSession(sess),
|
|
StripeSessionID: sess.ID,
|
|
}
|
|
acct, inst, err := a.accounts.UpsertCheckout(r.Context(), input)
|
|
if err != nil {
|
|
log.Printf("account provision failed for %s: %v", email, err)
|
|
http.Error(w, "account creation failed, contact support", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
exists, _ := a.swarm.StackExists(inst.StackName)
|
|
if !exists {
|
|
if err := a.swarm.RestoreVolumes(inst.StackName, a.cfg.ArchivePath); err != nil {
|
|
log.Printf("success: volume restore failed for %s: %v", inst.StackName, err)
|
|
}
|
|
if err := a.swarm.DeployStack(inst.StackName, inst.Slug, a.cfg.TraefikDomain); err != nil {
|
|
log.Printf("success: stack deploy failed for %s: %v", inst.StackName, err)
|
|
} else if err := a.accounts.UpdateInstanceState(context.Background(), inst.StackName, "running", true); err != nil {
|
|
log.Printf("success: update instance state failed for %s: %v", inst.StackName, err)
|
|
}
|
|
}
|
|
if err := a.tmpl.ExecuteTemplate(w, "success.html", map[string]any{
|
|
"AppURL": a.cfg.AppURL,
|
|
"IdentityURL": a.cfg.IdentityURL,
|
|
"Email": acct.PrimaryEmail,
|
|
}); err != nil {
|
|
log.Printf("template error: %v", err)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
// handleLinkStripeCustomer creates a Stripe customer for the current user and saves the ID,
|
|
// so "Manage Subscription" works. Used when the user is in customers group but has no customer_id (e.g. manual add).
|
|
func (a *App) handleLinkStripeCustomer(w http.ResponseWriter, r *http.Request) {
|
|
acct, identity, err := a.currentAccount(r)
|
|
if err != nil || acct == nil {
|
|
http.Redirect(w, r, a.cfg.IdentityURL, http.StatusSeeOther)
|
|
return
|
|
}
|
|
if !isSubscribedAccount(acct) {
|
|
http.Error(w, "forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
if acct.StripeCustomerID != "" {
|
|
redirectWithPortalError(w, r, a.cfg.AppURL+"/dashboard", "Billing account already linked. Use Manage Subscription below.")
|
|
return
|
|
}
|
|
email := strings.TrimSpace(firstNonEmpty(identity.Email, acct.PrimaryEmail))
|
|
if email == "" {
|
|
redirectWithPortalError(w, r, a.cfg.AppURL+"/dashboard", "No email on account. Contact support to manage your subscription.")
|
|
return
|
|
}
|
|
customerID, err := a.stripe.CreateCustomer(email)
|
|
if err != nil {
|
|
log.Printf("link-stripe-customer: create customer failed for %s: %v", email, err)
|
|
redirectWithPortalError(w, r, a.cfg.AppURL+"/dashboard", "Failed to create billing account: "+err.Error())
|
|
return
|
|
}
|
|
if _, _, err := a.accounts.UpsertCheckout(r.Context(), accounts.CheckoutInput{
|
|
Email: email,
|
|
DisplayName: firstNonEmpty(identity.Name, identity.Username, email),
|
|
StripeCustomerID: customerID,
|
|
}); err != nil {
|
|
log.Printf("link-stripe-customer: set account failed for %s: %v", email, err)
|
|
redirectWithPortalError(w, r, a.cfg.AppURL+"/dashboard", "Billing account created but link failed. Contact support to manage your subscription.")
|
|
return
|
|
}
|
|
log.Printf("link-stripe-customer: linked account %d -> Stripe customer %s", acct.ID, customerID)
|
|
u, _ := url.Parse(a.cfg.AppURL + "/dashboard")
|
|
q := u.Query()
|
|
q.Set("linked", "1")
|
|
u.RawQuery = q.Encode()
|
|
http.Redirect(w, r, u.String(), http.StatusSeeOther)
|
|
}
|
|
|
|
func (a *App) handlePortal(w http.ResponseWriter, r *http.Request) {
|
|
customerID := r.FormValue("customer_id")
|
|
if customerID == "" {
|
|
redirectWithPortalError(w, r, a.cfg.AppURL+"/dashboard", "No billing account linked. Contact support to manage your subscription.")
|
|
return
|
|
}
|
|
|
|
sess, err := a.stripe.CreatePortalSession(customerID)
|
|
if err != nil {
|
|
log.Printf("stripe portal error: %v", err)
|
|
http.Error(w, "failed to create portal session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, sess.URL, http.StatusSeeOther)
|
|
}
|
|
|
|
// handleResubscribe creates a fresh checkout session for an existing Stripe
|
|
// customer whose subscription has expired/been cancelled. This differs from
|
|
// the portal flow which only manages active or scheduled-to-cancel subs.
|
|
func (a *App) handleResubscribe(w http.ResponseWriter, r *http.Request) {
|
|
customerID := r.FormValue("customer_id")
|
|
if customerID == "" {
|
|
redirectWithPortalError(w, r, a.cfg.AppURL+"/dashboard", "No billing account linked. Contact support to resubscribe.")
|
|
return
|
|
}
|
|
|
|
count, _ := a.accounts.CountCustomers(r.Context())
|
|
sess, err := a.stripe.CreateCheckoutForCustomer(customerID, count)
|
|
if err != nil {
|
|
if errors.Is(err, stripe.ErrNoPriceForTier) {
|
|
http.Error(w, "pricing not configured for current tier — set STRIPE_PRICE_ID or tier prices in env", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
log.Printf("stripe resubscribe error: %v", err)
|
|
http.Error(w, "failed to create checkout session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, sess.URL, http.StatusSeeOther)
|
|
}
|
|
|
|
func isSubscribedAccount(acct *accounts.Account) bool {
|
|
return acct != nil && acct.SubscriptionStatus == "active"
|
|
}
|
|
|
|
func sanitizeUsername(email string) string {
|
|
parts := strings.SplitN(email, "@", 2)
|
|
local := parts[0]
|
|
domain := ""
|
|
if len(parts) == 2 {
|
|
// Use second-level domain only (e.g. "nixc" from "nixc.us", "gmail" from "gmail.com")
|
|
domainParts := strings.Split(parts[1], ".")
|
|
if len(domainParts) >= 2 {
|
|
domain = "-" + domainParts[len(domainParts)-2]
|
|
}
|
|
}
|
|
clean := func(s string) string {
|
|
return strings.Map(func(r rune) rune {
|
|
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_' {
|
|
return r
|
|
}
|
|
return '-'
|
|
}, strings.ToLower(s))
|
|
}
|
|
return clean(local) + clean(domain)
|
|
}
|
|
|
|
func redirectWithPortalError(w http.ResponseWriter, r *http.Request, baseURL, message string) {
|
|
u, _ := url.Parse(baseURL)
|
|
q := u.Query()
|
|
q.Set("portal_error", message)
|
|
u.RawQuery = q.Encode()
|
|
http.Redirect(w, r, u.String(), http.StatusSeeOther)
|
|
}
|
|
|
|
func subscriptionIDFromSession(sess *stripego.CheckoutSession) string {
|
|
if sess == nil || sess.Subscription == nil {
|
|
return ""
|
|
}
|
|
return sess.Subscription.ID
|
|
}
|