forked from Nixius/authelia
1
0
Fork 0
ATLAS/docker/ss-atlas/internal/accounts/store.go

173 lines
4.7 KiB
Go

package accounts
import (
"context"
"database/sql"
"errors"
"strings"
_ "github.com/lib/pq"
)
var ErrNotFound = errors.New("account not found")
type Store struct {
db *sql.DB
}
func New(ctx context.Context, databaseURL string) (*Store, error) {
if strings.TrimSpace(databaseURL) == "" {
return nil, errors.New("DATABASE_URL is required")
}
db, err := sql.Open("postgres", databaseURL)
if err != nil {
return nil, err
}
if err := db.PingContext(ctx); err != nil {
db.Close()
return nil, err
}
s := &Store{db: db}
if err := s.Migrate(ctx); err != nil {
db.Close()
return nil, err
}
return s, nil
}
func (s *Store) Close() error {
if s == nil || s.db == nil {
return nil
}
return s.db.Close()
}
func (s *Store) Migrate(ctx context.Context) error {
_, err := s.db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS accounts (
id BIGSERIAL PRIMARY KEY,
primary_email TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL DEFAULT '',
phone TEXT NOT NULL DEFAULT '',
stripe_customer_id TEXT UNIQUE,
subscription_status TEXT NOT NULL DEFAULT 'none',
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE TABLE IF NOT EXISTS account_identities (
id BIGSERIAL PRIMARY KEY,
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
provider TEXT NOT NULL,
provider_subject TEXT NOT NULL,
email_at_login TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
UNIQUE(provider, provider_subject)
);
CREATE TABLE IF NOT EXISTS instances (
id BIGSERIAL PRIMARY KEY,
account_id BIGINT NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
slug TEXT NOT NULL UNIQUE,
stack_name TEXT NOT NULL UNIQUE,
customer_domain TEXT NOT NULL DEFAULT '',
state TEXT NOT NULL DEFAULT 'pending',
last_deployed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE TABLE IF NOT EXISTS billing_events (
event_id TEXT PRIMARY KEY,
event_type TEXT NOT NULL,
stripe_session_id TEXT NOT NULL DEFAULT '',
stripe_subscription_id TEXT NOT NULL DEFAULT '',
stripe_customer_id TEXT NOT NULL DEFAULT '',
processed_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS accounts_subscription_status_idx ON accounts(subscription_status);
CREATE INDEX IF NOT EXISTS instances_account_id_idx ON instances(account_id);
`)
return err
}
func (s *Store) CountCustomers(ctx context.Context) (int, error) {
var count int
err := s.db.QueryRowContext(ctx, `
SELECT count(*) FROM accounts
WHERE stripe_customer_id IS NOT NULL AND stripe_customer_id <> ''
`).Scan(&count)
return count, err
}
func (s *Store) UpsertFromIdentity(ctx context.Context, identity Identity) (*Account, error) {
if identity.Provider == "" {
identity.Provider = "authentik"
}
if identity.Subject == "" {
identity.Subject = firstNonEmpty(identity.Username, identity.Email)
}
if identity.Subject == "" {
return nil, ErrNotFound
}
email := strings.ToLower(strings.TrimSpace(identity.Email))
displayName := strings.TrimSpace(firstNonEmpty(identity.Name, identity.Username, email))
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
acct, err := accountByIdentity(ctx, tx, identity.Provider, identity.Subject)
if err == nil {
if email != "" && !strings.EqualFold(acct.PrimaryEmail, email) {
emailAcct, emailErr := accountByEmail(ctx, tx, email)
if emailErr == nil && emailAcct.ID != acct.ID {
if err := linkIdentity(ctx, tx, emailAcct.ID, identity.Provider, identity.Subject, email); err != nil {
return nil, err
}
if displayName != "" {
if err := updateAccountProfile(ctx, tx, emailAcct.ID, "", displayName); err != nil {
return nil, err
}
}
acct, err := accountByID(ctx, tx, emailAcct.ID)
if err != nil {
return nil, err
}
return acct, tx.Commit()
}
if emailErr != nil && !errors.Is(emailErr, ErrNotFound) {
return nil, emailErr
}
}
if email != "" || displayName != "" {
if err := updateAccountProfile(ctx, tx, acct.ID, email, displayName); err != nil {
return nil, err
}
acct, _ = accountByID(ctx, tx, acct.ID)
}
return acct, tx.Commit()
}
if !errors.Is(err, ErrNotFound) {
return nil, err
}
if email == "" {
return nil, ErrNotFound
}
acct, err = upsertAccountByEmail(ctx, tx, email, displayName)
if err != nil {
return nil, err
}
if err := linkIdentity(ctx, tx, acct.ID, identity.Provider, identity.Subject, email); err != nil {
return nil, err
}
return acct, tx.Commit()
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
if strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
}
return ""
}