205 lines
5.0 KiB
Go
205 lines
5.0 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// TunnelRequest is the metadata a client sends when opening a tunnel channel.
|
|
type TunnelRequest struct {
|
|
Domain string `json:"domain"`
|
|
}
|
|
|
|
// SSHServer handles incoming SSH connections and sets up reverse tunnels.
|
|
type SSHServer struct {
|
|
config *ssh.ServerConfig
|
|
pool *PortPool
|
|
labels *LabelManager
|
|
mu sync.Mutex
|
|
activeTuns map[string]*activeTunnel // keyed by sanitized domain
|
|
}
|
|
|
|
type activeTunnel struct {
|
|
domain string
|
|
port int
|
|
listener net.Listener
|
|
done chan struct{}
|
|
connKey string // tracks which SSH connection owns this tunnel
|
|
}
|
|
|
|
// NewSSHServer creates a new SSH server with host key and authorized keys.
|
|
func NewSSHServer(
|
|
hostKeyPath, authorizedKeysPath string,
|
|
pool *PortPool,
|
|
labels *LabelManager,
|
|
) (*SSHServer, error) {
|
|
s := &SSHServer{
|
|
pool: pool,
|
|
labels: labels,
|
|
activeTuns: make(map[string]*activeTunnel),
|
|
}
|
|
|
|
config := &ssh.ServerConfig{
|
|
PublicKeyCallback: s.buildAuthCallback(authorizedKeysPath),
|
|
}
|
|
|
|
hostKeyBytes, err := os.ReadFile(hostKeyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read host key: %w", err)
|
|
}
|
|
|
|
hostKey, err := ssh.ParsePrivateKey(hostKeyBytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse host key: %w", err)
|
|
}
|
|
config.AddHostKey(hostKey)
|
|
|
|
s.config = config
|
|
return s, nil
|
|
}
|
|
|
|
// buildAuthCallback loads authorized keys and returns a public key callback.
|
|
func (s *SSHServer) buildAuthCallback(
|
|
path string,
|
|
) func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error) {
|
|
allowed := make(map[string]bool)
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
log.Printf("WARN: cannot read authorized_keys at %s: %v", path, err)
|
|
return func(_ ssh.ConnMetadata, _ ssh.PublicKey) (*ssh.Permissions, error) {
|
|
return nil, fmt.Errorf("no authorized keys configured")
|
|
}
|
|
}
|
|
|
|
for _, line := range strings.Split(string(data), "\n") {
|
|
line = strings.TrimSpace(line)
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
pubKey, _, _, _, parseErr := ssh.ParseAuthorizedKey([]byte(line))
|
|
if parseErr != nil {
|
|
log.Printf("WARN: skipping bad authorized key: %v", parseErr)
|
|
continue
|
|
}
|
|
allowed[string(pubKey.Marshal())] = true
|
|
}
|
|
|
|
log.Printf("Loaded %d authorized key(s)", len(allowed))
|
|
|
|
return func(_ ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
if allowed[string(key.Marshal())] {
|
|
return &ssh.Permissions{}, nil
|
|
}
|
|
return nil, fmt.Errorf("unknown public key")
|
|
}
|
|
}
|
|
|
|
// ListenAndServe starts the SSH server on the given address.
|
|
func (s *SSHServer) ListenAndServe(addr string) error {
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen on %s: %w", addr, err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
log.Printf("SSH server listening on %s", addr)
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
log.Printf("accept error: %v", err)
|
|
continue
|
|
}
|
|
go s.handleConn(conn)
|
|
}
|
|
}
|
|
|
|
// handleConn performs the SSH handshake and processes channels.
|
|
func (s *SSHServer) handleConn(netConn net.Conn) {
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(netConn, s.config)
|
|
if err != nil {
|
|
log.Printf("SSH handshake failed from %s: %v", netConn.RemoteAddr(), err)
|
|
netConn.Close()
|
|
return
|
|
}
|
|
|
|
connKey := sshConn.RemoteAddr().String()
|
|
log.Printf("SSH connection from %s (%s)", connKey, sshConn.User())
|
|
|
|
go s.handleGlobalRequests(reqs, sshConn, connKey)
|
|
|
|
for newChan := range chans {
|
|
switch newChan.ChannelType() {
|
|
case "tunnel-request":
|
|
go s.handleTunnelChannel(newChan, connKey)
|
|
case "session":
|
|
ch, _, chErr := newChan.Accept()
|
|
if chErr == nil {
|
|
ch.Close()
|
|
}
|
|
default:
|
|
newChan.Reject(ssh.UnknownChannelType, "unsupported channel type")
|
|
}
|
|
}
|
|
|
|
s.cleanupConnection(connKey)
|
|
log.Printf("SSH connection closed from %s", connKey)
|
|
}
|
|
|
|
// handleTunnelChannel reads tunnel metadata from the custom channel.
|
|
func (s *SSHServer) handleTunnelChannel(newChan ssh.NewChannel, connKey string) {
|
|
ch, _, err := newChan.Accept()
|
|
if err != nil {
|
|
log.Printf("failed to accept tunnel channel: %v", err)
|
|
return
|
|
}
|
|
|
|
buf := make([]byte, 4096)
|
|
n, err := ch.Read(buf)
|
|
if err != nil && err != io.EOF {
|
|
log.Printf("failed to read tunnel metadata: %v", err)
|
|
ch.Close()
|
|
return
|
|
}
|
|
|
|
var req TunnelRequest
|
|
if err := json.Unmarshal(buf[:n], &req); err != nil {
|
|
log.Printf("invalid tunnel metadata: %v", err)
|
|
ch.Close()
|
|
return
|
|
}
|
|
|
|
if req.Domain == "" {
|
|
ch.Close()
|
|
return
|
|
}
|
|
|
|
log.Printf("Tunnel metadata received: domain=%s (conn=%s)", req.Domain, connKey)
|
|
|
|
// Store domain mapping for this connection so forward handler can use it.
|
|
s.mu.Lock()
|
|
s.activeTuns[connKey+"-meta"] = &activeTunnel{
|
|
domain: req.Domain,
|
|
connKey: connKey,
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
// Keep channel open as heartbeat / disconnect signal.
|
|
io.Copy(io.Discard, ch)
|
|
}
|
|
|
|
// SanitizeDomain converts a domain name into a safe label key.
|
|
func SanitizeDomain(domain string) string {
|
|
r := strings.NewReplacer(".", "-", ":", "-", "/", "-")
|
|
return strings.ToLower(r.Replace(strings.TrimSpace(domain)))
|
|
}
|