forked from Nixius/authelia
1
0
Fork 0
ATLAS/docker/ss-atlas/internal/accounts/scans.go

126 lines
4.3 KiB
Go

package accounts
import (
"context"
"database/sql"
"errors"
)
func upsertAccountByEmail(ctx context.Context, q txQueryer, email, displayName string) (*Account, error) {
return scanAccount(q.QueryRowContext(ctx, `
INSERT INTO accounts (primary_email, display_name)
VALUES ($1, $2)
ON CONFLICT (primary_email) DO UPDATE SET
display_name = COALESCE(NULLIF(EXCLUDED.display_name, ''), accounts.display_name),
updated_at = now()
RETURNING id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
`, email, displayName))
}
func updateAccountProfile(ctx context.Context, q txQueryer, accountID int64, email, displayName string) error {
_, err := q.ExecContext(ctx, `
UPDATE accounts
SET primary_email = COALESCE(NULLIF($2, ''), primary_email),
display_name = COALESCE(NULLIF($3, ''), display_name),
updated_at = now()
WHERE id = $1
`, accountID, email, displayName)
return err
}
func linkIdentity(ctx context.Context, q txQueryer, accountID int64, provider, subject, email string) error {
_, err := q.ExecContext(ctx, `
INSERT INTO account_identities (account_id, provider, provider_subject, email_at_login)
VALUES ($1, $2, $3, $4)
ON CONFLICT (provider, provider_subject) DO UPDATE SET
account_id = EXCLUDED.account_id,
email_at_login = EXCLUDED.email_at_login
`, accountID, provider, subject, email)
return err
}
func accountByIdentity(ctx context.Context, q txQueryer, provider, subject string) (*Account, error) {
return scanAccount(q.QueryRowContext(ctx, `
SELECT a.id, a.primary_email, a.display_name, a.stripe_customer_id,
a.subscription_status, a.created_at, a.updated_at
FROM accounts a
JOIN account_identities i ON i.account_id = a.id
WHERE i.provider = $1 AND i.provider_subject = $2
`, provider, subject))
}
func accountByID(ctx context.Context, q txQueryer, accountID int64) (*Account, error) {
return scanAccount(q.QueryRowContext(ctx, `
SELECT id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
FROM accounts WHERE id = $1
`, accountID))
}
func accountByEmail(ctx context.Context, q txQueryer, email string) (*Account, error) {
return scanAccount(q.QueryRowContext(ctx, `
SELECT id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at
FROM accounts WHERE primary_email = $1
`, email))
}
func instanceByAccountID(ctx context.Context, q txQueryer, accountID int64) (*Instance, error) {
return scanInstance(q.QueryRowContext(ctx, `
SELECT id, account_id, slug, stack_name, customer_domain, state, last_deployed_at
FROM instances WHERE account_id = $1
`, accountID))
}
func instanceBySlug(ctx context.Context, q txQueryer, slug string) (*Instance, error) {
return scanInstance(q.QueryRowContext(ctx, `
SELECT id, account_id, slug, stack_name, customer_domain, state, last_deployed_at
FROM instances WHERE slug = $1
`, slug))
}
func eventProcessed(ctx context.Context, q txQueryer, eventID string) (bool, error) {
var one int
err := q.QueryRowContext(ctx, `SELECT 1 FROM billing_events WHERE event_id = $1`, eventID).Scan(&one)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return err == nil, err
}
func insertBillingEvent(ctx context.Context, q txQueryer, input CheckoutInput, eventType string) error {
_, err := q.ExecContext(ctx, `
INSERT INTO billing_events (
event_id, event_type, stripe_session_id, stripe_subscription_id, stripe_customer_id
) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (event_id) DO NOTHING
`, input.StripeEventID, eventType, input.StripeSessionID, input.StripeSubscriptionID, input.StripeCustomerID)
return err
}
func scanAccount(row *sql.Row) (*Account, error) {
var acct Account
var stripe sql.NullString
err := row.Scan(&acct.ID, &acct.PrimaryEmail, &acct.DisplayName, &stripe,
&acct.SubscriptionStatus, &acct.CreatedAt, &acct.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
acct.StripeCustomerID = stripe.String
return &acct, nil
}
func scanInstance(row *sql.Row) (*Instance, error) {
var inst Instance
err := row.Scan(&inst.ID, &inst.AccountID, &inst.Slug, &inst.StackName,
&inst.CustomerDomain, &inst.State, &inst.LastDeployedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
return &inst, nil
}