package server import ( "fmt" "io" "log" "net" "sync" "golang.org/x/crypto/ssh" ) // PortPool manages a pool of available ports for reverse tunnels. type PortPool struct { mu sync.Mutex available map[int]bool start int end int } // NewPortPool creates a port pool for the given range [start, end]. func NewPortPool(start, end int) *PortPool { available := make(map[int]bool, end-start+1) for p := start; p <= end; p++ { available[p] = true } return &PortPool{ available: available, start: start, end: end, } } // Allocate claims an available port from the pool. func (pp *PortPool) Allocate() (int, error) { pp.mu.Lock() defer pp.mu.Unlock() for port, free := range pp.available { if free { pp.available[port] = false return port, nil } } return 0, fmt.Errorf("no ports available in range %d-%d", pp.start, pp.end) } // Release returns a port to the pool. func (pp *PortPool) Release(port int) { pp.mu.Lock() defer pp.mu.Unlock() pp.available[port] = true } // tcpipForwardRequest matches the SSH tcpip-forward request payload. type tcpipForwardRequest struct { BindAddr string BindPort uint32 } // tcpipForwardResponse matches the SSH tcpip-forward response payload. type tcpipForwardResponse struct { BoundPort uint32 } // forwardedTCPPayload matches the SSH forwarded-tcpip channel data. type forwardedTCPPayload struct { Addr string Port uint32 OriginAddr string OriginPort uint32 } // handleGlobalRequests processes SSH global requests (tcpip-forward). func (s *SSHServer) handleGlobalRequests( reqs <-chan *ssh.Request, sshConn *ssh.ServerConn, connKey string, ) { for req := range reqs { switch req.Type { case "tcpip-forward": s.handleForwardRequest(req, sshConn, connKey) default: if req.WantReply { req.Reply(false, nil) } } } } // handleForwardRequest handles a tcpip-forward global request. func (s *SSHServer) handleForwardRequest( req *ssh.Request, sshConn *ssh.ServerConn, connKey string, ) { var fwdReq tcpipForwardRequest if err := ssh.Unmarshal(req.Payload, &fwdReq); err != nil { log.Printf("invalid tcpip-forward payload: %v", err) req.Reply(false, nil) return } log.Printf("tcpip-forward request: bind=%s:%d from %s", fwdReq.BindAddr, fwdReq.BindPort, connKey) // Allocate a port from the pool. port, err := s.pool.Allocate() if err != nil { log.Printf("port allocation failed: %v", err) req.Reply(false, nil) return } // Start a listener on the allocated port. listenAddr := fmt.Sprintf("0.0.0.0:%d", port) listener, err := net.Listen("tcp", listenAddr) if err != nil { log.Printf("failed to listen on %s: %v", listenAddr, err) s.pool.Release(port) req.Reply(false, nil) return } // Reply with the bound port. resp := tcpipForwardResponse{BoundPort: uint32(port)} req.Reply(true, ssh.Marshal(&resp)) log.Printf("Allocated port %d for forwarding (conn=%s)", port, connKey) // Accept connections on the allocated port and forward through SSH. done := make(chan struct{}) go func() { defer close(done) acceptForwardedConnections(listener, sshConn, fwdReq.BindAddr, uint32(port)) }() // Determine the domain for Traefik label registration. // Look up the metadata channel first, fall back to bind address. domain := fwdReq.BindAddr var authUser, authPass string s.mu.Lock() if meta, ok := s.activeTuns[connKey+"-meta"]; ok { domain = meta.domain authUser = meta.authUser authPass = meta.authPass } s.mu.Unlock() tunKey := SanitizeDomain(domain) if tunKey == "" { tunKey = fmt.Sprintf("port-%d", port) } tun := &activeTunnel{ domain: domain, port: port, listener: listener, done: done, connKey: connKey, authUser: authUser, authPass: authPass, } s.mu.Lock() s.activeTuns[tunKey] = tun s.mu.Unlock() // Register Traefik labels (with optional basicauth middleware). if err := s.labels.Add(tunKey, domain, port, authUser, authPass); err != nil { log.Printf("WARN: failed to add Traefik labels for %s: %v", domain, err) } else { log.Printf("Traefik labels added for %s -> port %d", domain, port) } } // acceptForwardedConnections accepts TCP connections and opens SSH channels. func acceptForwardedConnections( listener net.Listener, sshConn *ssh.ServerConn, bindAddr string, bindPort uint32, ) { for { conn, err := listener.Accept() if err != nil { return // listener closed } go forwardConnection(conn, sshConn, bindAddr, bindPort) } } // forwardConnection forwards a single TCP connection through the SSH channel. func forwardConnection( conn net.Conn, sshConn *ssh.ServerConn, bindAddr string, bindPort uint32, ) { defer conn.Close() originAddr, originPortStr, _ := net.SplitHostPort(conn.RemoteAddr().String()) var originPort int fmt.Sscanf(originPortStr, "%d", &originPort) payload := forwardedTCPPayload{ Addr: bindAddr, Port: bindPort, OriginAddr: originAddr, OriginPort: uint32(originPort), } ch, reqs, err := sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(&payload)) if err != nil { log.Printf("failed to open forwarded-tcpip channel: %v", err) return } go ssh.DiscardRequests(reqs) // Bidirectional copy. var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() io.Copy(ch, conn) ch.CloseWrite() }() go func() { defer wg.Done() io.Copy(conn, ch) }() wg.Wait() } // cleanupConnection removes all tunnels associated with a closed SSH connection. func (s *SSHServer) cleanupConnection(connKey string) { s.mu.Lock() defer s.mu.Unlock() for key, tun := range s.activeTuns { if tun.connKey != connKey { continue } if tun.listener != nil { tun.listener.Close() s.pool.Release(tun.port) } if err := s.labels.Remove(key); err != nil { log.Printf("WARN: failed to remove labels for %s: %v", key, err) } log.Printf("Cleaned up tunnel %s (port %d, conn=%s)", key, tun.port, connKey) delete(s.activeTuns, key) } }