package main import ( "context" "io" "log" "net" "os" "os/signal" "strconv" "sync" "syscall" "time" "github.com/google/uuid" ) type ConnHandler struct { BackendAddr string MaxDuration time.Duration Dialer func(network, addr string) (net.Conn, error) TimerFunc func(d time.Duration, f func()) *time.Timer } func NewConnHandler(backendAddr string, maxDuration time.Duration) *ConnHandler { return &ConnHandler{ BackendAddr: backendAddr, MaxDuration: maxDuration, Dialer: net.Dial, TimerFunc: time.AfterFunc, } } func (h *ConnHandler) HandleConnection(clientConn net.Conn) { connID := uuid.New().String() log.Printf("New connection [%s] started from %s", connID, clientConn.RemoteAddr()) defer clientConn.Close() backendConn, err := h.Dialer("tcp", h.BackendAddr) if err != nil { log.Printf("Failed to connect to backend [%s]: %s", connID, err) return } defer backendConn.Close() timer := h.TimerFunc(h.MaxDuration, func() { log.Printf("Connection [%s] exceeded max duration, terminating", connID) clientConn.Close() backendConn.Close() }) defer timer.Stop() ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() if _, err := io.Copy(backendConn, clientConn); err != nil { log.Printf("Error forwarding from client to backend [%s]: %s", connID, err) } cancel() }() go func() { defer wg.Done() if _, err := io.Copy(clientConn, backendConn); err != nil { log.Printf("Error forwarding from backend to client [%s]: %s", connID, err) } cancel() }() <-ctx.Done() timer.Stop() wg.Wait() log.Printf("Connection [%s] terminated", connID) } func main() { listenAddr := os.Getenv("LISTEN_ADDR") if listenAddr == "" { listenAddr = ":2222" // Default listen address } backendAddr := os.Getenv("SSH_BACKEND") maxDuration, err := strconv.Atoi(os.Getenv("SSH_MAX_DURATION")) if err != nil { log.Fatalf("Invalid SSH_MAX_DURATION value: %s", err) } listener, err := net.Listen("tcp", listenAddr) if err != nil { log.Fatalf("Failed to open listener: %s", err) } defer listener.Close() log.Println("Listening on", listenAddr) handler := NewConnHandler(backendAddr, time.Duration(maxDuration)*time.Second) // Handling graceful shutdown stopChan := make(chan os.Signal, 1) signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM) go func() { <-stopChan log.Println("Shutting down server...") listener.Close() }() for { clientConn, err := listener.Accept() if err != nil { if ne, ok := err.(net.Error); ok && ne.Temporary() { log.Printf("Temporary listener error: %v", err) continue } log.Printf("Listener closed, stopping accept loop: %v", err) break } go handler.HandleConnection(clientConn) } }