ssh-timeout/main.go

120 lines
3.1 KiB
Go

package main
import (
"context"
"io"
"log"
"net"
"os"
"os/signal"
"strconv"
"sync"
"syscall"
"time"
"github.com/google/uuid"
)
type ConnHandler struct {
BackendAddr string
MaxDuration time.Duration
Dialer func(network, addr string) (net.Conn, error)
TimerFunc func(d time.Duration, f func()) *time.Timer
}
func NewConnHandler(backendAddr string, maxDuration time.Duration) *ConnHandler {
return &ConnHandler{
BackendAddr: backendAddr,
MaxDuration: maxDuration,
Dialer: net.Dial,
TimerFunc: time.AfterFunc,
}
}
func (h *ConnHandler) HandleConnection(clientConn net.Conn) {
connID := uuid.New().String()
log.Printf("New connection [%s] started from %s", connID, clientConn.RemoteAddr())
defer clientConn.Close()
backendConn, err := h.Dialer("tcp", h.BackendAddr)
if err != nil {
log.Printf("Failed to connect to backend [%s]: %s", connID, err)
return
}
defer backendConn.Close()
timer := h.TimerFunc(h.MaxDuration, func() {
log.Printf("Connection [%s] exceeded max duration, terminating", connID)
clientConn.Close()
backendConn.Close()
})
defer timer.Stop()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
if _, err := io.Copy(backendConn, clientConn); err != nil {
log.Printf("Error forwarding from client to backend [%s]: %s", connID, err)
}
cancel()
}()
go func() {
defer wg.Done()
if _, err := io.Copy(clientConn, backendConn); err != nil {
log.Printf("Error forwarding from backend to client [%s]: %s", connID, err)
}
cancel()
}()
<-ctx.Done()
timer.Stop()
wg.Wait()
log.Printf("Connection [%s] terminated", connID)
}
func main() {
listenAddr := os.Getenv("LISTEN_ADDR")
if listenAddr == "" {
listenAddr = ":2222" // Default listen address
}
backendAddr := os.Getenv("SSH_BACKEND")
maxDuration, err := strconv.Atoi(os.Getenv("SSH_MAX_DURATION"))
if err != nil {
log.Fatalf("Invalid SSH_MAX_DURATION value: %s", err)
}
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
log.Fatalf("Failed to open listener: %s", err)
}
defer listener.Close()
log.Println("Listening on", listenAddr)
handler := NewConnHandler(backendAddr, time.Duration(maxDuration)*time.Second)
// Handling graceful shutdown
stopChan := make(chan os.Signal, 1)
signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-stopChan
log.Println("Shutting down server...")
listener.Close()
}()
for {
clientConn, err := listener.Accept()
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() {
log.Printf("Temporary listener error: %v", err)
continue
}
log.Printf("Listener closed, stopping accept loop: %v", err)
break
}
go handler.HandleConnection(clientConn)
}
}