package server import ( "encoding/json" "fmt" "io" "log" "net" "os" "strings" "sync" "golang.org/x/crypto/ssh" ) // TunnelRequest is the metadata a client sends when opening a tunnel channel. type TunnelRequest struct { Domain string `json:"domain"` AuthUser string `json:"auth_user,omitempty"` AuthPass string `json:"auth_pass,omitempty"` } // SSHServer handles incoming SSH connections and sets up reverse tunnels. type SSHServer struct { config *ssh.ServerConfig pool *PortPool labels *LabelManager mu sync.Mutex activeTuns map[string]*activeTunnel // keyed by sanitized domain } type activeTunnel struct { domain string port int listener net.Listener done chan struct{} connKey string // tracks which SSH connection owns this tunnel authUser string // optional HTTP Basic Auth username authPass string // optional HTTP Basic Auth password } // NewSSHServer creates a new SSH server with host key and authorized keys. func NewSSHServer( hostKeyPath, authorizedKeysPath string, pool *PortPool, labels *LabelManager, ) (*SSHServer, error) { s := &SSHServer{ pool: pool, labels: labels, activeTuns: make(map[string]*activeTunnel), } config := &ssh.ServerConfig{ PublicKeyCallback: s.buildAuthCallback(authorizedKeysPath), } hostKeyBytes, err := os.ReadFile(hostKeyPath) if err != nil { return nil, fmt.Errorf("read host key: %w", err) } hostKey, err := ssh.ParsePrivateKey(hostKeyBytes) if err != nil { return nil, fmt.Errorf("parse host key: %w", err) } config.AddHostKey(hostKey) s.config = config return s, nil } // buildAuthCallback loads authorized keys and returns a public key callback. func (s *SSHServer) buildAuthCallback( path string, ) func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error) { allowed := make(map[string]bool) data, err := os.ReadFile(path) if err != nil { log.Printf("WARN: cannot read authorized_keys at %s: %v", path, err) return func(_ ssh.ConnMetadata, _ ssh.PublicKey) (*ssh.Permissions, error) { return nil, fmt.Errorf("no authorized keys configured") } } for _, line := range strings.Split(string(data), "\n") { line = strings.TrimSpace(line) if line == "" || strings.HasPrefix(line, "#") { continue } pubKey, _, _, _, parseErr := ssh.ParseAuthorizedKey([]byte(line)) if parseErr != nil { log.Printf("WARN: skipping bad authorized key: %v", parseErr) continue } allowed[string(pubKey.Marshal())] = true } log.Printf("Loaded %d authorized key(s)", len(allowed)) return func(_ ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if allowed[string(key.Marshal())] { return &ssh.Permissions{}, nil } return nil, fmt.Errorf("unknown public key") } } // ListenAndServe starts the SSH server on the given address. func (s *SSHServer) ListenAndServe(addr string) error { listener, err := net.Listen("tcp", addr) if err != nil { return fmt.Errorf("listen on %s: %w", addr, err) } defer listener.Close() log.Printf("SSH server listening on %s", addr) for { conn, err := listener.Accept() if err != nil { log.Printf("accept error: %v", err) continue } go s.handleConn(conn) } } // handleConn performs the SSH handshake and processes channels. func (s *SSHServer) handleConn(netConn net.Conn) { sshConn, chans, reqs, err := ssh.NewServerConn(netConn, s.config) if err != nil { log.Printf("SSH handshake failed from %s: %v", netConn.RemoteAddr(), err) netConn.Close() return } connKey := sshConn.RemoteAddr().String() log.Printf("SSH connection from %s (%s)", connKey, sshConn.User()) go s.handleGlobalRequests(reqs, sshConn, connKey) for newChan := range chans { switch newChan.ChannelType() { case "tunnel-request": go s.handleTunnelChannel(newChan, connKey) case "session": ch, _, chErr := newChan.Accept() if chErr == nil { ch.Close() } default: newChan.Reject(ssh.UnknownChannelType, "unsupported channel type") } } s.cleanupConnection(connKey) log.Printf("SSH connection closed from %s", connKey) } // handleTunnelChannel reads tunnel metadata from the custom channel. func (s *SSHServer) handleTunnelChannel(newChan ssh.NewChannel, connKey string) { ch, _, err := newChan.Accept() if err != nil { log.Printf("failed to accept tunnel channel: %v", err) return } buf := make([]byte, 4096) n, err := ch.Read(buf) if err != nil && err != io.EOF { log.Printf("failed to read tunnel metadata: %v", err) ch.Close() return } var req TunnelRequest if err := json.Unmarshal(buf[:n], &req); err != nil { log.Printf("invalid tunnel metadata: %v", err) ch.Close() return } if req.Domain == "" { ch.Close() return } log.Printf("Tunnel metadata received: domain=%s (conn=%s)", req.Domain, connKey) if req.AuthUser != "" && req.AuthPass != "" { log.Printf("Tunnel metadata includes basicauth for domain=%s", req.Domain) } // Store domain mapping for this connection so forward handler can use it. s.mu.Lock() s.activeTuns[connKey+"-meta"] = &activeTunnel{ domain: req.Domain, connKey: connKey, authUser: req.AuthUser, authPass: req.AuthPass, } s.mu.Unlock() // Keep channel open as heartbeat / disconnect signal. io.Copy(io.Discard, ch) } // SanitizeDomain converts a domain name into a safe label key. func SanitizeDomain(domain string) string { r := strings.NewReplacer(".", "-", ":", "-", "/", "-") return strings.ToLower(r.Replace(strings.TrimSpace(domain))) }