forked from Nixius/authelia
379 lines
10 KiB
Go
379 lines
10 KiB
Go
package ldap
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
|
|
"git.nixc.us/a250/ss-atlas/internal/config"
|
|
goldap "github.com/go-ldap/ldap/v3"
|
|
)
|
|
|
|
type Client struct {
|
|
cfg *config.Config
|
|
gql *gqlClient
|
|
}
|
|
|
|
type ProvisionResult struct {
|
|
Username string
|
|
Password string
|
|
IsNew bool
|
|
}
|
|
|
|
func New(cfg *config.Config) *Client {
|
|
adminUID := "admin"
|
|
return &Client{
|
|
cfg: cfg,
|
|
gql: newGQLClient(cfg.LLDAPHttpURL, adminUID, cfg.LDAPAdminPassword),
|
|
}
|
|
}
|
|
|
|
func (c *Client) connect() (*goldap.Conn, error) {
|
|
conn, err := goldap.DialURL(c.cfg.LDAPUrl)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ldap dial: %w", err)
|
|
}
|
|
if err := conn.Bind(c.cfg.LDAPAdminDN, c.cfg.LDAPAdminPassword); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("ldap bind: %w", err)
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
func (c *Client) ProvisionUser(username, email, stripeCustomerID, phone string) (*ProvisionResult, error) {
|
|
conn, err := c.connect()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer conn.Close()
|
|
|
|
exists, err := c.userExists(conn, username)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if exists {
|
|
log.Printf("ldap user %s already exists", username)
|
|
if phone != "" {
|
|
_ = c.SetCustomerPhone(username, phone)
|
|
}
|
|
_ = c.ensureDisplayName(conn, username, email)
|
|
return &ProvisionResult{Username: username, IsNew: false}, nil
|
|
}
|
|
|
|
password := generatePassword()
|
|
userDN := fmt.Sprintf("uid=%s,ou=people,%s", username, c.cfg.LDAPBaseDN)
|
|
|
|
addReq := goldap.NewAddRequest(userDN, nil)
|
|
addReq.Attribute("objectClass", []string{"inetOrgPerson"})
|
|
addReq.Attribute("cn", []string{username})
|
|
addReq.Attribute("sn", []string{username})
|
|
addReq.Attribute("uid", []string{username})
|
|
addReq.Attribute("mail", []string{email})
|
|
addReq.Attribute("displayName", []string{email})
|
|
if phone != "" {
|
|
addReq.Attribute("telephoneNumber", []string{phone})
|
|
}
|
|
|
|
if err := conn.Add(addReq); err != nil {
|
|
return nil, fmt.Errorf("ldap add user %s: %w", username, err)
|
|
}
|
|
|
|
pwReq := goldap.NewPasswordModifyRequest(userDN, "", password)
|
|
if _, err := conn.PasswordModify(pwReq); err != nil {
|
|
return nil, fmt.Errorf("ldap set password for %s: %w", username, err)
|
|
}
|
|
|
|
if stripeCustomerID != "" {
|
|
if err := c.SetStripeCustomerID(username, stripeCustomerID); err != nil {
|
|
log.Printf("warning: failed to set stripe customer id for %s: %v", username, err)
|
|
}
|
|
}
|
|
if phone != "" {
|
|
if err := c.SetCustomerPhone(username, phone); err != nil {
|
|
log.Printf("warning: failed to set customer phone for %s: %v", username, err)
|
|
}
|
|
}
|
|
|
|
log.Printf("created ldap user %s (%s)", username, email)
|
|
return &ProvisionResult{Username: username, Password: password, IsNew: true}, nil
|
|
}
|
|
|
|
func (c *Client) ensureDisplayName(conn *goldap.Conn, username, email string) error {
|
|
userDN := fmt.Sprintf("uid=%s,ou=people,%s", username, c.cfg.LDAPBaseDN)
|
|
modReq := goldap.NewModifyRequest(userDN, nil)
|
|
modReq.Replace("displayName", []string{email})
|
|
if err := conn.Modify(modReq); err != nil {
|
|
log.Printf("ldap ensure displayName for %s: %v (may already be set)", username, err)
|
|
return err
|
|
}
|
|
log.Printf("ldap set displayName for %s to %s", username, email)
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) EnsureUser(username, email, stripeCustomerID, phone string) error {
|
|
_, err := c.ProvisionUser(username, email, stripeCustomerID, phone)
|
|
return err
|
|
}
|
|
|
|
func (c *Client) AddToGroup(username, groupName string) error {
|
|
groupID, err := c.getGroupID(groupName)
|
|
if err != nil {
|
|
return fmt.Errorf("resolve group %s: %w", groupName, err)
|
|
}
|
|
|
|
query := `mutation($userId: String!, $groupId: Int!) { addUserToGroup(userId: $userId, groupId: $groupId) { ok } }`
|
|
_, err = c.gql.exec(query, map[string]any{"userId": username, "groupId": groupID})
|
|
if err != nil {
|
|
return fmt.Errorf("add %s to group %s: %w", username, groupName, err)
|
|
}
|
|
log.Printf("added %s to group %s", username, groupName)
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) RemoveFromGroup(username, groupName string) error {
|
|
groupID, err := c.getGroupID(groupName)
|
|
if err != nil {
|
|
return fmt.Errorf("resolve group %s: %w", groupName, err)
|
|
}
|
|
|
|
query := `mutation($userId: String!, $groupId: Int!) { removeUserFromGroup(userId: $userId, groupId: $groupId) { ok } }`
|
|
_, err = c.gql.exec(query, map[string]any{"userId": username, "groupId": groupID})
|
|
if err != nil {
|
|
return fmt.Errorf("remove %s from group %s: %w", username, groupName, err)
|
|
}
|
|
log.Printf("removed %s from group %s", username, groupName)
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) IsInGroup(username, groupName string) (bool, error) {
|
|
query := `query($userId: String!) { user(userId: $userId) { groups { displayName } } }`
|
|
data, err := c.gql.exec(query, map[string]any{"userId": username})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
var result struct {
|
|
User struct {
|
|
Groups []struct {
|
|
DisplayName string `json:"displayName"`
|
|
} `json:"groups"`
|
|
} `json:"user"`
|
|
}
|
|
if err := json.Unmarshal(data, &result); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
for _, g := range result.User.Groups {
|
|
if g.DisplayName == groupName {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (c *Client) SetStripeCustomerID(username, customerID string) error {
|
|
query := `mutation($userId: String!, $attrs: [AttributeValueInput!]!) {
|
|
updateUser(user: { id: $userId, insertAttributes: $attrs }) { ok }
|
|
}`
|
|
attrs := []map[string]any{
|
|
{"name": "stripe-customer-id", "value": []string{customerID}},
|
|
}
|
|
_, err := c.gql.exec(query, map[string]any{"userId": username, "attrs": attrs})
|
|
return err
|
|
}
|
|
|
|
func (c *Client) SetCustomerPhone(username, phone string) error {
|
|
if phone == "" {
|
|
return nil
|
|
}
|
|
query := `mutation($userId: String!, $attrs: [AttributeValueInput!]!) {
|
|
updateUser(user: { id: $userId, insertAttributes: $attrs }) { ok }
|
|
}`
|
|
attrs := []map[string]any{
|
|
{"name": "customer-phone", "value": []string{phone}},
|
|
}
|
|
_, err := c.gql.exec(query, map[string]any{"userId": username, "attrs": attrs})
|
|
return err
|
|
}
|
|
|
|
func (c *Client) SetCustomerDomain(username, domain string) error {
|
|
if domain == "" {
|
|
return nil
|
|
}
|
|
query := `mutation($userId: String!, $attrs: [AttributeValueInput!]!) {
|
|
updateUser(user: { id: $userId, insertAttributes: $attrs }) { ok }
|
|
}`
|
|
attrs := []map[string]any{
|
|
{"name": "customer-domain", "value": []string{domain}},
|
|
}
|
|
_, err := c.gql.exec(query, map[string]any{"userId": username, "attrs": attrs})
|
|
return err
|
|
}
|
|
|
|
func (c *Client) GetCustomerDomain(username string) (string, error) {
|
|
query := `query($userId: String!) { user(userId: $userId) { attributes { name value } } }`
|
|
data, err := c.gql.exec(query, map[string]any{"userId": username})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var result struct {
|
|
User struct {
|
|
Attributes []struct {
|
|
Name string `json:"name"`
|
|
Value []string `json:"value"`
|
|
} `json:"attributes"`
|
|
} `json:"user"`
|
|
}
|
|
if err := json.Unmarshal(data, &result); err != nil {
|
|
return "", err
|
|
}
|
|
for _, attr := range result.User.Attributes {
|
|
if attr.Name == "customer-domain" && len(attr.Value) > 0 {
|
|
return attr.Value[0], nil
|
|
}
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func (c *Client) GetStripeCustomerID(username string) (string, error) {
|
|
query := `query($userId: String!) { user(userId: $userId) { attributes { name value } } }`
|
|
data, err := c.gql.exec(query, map[string]any{"userId": username})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var result struct {
|
|
User struct {
|
|
Attributes []struct {
|
|
Name string `json:"name"`
|
|
Value []string `json:"value"`
|
|
} `json:"attributes"`
|
|
} `json:"user"`
|
|
}
|
|
if err := json.Unmarshal(data, &result); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for _, attr := range result.User.Attributes {
|
|
if attr.Name == "stripe-customer-id" && len(attr.Value) > 0 {
|
|
return attr.Value[0], nil
|
|
}
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func (c *Client) FindUserByStripeID(stripeCustomerID string) (string, error) {
|
|
query := `query { users(filters: {}) { id attributes { name value } } }`
|
|
data, err := c.gql.exec(query, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var result struct {
|
|
Users []struct {
|
|
ID string `json:"id"`
|
|
Attributes []struct {
|
|
Name string `json:"name"`
|
|
Value []string `json:"value"`
|
|
} `json:"attributes"`
|
|
} `json:"users"`
|
|
}
|
|
if err := json.Unmarshal(data, &result); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for _, u := range result.Users {
|
|
for _, attr := range u.Attributes {
|
|
if attr.Name == "stripe-customer-id" && len(attr.Value) > 0 && attr.Value[0] == stripeCustomerID {
|
|
return u.ID, nil
|
|
}
|
|
}
|
|
}
|
|
return "", fmt.Errorf("no user found with stripe customer %s", stripeCustomerID)
|
|
}
|
|
|
|
// CountCustomers returns the number of users who have completed checkout (have stripe-customer-id).
|
|
// Used to enforce signup limits.
|
|
func (c *Client) CountCustomers() (int, error) {
|
|
query := `query { users(filters: {}) { id attributes { name value } } }`
|
|
data, err := c.gql.exec(query, nil)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
var result struct {
|
|
Users []struct {
|
|
Attributes []struct {
|
|
Name string `json:"name"`
|
|
Value []string `json:"value"`
|
|
} `json:"attributes"`
|
|
} `json:"users"`
|
|
}
|
|
if err := json.Unmarshal(data, &result); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
n := 0
|
|
for _, u := range result.Users {
|
|
for _, attr := range u.Attributes {
|
|
if attr.Name == "stripe-customer-id" && len(attr.Value) > 0 && attr.Value[0] != "" {
|
|
n++
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func (c *Client) getGroupID(groupName string) (int, error) {
|
|
query := `query { groups { id displayName } }`
|
|
data, err := c.gql.exec(query, nil)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
var result struct {
|
|
Groups []struct {
|
|
ID int `json:"id"`
|
|
DisplayName string `json:"displayName"`
|
|
} `json:"groups"`
|
|
}
|
|
if err := json.Unmarshal(data, &result); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
for _, g := range result.Groups {
|
|
if g.DisplayName == groupName {
|
|
return g.ID, nil
|
|
}
|
|
}
|
|
return 0, fmt.Errorf("group %s not found", groupName)
|
|
}
|
|
|
|
func (c *Client) userExists(conn *goldap.Conn, username string) (bool, error) {
|
|
searchReq := goldap.NewSearchRequest(
|
|
fmt.Sprintf("ou=people,%s", c.cfg.LDAPBaseDN),
|
|
goldap.ScopeWholeSubtree, goldap.NeverDerefAliases, 1, 0, false,
|
|
fmt.Sprintf("(uid=%s)", goldap.EscapeFilter(username)),
|
|
[]string{"uid"},
|
|
nil,
|
|
)
|
|
|
|
result, err := conn.Search(searchReq)
|
|
if err != nil {
|
|
return false, fmt.Errorf("ldap search: %w", err)
|
|
}
|
|
|
|
return len(result.Entries) > 0, nil
|
|
}
|
|
|
|
func generatePassword() string {
|
|
b := make([]byte, 18)
|
|
if _, err := rand.Read(b); err != nil {
|
|
panic("crypto/rand failed: " + err.Error())
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b)
|
|
}
|