package ldap import ( "crypto/rand" "encoding/base64" "encoding/json" "fmt" "log" "git.nixc.us/a250/ss-atlas/internal/config" goldap "github.com/go-ldap/ldap/v3" ) type Client struct { cfg *config.Config gql *gqlClient } type ProvisionResult struct { Username string Password string IsNew bool } func New(cfg *config.Config) *Client { adminUID := "admin" return &Client{ cfg: cfg, gql: newGQLClient(cfg.LLDAPHttpURL, adminUID, cfg.LDAPAdminPassword), } } func (c *Client) connect() (*goldap.Conn, error) { conn, err := goldap.DialURL(c.cfg.LDAPUrl) if err != nil { return nil, fmt.Errorf("ldap dial: %w", err) } if err := conn.Bind(c.cfg.LDAPAdminDN, c.cfg.LDAPAdminPassword); err != nil { conn.Close() return nil, fmt.Errorf("ldap bind: %w", err) } return conn, nil } func (c *Client) ProvisionUser(username, email, stripeCustomerID, phone string) (*ProvisionResult, error) { conn, err := c.connect() if err != nil { return nil, err } defer conn.Close() exists, err := c.userExists(conn, username) if err != nil { return nil, err } if exists { log.Printf("ldap user %s already exists", username) if phone != "" { _ = c.SetCustomerPhone(username, phone) } _ = c.ensureDisplayName(conn, username, email) return &ProvisionResult{Username: username, IsNew: false}, nil } password := generatePassword() userDN := fmt.Sprintf("uid=%s,ou=people,%s", username, c.cfg.LDAPBaseDN) addReq := goldap.NewAddRequest(userDN, nil) addReq.Attribute("objectClass", []string{"inetOrgPerson"}) addReq.Attribute("cn", []string{username}) addReq.Attribute("sn", []string{username}) addReq.Attribute("uid", []string{username}) addReq.Attribute("mail", []string{email}) addReq.Attribute("displayName", []string{email}) if phone != "" { addReq.Attribute("telephoneNumber", []string{phone}) } if err := conn.Add(addReq); err != nil { return nil, fmt.Errorf("ldap add user %s: %w", username, err) } pwReq := goldap.NewPasswordModifyRequest(userDN, "", password) if _, err := conn.PasswordModify(pwReq); err != nil { return nil, fmt.Errorf("ldap set password for %s: %w", username, err) } if stripeCustomerID != "" { if err := c.SetStripeCustomerID(username, stripeCustomerID); err != nil { log.Printf("warning: failed to set stripe customer id for %s: %v", username, err) } } if phone != "" { if err := c.SetCustomerPhone(username, phone); err != nil { log.Printf("warning: failed to set customer phone for %s: %v", username, err) } } log.Printf("created ldap user %s (%s)", username, email) return &ProvisionResult{Username: username, Password: password, IsNew: true}, nil } func (c *Client) ensureDisplayName(conn *goldap.Conn, username, email string) error { userDN := fmt.Sprintf("uid=%s,ou=people,%s", username, c.cfg.LDAPBaseDN) modReq := goldap.NewModifyRequest(userDN, nil) modReq.Replace("displayName", []string{email}) if err := conn.Modify(modReq); err != nil { log.Printf("ldap ensure displayName for %s: %v (may already be set)", username, err) return err } log.Printf("ldap set displayName for %s to %s", username, email) return nil } func (c *Client) EnsureUser(username, email, stripeCustomerID, phone string) error { _, err := c.ProvisionUser(username, email, stripeCustomerID, phone) return err } func (c *Client) AddToGroup(username, groupName string) error { groupID, err := c.getGroupID(groupName) if err != nil { return fmt.Errorf("resolve group %s: %w", groupName, err) } query := `mutation($userId: String!, $groupId: Int!) { addUserToGroup(userId: $userId, groupId: $groupId) { ok } }` _, err = c.gql.exec(query, map[string]any{"userId": username, "groupId": groupID}) if err != nil { return fmt.Errorf("add %s to group %s: %w", username, groupName, err) } log.Printf("added %s to group %s", username, groupName) return nil } func (c *Client) RemoveFromGroup(username, groupName string) error { groupID, err := c.getGroupID(groupName) if err != nil { return fmt.Errorf("resolve group %s: %w", groupName, err) } query := `mutation($userId: String!, $groupId: Int!) { removeUserFromGroup(userId: $userId, groupId: $groupId) { ok } }` _, err = c.gql.exec(query, map[string]any{"userId": username, "groupId": groupID}) if err != nil { return fmt.Errorf("remove %s from group %s: %w", username, groupName, err) } log.Printf("removed %s from group %s", username, groupName) return nil } func (c *Client) IsInGroup(username, groupName string) (bool, error) { query := `query($userId: String!) { user(userId: $userId) { groups { displayName } } }` data, err := c.gql.exec(query, map[string]any{"userId": username}) if err != nil { return false, err } var result struct { User struct { Groups []struct { DisplayName string `json:"displayName"` } `json:"groups"` } `json:"user"` } if err := json.Unmarshal(data, &result); err != nil { return false, err } for _, g := range result.User.Groups { if g.DisplayName == groupName { return true, nil } } return false, nil } func (c *Client) SetStripeCustomerID(username, customerID string) error { query := `mutation($userId: String!, $attrs: [AttributeValueInput!]!) { updateUser(user: { id: $userId, insertAttributes: $attrs }) { ok } }` attrs := []map[string]any{ {"name": "stripe-customer-id", "value": []string{customerID}}, } _, err := c.gql.exec(query, map[string]any{"userId": username, "attrs": attrs}) return err } func (c *Client) SetCustomerPhone(username, phone string) error { if phone == "" { return nil } query := `mutation($userId: String!, $attrs: [AttributeValueInput!]!) { updateUser(user: { id: $userId, insertAttributes: $attrs }) { ok } }` attrs := []map[string]any{ {"name": "customer-phone", "value": []string{phone}}, } _, err := c.gql.exec(query, map[string]any{"userId": username, "attrs": attrs}) return err } func (c *Client) SetCustomerDomain(username, domain string) error { if domain == "" { return nil } query := `mutation($userId: String!, $attrs: [AttributeValueInput!]!) { updateUser(user: { id: $userId, insertAttributes: $attrs }) { ok } }` attrs := []map[string]any{ {"name": "customer-domain", "value": []string{domain}}, } _, err := c.gql.exec(query, map[string]any{"userId": username, "attrs": attrs}) return err } func (c *Client) GetCustomerDomain(username string) (string, error) { query := `query($userId: String!) { user(userId: $userId) { attributes { name value } } }` data, err := c.gql.exec(query, map[string]any{"userId": username}) if err != nil { return "", err } var result struct { User struct { Attributes []struct { Name string `json:"name"` Value []string `json:"value"` } `json:"attributes"` } `json:"user"` } if err := json.Unmarshal(data, &result); err != nil { return "", err } for _, attr := range result.User.Attributes { if attr.Name == "customer-domain" && len(attr.Value) > 0 { return attr.Value[0], nil } } return "", nil } func (c *Client) GetStripeCustomerID(username string) (string, error) { query := `query($userId: String!) { user(userId: $userId) { attributes { name value } } }` data, err := c.gql.exec(query, map[string]any{"userId": username}) if err != nil { return "", err } var result struct { User struct { Attributes []struct { Name string `json:"name"` Value []string `json:"value"` } `json:"attributes"` } `json:"user"` } if err := json.Unmarshal(data, &result); err != nil { return "", err } for _, attr := range result.User.Attributes { if attr.Name == "stripe-customer-id" && len(attr.Value) > 0 { return attr.Value[0], nil } } return "", nil } func (c *Client) FindUserByStripeID(stripeCustomerID string) (string, error) { query := `query { users(filters: {}) { id attributes { name value } } }` data, err := c.gql.exec(query, nil) if err != nil { return "", err } var result struct { Users []struct { ID string `json:"id"` Attributes []struct { Name string `json:"name"` Value []string `json:"value"` } `json:"attributes"` } `json:"users"` } if err := json.Unmarshal(data, &result); err != nil { return "", err } for _, u := range result.Users { for _, attr := range u.Attributes { if attr.Name == "stripe-customer-id" && len(attr.Value) > 0 && attr.Value[0] == stripeCustomerID { return u.ID, nil } } } return "", fmt.Errorf("no user found with stripe customer %s", stripeCustomerID) } // CountCustomers returns the number of users who have completed checkout (have stripe-customer-id). // Used to enforce signup limits. func (c *Client) CountCustomers() (int, error) { query := `query { users(filters: {}) { id attributes { name value } } }` data, err := c.gql.exec(query, nil) if err != nil { return 0, err } var result struct { Users []struct { Attributes []struct { Name string `json:"name"` Value []string `json:"value"` } `json:"attributes"` } `json:"users"` } if err := json.Unmarshal(data, &result); err != nil { return 0, err } n := 0 for _, u := range result.Users { for _, attr := range u.Attributes { if attr.Name == "stripe-customer-id" && len(attr.Value) > 0 && attr.Value[0] != "" { n++ break } } } return n, nil } func (c *Client) getGroupID(groupName string) (int, error) { query := `query { groups { id displayName } }` data, err := c.gql.exec(query, nil) if err != nil { return 0, err } var result struct { Groups []struct { ID int `json:"id"` DisplayName string `json:"displayName"` } `json:"groups"` } if err := json.Unmarshal(data, &result); err != nil { return 0, err } for _, g := range result.Groups { if g.DisplayName == groupName { return g.ID, nil } } return 0, fmt.Errorf("group %s not found", groupName) } func (c *Client) userExists(conn *goldap.Conn, username string) (bool, error) { searchReq := goldap.NewSearchRequest( fmt.Sprintf("ou=people,%s", c.cfg.LDAPBaseDN), goldap.ScopeWholeSubtree, goldap.NeverDerefAliases, 1, 0, false, fmt.Sprintf("(uid=%s)", goldap.EscapeFilter(username)), []string{"uid"}, nil, ) result, err := conn.Search(searchReq) if err != nil { return false, fmt.Errorf("ldap search: %w", err) } return len(result.Entries) > 0, nil } func generatePassword() string { b := make([]byte, 18) if _, err := rand.Read(b); err != nil { panic("crypto/rand failed: " + err.Error()) } return base64.URLEncoding.EncodeToString(b) }