forked from Nixius/authelia
1
0
Fork 0
ATLAS/docker/ss-atlas/internal/accounts/queries.go

193 lines
6.5 KiB
Go

package accounts
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
)
type txQueryer interface {
QueryRowContext(context.Context, string, ...any) *sql.Row
ExecContext(context.Context, string, ...any) (sql.Result, error)
}
func (s *Store) UpsertCheckout(ctx context.Context, input CheckoutInput) (*Account, *Instance, error) {
email := strings.ToLower(strings.TrimSpace(input.Email))
if input.StripeCustomerID == "" {
return nil, nil, errors.New("Stripe customer id is required")
}
if email == "" && input.AccountID == 0 {
return nil, nil, errors.New("email is required when account id is missing")
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, nil, err
}
defer tx.Rollback()
if input.StripeEventID != "" {
if processed, err := eventProcessed(ctx, tx, input.StripeEventID); err != nil {
return nil, nil, err
} else if processed {
acct, inst, loadErr := s.accountAndInstanceByStripeTx(ctx, tx, input.StripeCustomerID)
if loadErr != nil {
return nil, nil, loadErr
}
return acct, inst, tx.Commit()
}
}
var acct *Account
if input.AccountID > 0 {
acct, err = accountByID(ctx, tx, input.AccountID)
}
if acct == nil && (input.AccountID == 0 || errors.Is(err, ErrNotFound)) {
acct, err = accountByStripeCustomerID(ctx, tx, input.StripeCustomerID)
}
if errors.Is(err, ErrNotFound) {
acct, err = upsertCheckoutAccount(ctx, tx, input)
}
if err != nil {
return nil, nil, err
}
if email == "" {
input.Email = acct.PrimaryEmail
email = strings.ToLower(strings.TrimSpace(input.Email))
}
if input.DisplayName == "" {
input.DisplayName = acct.DisplayName
}
if err := updateCheckoutAccount(ctx, tx, acct.ID, input); err != nil {
return nil, nil, err
}
if err := linkIdentity(ctx, tx, acct.ID, "stripe", input.StripeCustomerID, email); err != nil {
return nil, nil, err
}
inst, err := ensureInstance(ctx, tx, acct.ID, email, input.CustomerDomain)
if err != nil {
return nil, nil, err
}
if input.StripeEventID != "" {
if err := insertBillingEvent(ctx, tx, input, "checkout.session.completed"); err != nil {
return nil, nil, err
}
}
return acct, inst, tx.Commit()
}
func (s *Store) AccountByStripeCustomerID(ctx context.Context, customerID string) (*Account, *Instance, error) {
return s.accountAndInstanceByStripeTx(ctx, s.db, customerID)
}
func (s *Store) MarkSubscriptionStatus(ctx context.Context, customerID, status string) error {
res, err := s.db.ExecContext(ctx, `
UPDATE accounts
SET subscription_status = $2, updated_at = now()
WHERE stripe_customer_id = $1
`, customerID, status)
if err != nil {
return err
}
n, _ := res.RowsAffected()
if n == 0 {
return ErrNotFound
}
return nil
}
func (s *Store) InstanceByAccountID(ctx context.Context, accountID int64) (*Instance, error) {
return instanceByAccountID(ctx, s.db, accountID)
}
func (s *Store) InstanceBySlug(ctx context.Context, slug string) (*Instance, error) {
return instanceBySlug(ctx, s.db, slug)
}
func (s *Store) UpdateInstanceState(ctx context.Context, stackName, state string, deployed bool) error {
q := `UPDATE instances SET state = $2, updated_at = now()`
args := []any{stackName, state}
if deployed {
q += `, last_deployed_at = now()`
}
q += ` WHERE stack_name = $1`
_, err := s.db.ExecContext(ctx, q, args...)
return err
}
func (s *Store) DeleteAccountByInstanceSlug(ctx context.Context, slug string) (*Instance, error) {
inst, err := instanceBySlug(ctx, s.db, slug)
if err != nil {
return nil, err
}
_, err = s.db.ExecContext(ctx, `DELETE FROM accounts WHERE id = $1`, inst.AccountID)
return inst, err
}
func upsertCheckoutAccount(ctx context.Context, q txQueryer, input CheckoutInput) (*Account, error) {
return scanAccount(q.QueryRowContext(ctx, `
INSERT INTO accounts (primary_email, display_name, phone, stripe_customer_id, subscription_status)
VALUES ($1, $2, $3, $4, 'active')
ON CONFLICT (primary_email) DO UPDATE SET
display_name = COALESCE(NULLIF(EXCLUDED.display_name, ''), accounts.display_name),
phone = COALESCE(NULLIF(EXCLUDED.phone, ''), accounts.phone),
stripe_customer_id = EXCLUDED.stripe_customer_id,
subscription_status = 'active',
updated_at = now()
RETURNING id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
`, strings.ToLower(input.Email), input.DisplayName, input.Phone, input.StripeCustomerID))
}
func updateCheckoutAccount(ctx context.Context, q txQueryer, accountID int64, input CheckoutInput) error {
_, err := q.ExecContext(ctx, `
UPDATE accounts
SET primary_email = COALESCE(NULLIF($2, ''), primary_email),
display_name = COALESCE(NULLIF($3, ''), display_name),
phone = COALESCE(NULLIF($4, ''), phone),
stripe_customer_id = COALESCE(NULLIF($5, ''), stripe_customer_id),
subscription_status = 'active',
updated_at = now()
WHERE id = $1
`, accountID, strings.ToLower(input.Email), input.DisplayName, input.Phone, input.StripeCustomerID)
return err
}
func (s *Store) accountAndInstanceByStripeTx(ctx context.Context, q txQueryer, customerID string) (*Account, *Instance, error) {
acct, err := accountByStripeCustomerID(ctx, q, customerID)
if err != nil {
return nil, nil, err
}
inst, err := instanceByAccountID(ctx, q, acct.ID)
return acct, inst, err
}
func accountByStripeCustomerID(ctx context.Context, q txQueryer, customerID string) (*Account, error) {
return scanAccount(q.QueryRowContext(ctx, `
SELECT id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
FROM accounts WHERE stripe_customer_id = $1
`, customerID))
}
func ensureInstance(ctx context.Context, q txQueryer, accountID int64, email, domain string) (*Instance, error) {
if inst, err := instanceByAccountID(ctx, q, accountID); err == nil {
if domain != "" && inst.CustomerDomain != domain {
_, _ = q.ExecContext(ctx, `UPDATE instances SET customer_domain = $2, updated_at = now() WHERE id = $1`, inst.ID, domain)
inst.CustomerDomain = domain
}
return inst, nil
} else if !errors.Is(err, ErrNotFound) {
return nil, err
}
slug := SlugFromEmail(email)
if owner, err := instanceBySlug(ctx, q, slug); err == nil && owner.AccountID != accountID {
slug = fmt.Sprintf("%s-%d", slug, accountID)
}
stackName := "customer-" + slug
return scanInstance(q.QueryRowContext(ctx, `
INSERT INTO instances (account_id, slug, stack_name, customer_domain, state)
VALUES ($1, $2, $3, $4, 'pending')
RETURNING id, account_id, slug, stack_name, customer_domain, state, last_deployed_at
`, accountID, slug, stackName, domain))
}