better-argo-tunnels/internal/server/ssh.go

215 lines
5.4 KiB
Go

package server
import (
"encoding/json"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"sync"
"golang.org/x/crypto/ssh"
)
// TunnelRequest is the metadata a client sends when opening a tunnel channel.
type TunnelRequest struct {
Domain string `json:"domain"`
AuthUser string `json:"auth_user,omitempty"`
AuthPass string `json:"auth_pass,omitempty"`
}
// SSHServer handles incoming SSH connections and sets up reverse tunnels.
type SSHServer struct {
config *ssh.ServerConfig
pool *PortPool
labels *LabelManager
mu sync.Mutex
activeTuns map[string]*activeTunnel // keyed by sanitized domain
}
type activeTunnel struct {
domain string
port int
listener net.Listener
done chan struct{}
connKey string // tracks which SSH connection owns this tunnel
authUser string // optional HTTP Basic Auth username
authPass string // optional HTTP Basic Auth password
}
// NewSSHServer creates a new SSH server with host key and authorized keys.
func NewSSHServer(
hostKeyPath, authorizedKeysPath string,
pool *PortPool,
labels *LabelManager,
) (*SSHServer, error) {
s := &SSHServer{
pool: pool,
labels: labels,
activeTuns: make(map[string]*activeTunnel),
}
config := &ssh.ServerConfig{
PublicKeyCallback: s.buildAuthCallback(authorizedKeysPath),
}
hostKeyBytes, err := os.ReadFile(hostKeyPath)
if err != nil {
return nil, fmt.Errorf("read host key: %w", err)
}
hostKey, err := ssh.ParsePrivateKey(hostKeyBytes)
if err != nil {
return nil, fmt.Errorf("parse host key: %w", err)
}
config.AddHostKey(hostKey)
s.config = config
return s, nil
}
// buildAuthCallback loads authorized keys and returns a public key callback.
func (s *SSHServer) buildAuthCallback(
path string,
) func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error) {
allowed := make(map[string]bool)
data, err := os.ReadFile(path)
if err != nil {
log.Printf("WARN: cannot read authorized_keys at %s: %v", path, err)
return func(_ ssh.ConnMetadata, _ ssh.PublicKey) (*ssh.Permissions, error) {
return nil, fmt.Errorf("no authorized keys configured")
}
}
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
pubKey, _, _, _, parseErr := ssh.ParseAuthorizedKey([]byte(line))
if parseErr != nil {
log.Printf("WARN: skipping bad authorized key: %v", parseErr)
continue
}
allowed[string(pubKey.Marshal())] = true
}
log.Printf("Loaded %d authorized key(s)", len(allowed))
return func(_ ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if allowed[string(key.Marshal())] {
return &ssh.Permissions{}, nil
}
return nil, fmt.Errorf("unknown public key")
}
}
// ListenAndServe starts the SSH server on the given address.
func (s *SSHServer) ListenAndServe(addr string) error {
listener, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
defer listener.Close()
log.Printf("SSH server listening on %s", addr)
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("accept error: %v", err)
continue
}
go s.handleConn(conn)
}
}
// handleConn performs the SSH handshake and processes channels.
func (s *SSHServer) handleConn(netConn net.Conn) {
sshConn, chans, reqs, err := ssh.NewServerConn(netConn, s.config)
if err != nil {
log.Printf("SSH handshake failed from %s: %v", netConn.RemoteAddr(), err)
netConn.Close()
return
}
connKey := sshConn.RemoteAddr().String()
log.Printf("SSH connection from %s (%s)", connKey, sshConn.User())
go s.handleGlobalRequests(reqs, sshConn, connKey)
for newChan := range chans {
switch newChan.ChannelType() {
case "tunnel-request":
go s.handleTunnelChannel(newChan, connKey)
case "session":
ch, _, chErr := newChan.Accept()
if chErr == nil {
ch.Close()
}
default:
newChan.Reject(ssh.UnknownChannelType, "unsupported channel type")
}
}
s.cleanupConnection(connKey)
log.Printf("SSH connection closed from %s", connKey)
}
// handleTunnelChannel reads tunnel metadata from the custom channel.
func (s *SSHServer) handleTunnelChannel(newChan ssh.NewChannel, connKey string) {
ch, _, err := newChan.Accept()
if err != nil {
log.Printf("failed to accept tunnel channel: %v", err)
return
}
buf := make([]byte, 4096)
n, err := ch.Read(buf)
if err != nil && err != io.EOF {
log.Printf("failed to read tunnel metadata: %v", err)
ch.Close()
return
}
var req TunnelRequest
if err := json.Unmarshal(buf[:n], &req); err != nil {
log.Printf("invalid tunnel metadata: %v", err)
ch.Close()
return
}
if req.Domain == "" {
ch.Close()
return
}
log.Printf("Tunnel metadata received: domain=%s (conn=%s)", req.Domain, connKey)
if req.AuthUser != "" && req.AuthPass != "" {
log.Printf("Tunnel metadata includes basicauth for domain=%s", req.Domain)
}
// Store domain mapping for this connection so forward handler can use it.
s.mu.Lock()
s.activeTuns[connKey+"-meta"] = &activeTunnel{
domain: req.Domain,
connKey: connKey,
authUser: req.AuthUser,
authPass: req.AuthPass,
}
s.mu.Unlock()
// Keep channel open as heartbeat / disconnect signal.
io.Copy(io.Discard, ch)
}
// SanitizeDomain converts a domain name into a safe label key.
func SanitizeDomain(domain string) string {
r := strings.NewReplacer(".", "-", ":", "-", "/", "-")
return strings.ToLower(r.Replace(strings.TrimSpace(domain)))
}