package accounts import ( "context" "database/sql" "errors" ) func upsertAccountByEmail(ctx context.Context, q txQueryer, email, displayName string) (*Account, error) { return scanAccount(q.QueryRowContext(ctx, ` INSERT INTO accounts (primary_email, display_name) VALUES ($1, $2) ON CONFLICT (primary_email) DO UPDATE SET display_name = COALESCE(NULLIF(EXCLUDED.display_name, ''), accounts.display_name), updated_at = now() RETURNING id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at `, email, displayName)) } func updateAccountProfile(ctx context.Context, q txQueryer, accountID int64, email, displayName string) error { _, err := q.ExecContext(ctx, ` UPDATE accounts SET primary_email = COALESCE(NULLIF($2, ''), primary_email), display_name = COALESCE(NULLIF($3, ''), display_name), updated_at = now() WHERE id = $1 `, accountID, email, displayName) return err } func linkIdentity(ctx context.Context, q txQueryer, accountID int64, provider, subject, email string) error { _, err := q.ExecContext(ctx, ` INSERT INTO account_identities (account_id, provider, provider_subject, email_at_login) VALUES ($1, $2, $3, $4) ON CONFLICT (provider, provider_subject) DO UPDATE SET account_id = EXCLUDED.account_id, email_at_login = EXCLUDED.email_at_login `, accountID, provider, subject, email) return err } func accountByIdentity(ctx context.Context, q txQueryer, provider, subject string) (*Account, error) { return scanAccount(q.QueryRowContext(ctx, ` SELECT a.id, a.primary_email, a.display_name, a.stripe_customer_id, a.subscription_status, a.created_at, a.updated_at FROM accounts a JOIN account_identities i ON i.account_id = a.id WHERE i.provider = $1 AND i.provider_subject = $2 `, provider, subject)) } func accountByID(ctx context.Context, q txQueryer, accountID int64) (*Account, error) { return scanAccount(q.QueryRowContext(ctx, ` SELECT id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at FROM accounts WHERE id = $1 `, accountID)) } func accountByEmail(ctx context.Context, q txQueryer, email string) (*Account, error) { return scanAccount(q.QueryRowContext(ctx, ` SELECT id, primary_email, display_name, stripe_customer_id, subscription_status, created_at, updated_at FROM accounts WHERE primary_email = $1 `, email)) } func instanceByAccountID(ctx context.Context, q txQueryer, accountID int64) (*Instance, error) { return scanInstance(q.QueryRowContext(ctx, ` SELECT id, account_id, slug, stack_name, customer_domain, state, last_deployed_at FROM instances WHERE account_id = $1 `, accountID)) } func instanceBySlug(ctx context.Context, q txQueryer, slug string) (*Instance, error) { return scanInstance(q.QueryRowContext(ctx, ` SELECT id, account_id, slug, stack_name, customer_domain, state, last_deployed_at FROM instances WHERE slug = $1 `, slug)) } func eventProcessed(ctx context.Context, q txQueryer, eventID string) (bool, error) { var one int err := q.QueryRowContext(ctx, `SELECT 1 FROM billing_events WHERE event_id = $1`, eventID).Scan(&one) if errors.Is(err, sql.ErrNoRows) { return false, nil } return err == nil, err } func insertBillingEvent(ctx context.Context, q txQueryer, input CheckoutInput, eventType string) error { _, err := q.ExecContext(ctx, ` INSERT INTO billing_events ( event_id, event_type, stripe_session_id, stripe_subscription_id, stripe_customer_id ) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (event_id) DO NOTHING `, input.StripeEventID, eventType, input.StripeSessionID, input.StripeSubscriptionID, input.StripeCustomerID) return err } func scanAccount(row *sql.Row) (*Account, error) { var acct Account var stripe sql.NullString err := row.Scan(&acct.ID, &acct.PrimaryEmail, &acct.DisplayName, &stripe, &acct.SubscriptionStatus, &acct.CreatedAt, &acct.UpdatedAt) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, err } acct.StripeCustomerID = stripe.String return &acct, nil } func scanInstance(row *sql.Row) (*Instance, error) { var inst Instance err := row.Scan(&inst.ID, &inst.AccountID, &inst.Slug, &inst.StackName, &inst.CustomerDomain, &inst.State, &inst.LastDeployedAt) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, err } return &inst, nil }