forked from Nixius/authelia
173 lines
4.7 KiB
Go
173 lines
4.7 KiB
Go
package accounts
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"strings"
|
|
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
var ErrNotFound = errors.New("account not found")
|
|
|
|
type Store struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func New(ctx context.Context, databaseURL string) (*Store, error) {
|
|
if strings.TrimSpace(databaseURL) == "" {
|
|
return nil, errors.New("DATABASE_URL is required")
|
|
}
|
|
db, err := sql.Open("postgres", databaseURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := db.PingContext(ctx); err != nil {
|
|
db.Close()
|
|
return nil, err
|
|
}
|
|
s := &Store{db: db}
|
|
if err := s.Migrate(ctx); err != nil {
|
|
db.Close()
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Store) Close() error {
|
|
if s == nil || s.db == nil {
|
|
return nil
|
|
}
|
|
return s.db.Close()
|
|
}
|
|
|
|
func (s *Store) Migrate(ctx context.Context) error {
|
|
_, err := s.db.ExecContext(ctx, `
|
|
CREATE TABLE IF NOT EXISTS accounts (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
primary_email TEXT NOT NULL UNIQUE,
|
|
display_name TEXT NOT NULL DEFAULT '',
|
|
phone TEXT NOT NULL DEFAULT '',
|
|
stripe_customer_id TEXT UNIQUE,
|
|
subscription_status TEXT NOT NULL DEFAULT 'none',
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
|
);
|
|
CREATE TABLE IF NOT EXISTS account_identities (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
|
provider TEXT NOT NULL,
|
|
provider_subject TEXT NOT NULL,
|
|
email_at_login TEXT NOT NULL DEFAULT '',
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
UNIQUE(provider, provider_subject)
|
|
);
|
|
CREATE TABLE IF NOT EXISTS instances (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
account_id BIGINT NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
|
|
slug TEXT NOT NULL UNIQUE,
|
|
stack_name TEXT NOT NULL UNIQUE,
|
|
customer_domain TEXT NOT NULL DEFAULT '',
|
|
state TEXT NOT NULL DEFAULT 'pending',
|
|
last_deployed_at TIMESTAMPTZ,
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
|
);
|
|
CREATE TABLE IF NOT EXISTS billing_events (
|
|
event_id TEXT PRIMARY KEY,
|
|
event_type TEXT NOT NULL,
|
|
stripe_session_id TEXT NOT NULL DEFAULT '',
|
|
stripe_subscription_id TEXT NOT NULL DEFAULT '',
|
|
stripe_customer_id TEXT NOT NULL DEFAULT '',
|
|
processed_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
|
);
|
|
CREATE INDEX IF NOT EXISTS accounts_subscription_status_idx ON accounts(subscription_status);
|
|
CREATE INDEX IF NOT EXISTS instances_account_id_idx ON instances(account_id);
|
|
`)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CountCustomers(ctx context.Context) (int, error) {
|
|
var count int
|
|
err := s.db.QueryRowContext(ctx, `
|
|
SELECT count(*) FROM accounts
|
|
WHERE stripe_customer_id IS NOT NULL AND stripe_customer_id <> ''
|
|
`).Scan(&count)
|
|
return count, err
|
|
}
|
|
|
|
func (s *Store) UpsertFromIdentity(ctx context.Context, identity Identity) (*Account, error) {
|
|
if identity.Provider == "" {
|
|
identity.Provider = "authentik"
|
|
}
|
|
if identity.Subject == "" {
|
|
identity.Subject = firstNonEmpty(identity.Username, identity.Email)
|
|
}
|
|
if identity.Subject == "" {
|
|
return nil, ErrNotFound
|
|
}
|
|
email := strings.ToLower(strings.TrimSpace(identity.Email))
|
|
displayName := strings.TrimSpace(firstNonEmpty(identity.Name, identity.Username, email))
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
acct, err := accountByIdentity(ctx, tx, identity.Provider, identity.Subject)
|
|
if err == nil {
|
|
if email != "" && !strings.EqualFold(acct.PrimaryEmail, email) {
|
|
emailAcct, emailErr := accountByEmail(ctx, tx, email)
|
|
if emailErr == nil && emailAcct.ID != acct.ID {
|
|
if err := linkIdentity(ctx, tx, emailAcct.ID, identity.Provider, identity.Subject, email); err != nil {
|
|
return nil, err
|
|
}
|
|
if displayName != "" {
|
|
if err := updateAccountProfile(ctx, tx, emailAcct.ID, "", displayName); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
acct, err := accountByID(ctx, tx, emailAcct.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return acct, tx.Commit()
|
|
}
|
|
if emailErr != nil && !errors.Is(emailErr, ErrNotFound) {
|
|
return nil, emailErr
|
|
}
|
|
}
|
|
if email != "" || displayName != "" {
|
|
if err := updateAccountProfile(ctx, tx, acct.ID, email, displayName); err != nil {
|
|
return nil, err
|
|
}
|
|
acct, _ = accountByID(ctx, tx, acct.ID)
|
|
}
|
|
return acct, tx.Commit()
|
|
}
|
|
if !errors.Is(err, ErrNotFound) {
|
|
return nil, err
|
|
}
|
|
if email == "" {
|
|
return nil, ErrNotFound
|
|
}
|
|
acct, err = upsertAccountByEmail(ctx, tx, email, displayName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := linkIdentity(ctx, tx, acct.ID, identity.Provider, identity.Subject, email); err != nil {
|
|
return nil, err
|
|
}
|
|
return acct, tx.Commit()
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, v := range values {
|
|
if strings.TrimSpace(v) != "" {
|
|
return strings.TrimSpace(v)
|
|
}
|
|
}
|
|
return ""
|
|
}
|