better-argo-tunnels/internal/server/tunnel.go

257 lines
5.8 KiB
Go

package server
import (
"fmt"
"io"
"log"
"net"
"sync"
"golang.org/x/crypto/ssh"
)
// PortPool manages a pool of available ports for reverse tunnels.
type PortPool struct {
mu sync.Mutex
available map[int]bool
start int
end int
}
// NewPortPool creates a port pool for the given range [start, end].
func NewPortPool(start, end int) *PortPool {
available := make(map[int]bool, end-start+1)
for p := start; p <= end; p++ {
available[p] = true
}
return &PortPool{
available: available,
start: start,
end: end,
}
}
// Allocate claims an available port from the pool.
func (pp *PortPool) Allocate() (int, error) {
pp.mu.Lock()
defer pp.mu.Unlock()
for port, free := range pp.available {
if free {
pp.available[port] = false
return port, nil
}
}
return 0, fmt.Errorf("no ports available in range %d-%d", pp.start, pp.end)
}
// Release returns a port to the pool.
func (pp *PortPool) Release(port int) {
pp.mu.Lock()
defer pp.mu.Unlock()
pp.available[port] = true
}
// tcpipForwardRequest matches the SSH tcpip-forward request payload.
type tcpipForwardRequest struct {
BindAddr string
BindPort uint32
}
// tcpipForwardResponse matches the SSH tcpip-forward response payload.
type tcpipForwardResponse struct {
BoundPort uint32
}
// forwardedTCPPayload matches the SSH forwarded-tcpip channel data.
type forwardedTCPPayload struct {
Addr string
Port uint32
OriginAddr string
OriginPort uint32
}
// handleGlobalRequests processes SSH global requests (tcpip-forward).
func (s *SSHServer) handleGlobalRequests(
reqs <-chan *ssh.Request,
sshConn *ssh.ServerConn,
connKey string,
) {
for req := range reqs {
switch req.Type {
case "tcpip-forward":
s.handleForwardRequest(req, sshConn, connKey)
default:
if req.WantReply {
req.Reply(false, nil)
}
}
}
}
// handleForwardRequest handles a tcpip-forward global request.
func (s *SSHServer) handleForwardRequest(
req *ssh.Request,
sshConn *ssh.ServerConn,
connKey string,
) {
var fwdReq tcpipForwardRequest
if err := ssh.Unmarshal(req.Payload, &fwdReq); err != nil {
log.Printf("invalid tcpip-forward payload: %v", err)
req.Reply(false, nil)
return
}
log.Printf("tcpip-forward request: bind=%s:%d from %s", fwdReq.BindAddr, fwdReq.BindPort, connKey)
// Allocate a port from the pool.
port, err := s.pool.Allocate()
if err != nil {
log.Printf("port allocation failed: %v", err)
req.Reply(false, nil)
return
}
// Start a listener on the allocated port.
listenAddr := fmt.Sprintf("0.0.0.0:%d", port)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
log.Printf("failed to listen on %s: %v", listenAddr, err)
s.pool.Release(port)
req.Reply(false, nil)
return
}
// Reply with the bound port.
resp := tcpipForwardResponse{BoundPort: uint32(port)}
req.Reply(true, ssh.Marshal(&resp))
log.Printf("Allocated port %d for forwarding (conn=%s)", port, connKey)
// Accept connections on the allocated port and forward through SSH.
done := make(chan struct{})
go func() {
defer close(done)
acceptForwardedConnections(listener, sshConn, fwdReq.BindAddr, uint32(port))
}()
// Determine the domain for Traefik label registration.
// Look up the metadata channel first, fall back to bind address.
domain := fwdReq.BindAddr
var authUser, authPass string
s.mu.Lock()
if meta, ok := s.activeTuns[connKey+"-meta"]; ok {
domain = meta.domain
authUser = meta.authUser
authPass = meta.authPass
}
s.mu.Unlock()
tunKey := SanitizeDomain(domain)
if tunKey == "" {
tunKey = fmt.Sprintf("port-%d", port)
}
tun := &activeTunnel{
domain: domain,
port: port,
listener: listener,
done: done,
connKey: connKey,
authUser: authUser,
authPass: authPass,
}
s.mu.Lock()
s.activeTuns[tunKey] = tun
s.mu.Unlock()
// Register Traefik labels (with optional basicauth middleware).
if err := s.labels.Add(tunKey, domain, port, authUser, authPass); err != nil {
log.Printf("WARN: failed to add Traefik labels for %s: %v", domain, err)
} else {
log.Printf("Traefik labels added for %s -> port %d", domain, port)
}
}
// acceptForwardedConnections accepts TCP connections and opens SSH channels.
func acceptForwardedConnections(
listener net.Listener,
sshConn *ssh.ServerConn,
bindAddr string,
bindPort uint32,
) {
for {
conn, err := listener.Accept()
if err != nil {
return // listener closed
}
go forwardConnection(conn, sshConn, bindAddr, bindPort)
}
}
// forwardConnection forwards a single TCP connection through the SSH channel.
func forwardConnection(
conn net.Conn,
sshConn *ssh.ServerConn,
bindAddr string,
bindPort uint32,
) {
defer conn.Close()
originAddr, originPortStr, _ := net.SplitHostPort(conn.RemoteAddr().String())
var originPort int
fmt.Sscanf(originPortStr, "%d", &originPort)
payload := forwardedTCPPayload{
Addr: bindAddr,
Port: bindPort,
OriginAddr: originAddr,
OriginPort: uint32(originPort),
}
ch, reqs, err := sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(&payload))
if err != nil {
log.Printf("failed to open forwarded-tcpip channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
// Bidirectional copy.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(ch, conn)
ch.CloseWrite()
}()
go func() {
defer wg.Done()
io.Copy(conn, ch)
}()
wg.Wait()
}
// cleanupConnection removes all tunnels associated with a closed SSH connection.
func (s *SSHServer) cleanupConnection(connKey string) {
s.mu.Lock()
defer s.mu.Unlock()
for key, tun := range s.activeTuns {
if tun.connKey != connKey {
continue
}
if tun.listener != nil {
tun.listener.Close()
s.pool.Release(tun.port)
}
if err := s.labels.Remove(key); err != nil {
log.Printf("WARN: failed to remove labels for %s: %v", key, err)
}
log.Printf("Cleaned up tunnel %s (port %d, conn=%s)", key, tun.port, connKey)
delete(s.activeTuns, key)
}
}