forked from Nixius/authelia
193 lines
6.5 KiB
Go
193 lines
6.5 KiB
Go
package accounts
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
type txQueryer interface {
|
|
QueryRowContext(context.Context, string, ...any) *sql.Row
|
|
ExecContext(context.Context, string, ...any) (sql.Result, error)
|
|
}
|
|
|
|
func (s *Store) UpsertCheckout(ctx context.Context, input CheckoutInput) (*Account, *Instance, error) {
|
|
email := strings.ToLower(strings.TrimSpace(input.Email))
|
|
if input.StripeCustomerID == "" {
|
|
return nil, nil, errors.New("Stripe customer id is required")
|
|
}
|
|
if email == "" && input.AccountID == 0 {
|
|
return nil, nil, errors.New("email is required when account id is missing")
|
|
}
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if input.StripeEventID != "" {
|
|
if processed, err := eventProcessed(ctx, tx, input.StripeEventID); err != nil {
|
|
return nil, nil, err
|
|
} else if processed {
|
|
acct, inst, loadErr := s.accountAndInstanceByStripeTx(ctx, tx, input.StripeCustomerID)
|
|
if loadErr != nil {
|
|
return nil, nil, loadErr
|
|
}
|
|
return acct, inst, tx.Commit()
|
|
}
|
|
}
|
|
|
|
var acct *Account
|
|
if input.AccountID > 0 {
|
|
acct, err = accountByID(ctx, tx, input.AccountID)
|
|
}
|
|
if acct == nil && (input.AccountID == 0 || errors.Is(err, ErrNotFound)) {
|
|
acct, err = accountByStripeCustomerID(ctx, tx, input.StripeCustomerID)
|
|
}
|
|
if errors.Is(err, ErrNotFound) {
|
|
acct, err = upsertCheckoutAccount(ctx, tx, input)
|
|
}
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if email == "" {
|
|
input.Email = acct.PrimaryEmail
|
|
email = strings.ToLower(strings.TrimSpace(input.Email))
|
|
}
|
|
if input.DisplayName == "" {
|
|
input.DisplayName = acct.DisplayName
|
|
}
|
|
if err := updateCheckoutAccount(ctx, tx, acct.ID, input); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if err := linkIdentity(ctx, tx, acct.ID, "stripe", input.StripeCustomerID, email); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
inst, err := ensureInstance(ctx, tx, acct.ID, email, input.CustomerDomain)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if input.StripeEventID != "" {
|
|
if err := insertBillingEvent(ctx, tx, input, "checkout.session.completed"); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
return acct, inst, tx.Commit()
|
|
}
|
|
|
|
func (s *Store) AccountByStripeCustomerID(ctx context.Context, customerID string) (*Account, *Instance, error) {
|
|
return s.accountAndInstanceByStripeTx(ctx, s.db, customerID)
|
|
}
|
|
|
|
func (s *Store) MarkSubscriptionStatus(ctx context.Context, customerID, status string) error {
|
|
res, err := s.db.ExecContext(ctx, `
|
|
UPDATE accounts
|
|
SET subscription_status = $2, updated_at = now()
|
|
WHERE stripe_customer_id = $1
|
|
`, customerID, status)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) InstanceByAccountID(ctx context.Context, accountID int64) (*Instance, error) {
|
|
return instanceByAccountID(ctx, s.db, accountID)
|
|
}
|
|
|
|
func (s *Store) InstanceBySlug(ctx context.Context, slug string) (*Instance, error) {
|
|
return instanceBySlug(ctx, s.db, slug)
|
|
}
|
|
|
|
func (s *Store) UpdateInstanceState(ctx context.Context, stackName, state string, deployed bool) error {
|
|
q := `UPDATE instances SET state = $2, updated_at = now()`
|
|
args := []any{stackName, state}
|
|
if deployed {
|
|
q += `, last_deployed_at = now()`
|
|
}
|
|
q += ` WHERE stack_name = $1`
|
|
_, err := s.db.ExecContext(ctx, q, args...)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) DeleteAccountByInstanceSlug(ctx context.Context, slug string) (*Instance, error) {
|
|
inst, err := instanceBySlug(ctx, s.db, slug)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
_, err = s.db.ExecContext(ctx, `DELETE FROM accounts WHERE id = $1`, inst.AccountID)
|
|
return inst, err
|
|
}
|
|
|
|
func upsertCheckoutAccount(ctx context.Context, q txQueryer, input CheckoutInput) (*Account, error) {
|
|
return scanAccount(q.QueryRowContext(ctx, `
|
|
INSERT INTO accounts (primary_email, display_name, phone, stripe_customer_id, subscription_status)
|
|
VALUES ($1, $2, $3, $4, 'active')
|
|
ON CONFLICT (primary_email) DO UPDATE SET
|
|
display_name = COALESCE(NULLIF(EXCLUDED.display_name, ''), accounts.display_name),
|
|
phone = COALESCE(NULLIF(EXCLUDED.phone, ''), accounts.phone),
|
|
stripe_customer_id = EXCLUDED.stripe_customer_id,
|
|
subscription_status = 'active',
|
|
updated_at = now()
|
|
RETURNING id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
|
|
`, strings.ToLower(input.Email), input.DisplayName, input.Phone, input.StripeCustomerID))
|
|
}
|
|
|
|
func updateCheckoutAccount(ctx context.Context, q txQueryer, accountID int64, input CheckoutInput) error {
|
|
_, err := q.ExecContext(ctx, `
|
|
UPDATE accounts
|
|
SET primary_email = COALESCE(NULLIF($2, ''), primary_email),
|
|
display_name = COALESCE(NULLIF($3, ''), display_name),
|
|
phone = COALESCE(NULLIF($4, ''), phone),
|
|
stripe_customer_id = COALESCE(NULLIF($5, ''), stripe_customer_id),
|
|
subscription_status = 'active',
|
|
updated_at = now()
|
|
WHERE id = $1
|
|
`, accountID, strings.ToLower(input.Email), input.DisplayName, input.Phone, input.StripeCustomerID)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) accountAndInstanceByStripeTx(ctx context.Context, q txQueryer, customerID string) (*Account, *Instance, error) {
|
|
acct, err := accountByStripeCustomerID(ctx, q, customerID)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
inst, err := instanceByAccountID(ctx, q, acct.ID)
|
|
return acct, inst, err
|
|
}
|
|
|
|
func accountByStripeCustomerID(ctx context.Context, q txQueryer, customerID string) (*Account, error) {
|
|
return scanAccount(q.QueryRowContext(ctx, `
|
|
SELECT id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
|
|
FROM accounts WHERE stripe_customer_id = $1
|
|
`, customerID))
|
|
}
|
|
|
|
func ensureInstance(ctx context.Context, q txQueryer, accountID int64, email, domain string) (*Instance, error) {
|
|
if inst, err := instanceByAccountID(ctx, q, accountID); err == nil {
|
|
if domain != "" && inst.CustomerDomain != domain {
|
|
_, _ = q.ExecContext(ctx, `UPDATE instances SET customer_domain = $2, updated_at = now() WHERE id = $1`, inst.ID, domain)
|
|
inst.CustomerDomain = domain
|
|
}
|
|
return inst, nil
|
|
} else if !errors.Is(err, ErrNotFound) {
|
|
return nil, err
|
|
}
|
|
slug := SlugFromEmail(email)
|
|
if owner, err := instanceBySlug(ctx, q, slug); err == nil && owner.AccountID != accountID {
|
|
slug = fmt.Sprintf("%s-%d", slug, accountID)
|
|
}
|
|
stackName := "customer-" + slug
|
|
return scanInstance(q.QueryRowContext(ctx, `
|
|
INSERT INTO instances (account_id, slug, stack_name, customer_domain, state)
|
|
VALUES ($1, $2, $3, $4, 'pending')
|
|
RETURNING id, account_id, slug, stack_name, customer_domain, state, last_deployed_at
|
|
`, accountID, slug, stackName, domain))
|
|
}
|