257 lines
5.8 KiB
Go
257 lines
5.8 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// PortPool manages a pool of available ports for reverse tunnels.
|
|
type PortPool struct {
|
|
mu sync.Mutex
|
|
available map[int]bool
|
|
start int
|
|
end int
|
|
}
|
|
|
|
// NewPortPool creates a port pool for the given range [start, end].
|
|
func NewPortPool(start, end int) *PortPool {
|
|
available := make(map[int]bool, end-start+1)
|
|
for p := start; p <= end; p++ {
|
|
available[p] = true
|
|
}
|
|
return &PortPool{
|
|
available: available,
|
|
start: start,
|
|
end: end,
|
|
}
|
|
}
|
|
|
|
// Allocate claims an available port from the pool.
|
|
func (pp *PortPool) Allocate() (int, error) {
|
|
pp.mu.Lock()
|
|
defer pp.mu.Unlock()
|
|
|
|
for port, free := range pp.available {
|
|
if free {
|
|
pp.available[port] = false
|
|
return port, nil
|
|
}
|
|
}
|
|
return 0, fmt.Errorf("no ports available in range %d-%d", pp.start, pp.end)
|
|
}
|
|
|
|
// Release returns a port to the pool.
|
|
func (pp *PortPool) Release(port int) {
|
|
pp.mu.Lock()
|
|
defer pp.mu.Unlock()
|
|
pp.available[port] = true
|
|
}
|
|
|
|
// tcpipForwardRequest matches the SSH tcpip-forward request payload.
|
|
type tcpipForwardRequest struct {
|
|
BindAddr string
|
|
BindPort uint32
|
|
}
|
|
|
|
// tcpipForwardResponse matches the SSH tcpip-forward response payload.
|
|
type tcpipForwardResponse struct {
|
|
BoundPort uint32
|
|
}
|
|
|
|
// forwardedTCPPayload matches the SSH forwarded-tcpip channel data.
|
|
type forwardedTCPPayload struct {
|
|
Addr string
|
|
Port uint32
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
|
|
// handleGlobalRequests processes SSH global requests (tcpip-forward).
|
|
func (s *SSHServer) handleGlobalRequests(
|
|
reqs <-chan *ssh.Request,
|
|
sshConn *ssh.ServerConn,
|
|
connKey string,
|
|
) {
|
|
for req := range reqs {
|
|
switch req.Type {
|
|
case "tcpip-forward":
|
|
s.handleForwardRequest(req, sshConn, connKey)
|
|
default:
|
|
if req.WantReply {
|
|
req.Reply(false, nil)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleForwardRequest handles a tcpip-forward global request.
|
|
func (s *SSHServer) handleForwardRequest(
|
|
req *ssh.Request,
|
|
sshConn *ssh.ServerConn,
|
|
connKey string,
|
|
) {
|
|
var fwdReq tcpipForwardRequest
|
|
if err := ssh.Unmarshal(req.Payload, &fwdReq); err != nil {
|
|
log.Printf("invalid tcpip-forward payload: %v", err)
|
|
req.Reply(false, nil)
|
|
return
|
|
}
|
|
|
|
log.Printf("tcpip-forward request: bind=%s:%d from %s", fwdReq.BindAddr, fwdReq.BindPort, connKey)
|
|
|
|
// Allocate a port from the pool.
|
|
port, err := s.pool.Allocate()
|
|
if err != nil {
|
|
log.Printf("port allocation failed: %v", err)
|
|
req.Reply(false, nil)
|
|
return
|
|
}
|
|
|
|
// Start a listener on the allocated port.
|
|
listenAddr := fmt.Sprintf("0.0.0.0:%d", port)
|
|
listener, err := net.Listen("tcp", listenAddr)
|
|
if err != nil {
|
|
log.Printf("failed to listen on %s: %v", listenAddr, err)
|
|
s.pool.Release(port)
|
|
req.Reply(false, nil)
|
|
return
|
|
}
|
|
|
|
// Reply with the bound port.
|
|
resp := tcpipForwardResponse{BoundPort: uint32(port)}
|
|
req.Reply(true, ssh.Marshal(&resp))
|
|
log.Printf("Allocated port %d for forwarding (conn=%s)", port, connKey)
|
|
|
|
// Accept connections on the allocated port and forward through SSH.
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
acceptForwardedConnections(listener, sshConn, fwdReq.BindAddr, uint32(port))
|
|
}()
|
|
|
|
// Determine the domain for Traefik label registration.
|
|
// Look up the metadata channel first, fall back to bind address.
|
|
domain := fwdReq.BindAddr
|
|
var authUser, authPass string
|
|
s.mu.Lock()
|
|
if meta, ok := s.activeTuns[connKey+"-meta"]; ok {
|
|
domain = meta.domain
|
|
authUser = meta.authUser
|
|
authPass = meta.authPass
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
tunKey := SanitizeDomain(domain)
|
|
if tunKey == "" {
|
|
tunKey = fmt.Sprintf("port-%d", port)
|
|
}
|
|
|
|
tun := &activeTunnel{
|
|
domain: domain,
|
|
port: port,
|
|
listener: listener,
|
|
done: done,
|
|
connKey: connKey,
|
|
authUser: authUser,
|
|
authPass: authPass,
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.activeTuns[tunKey] = tun
|
|
s.mu.Unlock()
|
|
|
|
// Register Traefik labels (with optional basicauth middleware).
|
|
if err := s.labels.Add(tunKey, domain, port, authUser, authPass); err != nil {
|
|
log.Printf("WARN: failed to add Traefik labels for %s: %v", domain, err)
|
|
} else {
|
|
log.Printf("Traefik labels added for %s -> port %d", domain, port)
|
|
}
|
|
}
|
|
|
|
// acceptForwardedConnections accepts TCP connections and opens SSH channels.
|
|
func acceptForwardedConnections(
|
|
listener net.Listener,
|
|
sshConn *ssh.ServerConn,
|
|
bindAddr string,
|
|
bindPort uint32,
|
|
) {
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return // listener closed
|
|
}
|
|
go forwardConnection(conn, sshConn, bindAddr, bindPort)
|
|
}
|
|
}
|
|
|
|
// forwardConnection forwards a single TCP connection through the SSH channel.
|
|
func forwardConnection(
|
|
conn net.Conn,
|
|
sshConn *ssh.ServerConn,
|
|
bindAddr string,
|
|
bindPort uint32,
|
|
) {
|
|
defer conn.Close()
|
|
|
|
originAddr, originPortStr, _ := net.SplitHostPort(conn.RemoteAddr().String())
|
|
var originPort int
|
|
fmt.Sscanf(originPortStr, "%d", &originPort)
|
|
|
|
payload := forwardedTCPPayload{
|
|
Addr: bindAddr,
|
|
Port: bindPort,
|
|
OriginAddr: originAddr,
|
|
OriginPort: uint32(originPort),
|
|
}
|
|
|
|
ch, reqs, err := sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(&payload))
|
|
if err != nil {
|
|
log.Printf("failed to open forwarded-tcpip channel: %v", err)
|
|
return
|
|
}
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
// Bidirectional copy.
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
io.Copy(ch, conn)
|
|
ch.CloseWrite()
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
io.Copy(conn, ch)
|
|
}()
|
|
wg.Wait()
|
|
}
|
|
|
|
// cleanupConnection removes all tunnels associated with a closed SSH connection.
|
|
func (s *SSHServer) cleanupConnection(connKey string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
for key, tun := range s.activeTuns {
|
|
if tun.connKey != connKey {
|
|
continue
|
|
}
|
|
|
|
if tun.listener != nil {
|
|
tun.listener.Close()
|
|
s.pool.Release(tun.port)
|
|
}
|
|
|
|
if err := s.labels.Remove(key); err != nil {
|
|
log.Printf("WARN: failed to remove labels for %s: %v", key, err)
|
|
}
|
|
|
|
log.Printf("Cleaned up tunnel %s (port %d, conn=%s)", key, tun.port, connKey)
|
|
delete(s.activeTuns, key)
|
|
}
|
|
}
|