better-argo-tunnels/internal/server/tunnel.go

295 lines
7.1 KiB
Go

package server
import (
"fmt"
"io"
"log"
"net"
"strings"
"sync"
"golang.org/x/crypto/ssh"
)
// PortPool manages a pool of available ports for reverse tunnels.
type PortPool struct {
mu sync.Mutex
available map[int]bool
start int
end int
}
// NewPortPool creates a port pool for the given range [start, end].
func NewPortPool(start, end int) *PortPool {
available := make(map[int]bool, end-start+1)
for p := start; p <= end; p++ {
available[p] = true
}
return &PortPool{
available: available,
start: start,
end: end,
}
}
// Allocate claims an available port from the pool.
func (pp *PortPool) Allocate() (int, error) {
pp.mu.Lock()
defer pp.mu.Unlock()
for port, free := range pp.available {
if free {
pp.available[port] = false
return port, nil
}
}
return 0, fmt.Errorf("no ports available in range %d-%d", pp.start, pp.end)
}
// Release returns a port to the pool.
func (pp *PortPool) Release(port int) {
pp.mu.Lock()
defer pp.mu.Unlock()
pp.available[port] = true
}
// tcpipForwardRequest matches the SSH tcpip-forward request payload.
type tcpipForwardRequest struct {
BindAddr string
BindPort uint32
}
// tcpipForwardResponse matches the SSH tcpip-forward response payload.
type tcpipForwardResponse struct {
BoundPort uint32
}
// forwardedTCPPayload matches the SSH forwarded-tcpip channel data.
type forwardedTCPPayload struct {
Addr string
Port uint32
OriginAddr string
OriginPort uint32
}
// handleGlobalRequests processes SSH global requests (tcpip-forward).
func (s *SSHServer) handleGlobalRequests(
reqs <-chan *ssh.Request,
sshConn *ssh.ServerConn,
connKey string,
) {
for req := range reqs {
switch req.Type {
case "tcpip-forward":
s.handleForwardRequest(req, sshConn, connKey)
default:
if req.WantReply {
req.Reply(false, nil)
}
}
}
}
// handleForwardRequest handles a tcpip-forward global request.
func (s *SSHServer) handleForwardRequest(
req *ssh.Request,
sshConn *ssh.ServerConn,
connKey string,
) {
var fwdReq tcpipForwardRequest
if err := ssh.Unmarshal(req.Payload, &fwdReq); err != nil {
log.Printf("invalid tcpip-forward payload: %v", err)
req.Reply(false, nil)
return
}
log.Printf("tcpip-forward request: bind=%s:%d from %s", fwdReq.BindAddr, fwdReq.BindPort, connKey)
// Allocate a port from the pool.
port, err := s.pool.Allocate()
if err != nil {
log.Printf("port allocation failed: %v", err)
req.Reply(false, nil)
return
}
// Start a listener on the allocated port.
listenAddr := fmt.Sprintf("0.0.0.0:%d", port)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
log.Printf("failed to listen on %s: %v", listenAddr, err)
s.pool.Release(port)
req.Reply(false, nil)
return
}
// Reply with the bound port.
resp := tcpipForwardResponse{BoundPort: uint32(port)}
req.Reply(true, ssh.Marshal(&resp))
log.Printf("Allocated port %d for forwarding (conn=%s)", port, connKey)
// Accept connections on the allocated port and forward through SSH.
done := make(chan struct{})
go func() {
defer close(done)
acceptForwardedConnections(listener, sshConn, fwdReq.BindAddr, uint32(port))
}()
// Determine the domain for Traefik label registration.
// Look up the metadata channel first, fall back to bind address.
domain := fwdReq.BindAddr
var authUser, authPass string
s.mu.Lock()
if meta, ok := s.activeTuns[connKey+"-meta"]; ok {
domain = meta.domain
authUser = meta.authUser
authPass = meta.authPass
}
s.mu.Unlock()
tunKey := SanitizeDomain(domain)
if tunKey == "" {
tunKey = fmt.Sprintf("port-%d", port)
}
tun := &activeTunnel{
domain: domain,
port: port,
listener: listener,
done: done,
connKey: connKey,
authUser: authUser,
authPass: authPass,
}
// If a previous tunnel exists for this domain (reconnect), tear it down
// so we don't leak ports/listeners and so Traefik labels stay consistent.
s.mu.Lock()
if old, exists := s.activeTuns[tunKey]; exists && old.listener != nil {
log.Printf("Replacing stale tunnel %s (old port %d, new port %d)", tunKey, old.port, port)
old.listener.Close()
s.pool.Release(old.port)
}
s.activeTuns[tunKey] = tun
s.mu.Unlock()
// Register Traefik labels (with optional basicauth middleware).
if err := s.labels.Add(tunKey, domain, port, authUser, authPass); err != nil {
log.Printf("WARN: failed to add Traefik labels for %s: %v", domain, err)
} else {
log.Printf("Traefik labels added for %s -> port %d", domain, port)
}
}
// acceptForwardedConnections accepts TCP connections and opens SSH channels.
func acceptForwardedConnections(
listener net.Listener,
sshConn *ssh.ServerConn,
bindAddr string,
bindPort uint32,
) {
for {
conn, err := listener.Accept()
if err != nil {
return // listener closed
}
go forwardConnection(conn, sshConn, bindAddr, bindPort)
}
}
// forwardConnection forwards a single TCP connection through the SSH channel.
func forwardConnection(
conn net.Conn,
sshConn *ssh.ServerConn,
bindAddr string,
bindPort uint32,
) {
defer conn.Close()
originAddr, originPortStr, _ := net.SplitHostPort(conn.RemoteAddr().String())
var originPort int
fmt.Sscanf(originPortStr, "%d", &originPort)
payload := forwardedTCPPayload{
Addr: bindAddr,
Port: bindPort,
OriginAddr: originAddr,
OriginPort: uint32(originPort),
}
ch, reqs, err := sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(&payload))
if err != nil {
log.Printf("failed to open forwarded-tcpip channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
// Bidirectional copy.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(ch, conn)
ch.CloseWrite()
}()
go func() {
defer wg.Done()
io.Copy(conn, ch)
}()
wg.Wait()
}
// cleanupConnection removes tunnels associated with a closed SSH connection.
// It collects work under the lock, then performs slow label removal outside it.
func (s *SSHServer) cleanupConnection(connKey string) {
type cleanupItem struct {
key string
port int
listener net.Listener
ownedNow bool // true if this connKey still owns the map entry
}
s.mu.Lock()
var items []cleanupItem
for key, tun := range s.activeTuns {
if tun.connKey != connKey {
continue
}
items = append(items, cleanupItem{
key: key,
port: tun.port,
listener: tun.listener,
ownedNow: true,
})
delete(s.activeTuns, key)
}
s.mu.Unlock()
for _, item := range items {
if item.listener != nil {
item.listener.Close()
s.pool.Release(item.port)
}
// Re-check ownership: a reconnecting client may have inserted a new
// entry for this key between our Unlock above and now.
s.mu.Lock()
current, replaced := s.activeTuns[item.key]
s.mu.Unlock()
if replaced && current.connKey != connKey {
log.Printf("Skipping label removal for %s — replaced by conn %s", item.key, current.connKey)
continue
}
if strings.HasSuffix(item.key, "-meta") {
continue
}
if err := s.labels.Remove(item.key); err != nil {
log.Printf("WARN: failed to remove labels for %s: %v", item.key, err)
}
log.Printf("Cleaned up tunnel %s (port %d, conn=%s)", item.key, item.port, connKey)
}
}