forked from Nixius/authelia
1
0
Fork 0
ATLAS/docker/ss-atlas/internal/handlers/subscription.go

254 lines
8.1 KiB
Go

package handlers
import (
"errors"
"fmt"
"log"
"net/http"
"strings"
"time"
"git.nixc.us/a250/ss-atlas/internal/pricing"
"git.nixc.us/a250/ss-atlas/internal/stripe"
"git.nixc.us/a250/ss-atlas/internal/validation"
"git.nixc.us/a250/ss-atlas/internal/version"
)
func (a *App) handleLanding(w http.ResponseWriter, r *http.Request) {
remoteUser := r.Header.Get("Remote-User")
if contains(r.Header.Get("Remote-Groups"), "customers") {
http.Redirect(w, r, a.cfg.AppURL+"/dashboard", http.StatusSeeOther)
return
}
// Logged-in user who paid but hasn't activated yet — send to activate.
if remoteUser != "" {
custID, _ := a.ldap.GetStripeCustomerID(remoteUser)
if custID != "" {
http.Redirect(w, r, a.cfg.AppURL+"/activate", http.StatusSeeOther)
return
}
}
count, _ := a.ldap.CountCustomers()
soldOut := a.cfg.MaxSignups > 0 && count >= a.cfg.MaxSignups
tier := pricing.ForCustomer(count, a.cfg.FreeTierLimit, a.cfg.YearTierLimit)
useForm := a.cfg.StripePriceID != "" || a.cfg.StripePriceIDFree != "" ||
a.cfg.StripePriceIDYear != "" || a.cfg.StripePriceIDMonth100 != "" ||
a.cfg.StripePriceIDMonth200 != ""
data := map[string]any{
"AppURL": a.cfg.AppURL,
"Commit": version.Commit,
"BuildTime": version.BuildTime,
"StripePaymentLink": a.cfg.StripePaymentLink,
"SoldOut": soldOut,
"PricingTier": int(tier),
"UseCheckoutForm": useForm,
"Tagline": a.cfg.LandingTagline,
"Features": a.cfg.LandingFeatures,
}
if err := a.tmpl.ExecuteTemplate(w, "landing.html", data); err != nil {
log.Printf("template error: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
}
}
func (a *App) handleCreateCheckout(w http.ResponseWriter, r *http.Request) {
if a.cfg.MaxSignups > 0 {
count, err := a.ldap.CountCustomers()
if err == nil && count >= a.cfg.MaxSignups {
http.Error(w, "signup limit reached, try again later", http.StatusServiceUnavailable)
return
}
}
rawEmail := r.FormValue("email")
email := validation.SanitizeEmail(rawEmail)
if email == "" {
http.Error(w, "valid email required", http.StatusBadRequest)
return
}
domain := strings.TrimSpace(r.FormValue("domain"))
if domain == "" {
http.Error(w, "domain required", http.StatusBadRequest)
return
}
if !validation.DomainResolves(domain, 5*time.Second) {
http.Error(w, "domain does not resolve; please enter a valid domain", http.StatusBadRequest)
return
}
rawPhone := r.FormValue("phone")
phone := validation.SanitizePhone(rawPhone)
if phone == "" {
http.Error(w, "valid phone number required", http.StatusBadRequest)
return
}
count, _ := a.ldap.CountCustomers()
sess, err := a.stripe.CreateCheckoutSession(email, domain, phone, count)
if err != nil {
if errors.Is(err, stripe.ErrNoPriceForTier) {
http.Error(w, "pricing not configured for current tier — set STRIPE_PRICE_ID or tier prices in env", http.StatusServiceUnavailable)
return
}
log.Printf("stripe checkout error: %v", err)
http.Error(w, "failed to create checkout", http.StatusInternalServerError)
return
}
http.Redirect(w, r, sess.URL, http.StatusSeeOther)
}
func (a *App) handleSuccess(w http.ResponseWriter, r *http.Request) {
sessionID := r.URL.Query().Get("session_id")
if sessionID == "" {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
if a.cfg.MaxSignups > 0 {
count, err := a.ldap.CountCustomers()
if err == nil && count >= a.cfg.MaxSignups {
http.Error(w, "signup limit reached, contact support", http.StatusServiceUnavailable)
return
}
}
sess, err := a.stripe.GetCheckoutSession(sessionID)
if err != nil {
log.Printf("stripe get session error: %v", err)
http.Error(w, "could not verify payment", http.StatusInternalServerError)
return
}
if sess.PaymentStatus != "paid" {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
email := sess.CustomerDetails.Email
customerID := sess.Customer.ID
username := sanitizeUsername(email)
phone := ""
if sess.Metadata != nil {
phone = sess.Metadata["customer_phone"]
}
result, err := a.ldap.ProvisionUser(username, email, customerID, phone)
if err != nil {
log.Printf("ldap provision failed for %s: %v", email, err)
http.Error(w, "account creation failed, contact support", http.StatusInternalServerError)
return
}
if sess.Metadata != nil && sess.Metadata["customer_domain"] != "" {
if err := a.ldap.SetCustomerDomain(result.Username, sess.Metadata["customer_domain"]); err != nil {
log.Printf("ldap set customer domain failed for %s: %v", result.Username, err)
}
}
// Grant active subscription: add to customers group so dashboard shows subscribed.
if err := a.ldap.AddToGroup(result.Username, "customers"); err != nil {
log.Printf("ldap add to customers failed for %s: %v (create group 'customers' in LLDAP admin if missing)", result.Username, err)
}
inGroup, _ := a.ldap.IsInGroup(result.Username, "customers")
if result.IsNew || !inGroup {
// New or lapsed: send password email, show success page.
if err := a.triggerPasswordReset(r, result.Username); err != nil {
log.Printf("authelia reset trigger failed for %s: %v", username, err)
} else {
resendRateLimiter.record(result.Username)
}
if err := a.tmpl.ExecuteTemplate(w, "success.html", map[string]any{
"AppURL": a.cfg.AppURL,
"Username": result.Username,
}); err != nil {
log.Printf("template error: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
}
return
}
// Returning active customer: ensure stack exists, go to dashboard
stackName := fmt.Sprintf("customer-%s", result.Username)
exists, _ := a.swarm.StackExists(stackName)
if !exists {
if err := a.swarm.RestoreVolumes(stackName, a.cfg.ArchivePath); err != nil {
log.Printf("resubscribe: volume restore failed for %s: %v", result.Username, err)
}
if err := a.swarm.DeployStack(stackName, result.Username, a.cfg.TraefikDomain); err != nil {
log.Printf("resubscribe: stack deploy failed for %s: %v", result.Username, err)
}
}
log.Printf("resubscribe: %s payment verified, redirecting to dashboard", result.Username)
http.Redirect(w, r, a.cfg.AppURL+"/dashboard", http.StatusSeeOther)
}
func (a *App) handlePortal(w http.ResponseWriter, r *http.Request) {
customerID := r.FormValue("customer_id")
if customerID == "" {
http.Error(w, "customer_id required", http.StatusBadRequest)
return
}
sess, err := a.stripe.CreatePortalSession(customerID)
if err != nil {
log.Printf("stripe portal error: %v", err)
http.Error(w, "failed to create portal session", http.StatusInternalServerError)
return
}
http.Redirect(w, r, sess.URL, http.StatusSeeOther)
}
// handleResubscribe creates a fresh checkout session for an existing Stripe
// customer whose subscription has expired/been cancelled. This differs from
// the portal flow which only manages active or scheduled-to-cancel subs.
func (a *App) handleResubscribe(w http.ResponseWriter, r *http.Request) {
customerID := r.FormValue("customer_id")
if customerID == "" {
http.Error(w, "customer_id required", http.StatusBadRequest)
return
}
count, _ := a.ldap.CountCustomers()
sess, err := a.stripe.CreateCheckoutForCustomer(customerID, count)
if err != nil {
if errors.Is(err, stripe.ErrNoPriceForTier) {
http.Error(w, "pricing not configured for current tier — set STRIPE_PRICE_ID or tier prices in env", http.StatusServiceUnavailable)
return
}
log.Printf("stripe resubscribe error: %v", err)
http.Error(w, "failed to create checkout session", http.StatusInternalServerError)
return
}
http.Redirect(w, r, sess.URL, http.StatusSeeOther)
}
func sanitizeUsername(email string) string {
parts := strings.SplitN(email, "@", 2)
local := parts[0]
domain := ""
if len(parts) == 2 {
// Use second-level domain only (e.g. "nixc" from "nixc.us", "gmail" from "gmail.com")
domainParts := strings.Split(parts[1], ".")
if len(domainParts) >= 2 {
domain = "-" + domainParts[len(domainParts)-2]
}
}
clean := func(s string) string {
return strings.Map(func(r rune) rune {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_' {
return r
}
return '-'
}, strings.ToLower(s))
}
return clean(local) + clean(domain)
}