forked from Nixius/authelia
126 lines
4.3 KiB
Go
126 lines
4.3 KiB
Go
package accounts
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
)
|
|
|
|
func upsertAccountByEmail(ctx context.Context, q txQueryer, email, displayName string) (*Account, error) {
|
|
return scanAccount(q.QueryRowContext(ctx, `
|
|
INSERT INTO accounts (primary_email, display_name)
|
|
VALUES ($1, $2)
|
|
ON CONFLICT (primary_email) DO UPDATE SET
|
|
display_name = COALESCE(NULLIF(EXCLUDED.display_name, ''), accounts.display_name),
|
|
updated_at = now()
|
|
RETURNING id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
|
|
`, email, displayName))
|
|
}
|
|
|
|
func updateAccountProfile(ctx context.Context, q txQueryer, accountID int64, email, displayName string) error {
|
|
_, err := q.ExecContext(ctx, `
|
|
UPDATE accounts
|
|
SET primary_email = COALESCE(NULLIF($2, ''), primary_email),
|
|
display_name = COALESCE(NULLIF($3, ''), display_name),
|
|
updated_at = now()
|
|
WHERE id = $1
|
|
`, accountID, email, displayName)
|
|
return err
|
|
}
|
|
|
|
func linkIdentity(ctx context.Context, q txQueryer, accountID int64, provider, subject, email string) error {
|
|
_, err := q.ExecContext(ctx, `
|
|
INSERT INTO account_identities (account_id, provider, provider_subject, email_at_login)
|
|
VALUES ($1, $2, $3, $4)
|
|
ON CONFLICT (provider, provider_subject) DO UPDATE SET
|
|
account_id = EXCLUDED.account_id,
|
|
email_at_login = EXCLUDED.email_at_login
|
|
`, accountID, provider, subject, email)
|
|
return err
|
|
}
|
|
|
|
func accountByIdentity(ctx context.Context, q txQueryer, provider, subject string) (*Account, error) {
|
|
return scanAccount(q.QueryRowContext(ctx, `
|
|
SELECT a.id, a.primary_email, a.display_name, a.stripe_customer_id,
|
|
a.subscription_status, a.created_at, a.updated_at
|
|
FROM accounts a
|
|
JOIN account_identities i ON i.account_id = a.id
|
|
WHERE i.provider = $1 AND i.provider_subject = $2
|
|
`, provider, subject))
|
|
}
|
|
|
|
func accountByID(ctx context.Context, q txQueryer, accountID int64) (*Account, error) {
|
|
return scanAccount(q.QueryRowContext(ctx, `
|
|
SELECT id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
|
|
FROM accounts WHERE id = $1
|
|
`, accountID))
|
|
}
|
|
|
|
func accountByEmail(ctx context.Context, q txQueryer, email string) (*Account, error) {
|
|
return scanAccount(q.QueryRowContext(ctx, `
|
|
SELECT id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
|
|
FROM accounts WHERE primary_email = $1
|
|
`, email))
|
|
}
|
|
|
|
func instanceByAccountID(ctx context.Context, q txQueryer, accountID int64) (*Instance, error) {
|
|
return scanInstance(q.QueryRowContext(ctx, `
|
|
SELECT id, account_id, slug, stack_name, customer_domain, state, last_deployed_at
|
|
FROM instances WHERE account_id = $1
|
|
`, accountID))
|
|
}
|
|
|
|
func instanceBySlug(ctx context.Context, q txQueryer, slug string) (*Instance, error) {
|
|
return scanInstance(q.QueryRowContext(ctx, `
|
|
SELECT id, account_id, slug, stack_name, customer_domain, state, last_deployed_at
|
|
FROM instances WHERE slug = $1
|
|
`, slug))
|
|
}
|
|
|
|
func eventProcessed(ctx context.Context, q txQueryer, eventID string) (bool, error) {
|
|
var one int
|
|
err := q.QueryRowContext(ctx, `SELECT 1 FROM billing_events WHERE event_id = $1`, eventID).Scan(&one)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
return err == nil, err
|
|
}
|
|
|
|
func insertBillingEvent(ctx context.Context, q txQueryer, input CheckoutInput, eventType string) error {
|
|
_, err := q.ExecContext(ctx, `
|
|
INSERT INTO billing_events (
|
|
event_id, event_type, stripe_session_id, stripe_subscription_id, stripe_customer_id
|
|
) VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (event_id) DO NOTHING
|
|
`, input.StripeEventID, eventType, input.StripeSessionID, input.StripeSubscriptionID, input.StripeCustomerID)
|
|
return err
|
|
}
|
|
|
|
func scanAccount(row *sql.Row) (*Account, error) {
|
|
var acct Account
|
|
var stripe sql.NullString
|
|
err := row.Scan(&acct.ID, &acct.PrimaryEmail, &acct.DisplayName, &stripe,
|
|
&acct.SubscriptionStatus, &acct.CreatedAt, &acct.UpdatedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
acct.StripeCustomerID = stripe.String
|
|
return &acct, nil
|
|
}
|
|
|
|
func scanInstance(row *sql.Row) (*Instance, error) {
|
|
var inst Instance
|
|
err := row.Scan(&inst.ID, &inst.AccountID, &inst.Slug, &inst.StackName,
|
|
&inst.CustomerDomain, &inst.State, &inst.LastDeployedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &inst, nil
|
|
}
|