120 lines
3.1 KiB
Go
120 lines
3.1 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type ConnHandler struct {
|
|
BackendAddr string
|
|
MaxDuration time.Duration
|
|
Dialer func(network, addr string) (net.Conn, error)
|
|
TimerFunc func(d time.Duration, f func()) *time.Timer
|
|
}
|
|
|
|
func NewConnHandler(backendAddr string, maxDuration time.Duration) *ConnHandler {
|
|
return &ConnHandler{
|
|
BackendAddr: backendAddr,
|
|
MaxDuration: maxDuration,
|
|
Dialer: net.Dial,
|
|
TimerFunc: time.AfterFunc,
|
|
}
|
|
}
|
|
|
|
func (h *ConnHandler) HandleConnection(clientConn net.Conn) {
|
|
connID := uuid.New().String()
|
|
log.Printf("New connection [%s] started from %s", connID, clientConn.RemoteAddr())
|
|
defer clientConn.Close()
|
|
|
|
backendConn, err := h.Dialer("tcp", h.BackendAddr)
|
|
if err != nil {
|
|
log.Printf("Failed to connect to backend [%s]: %s", connID, err)
|
|
return
|
|
}
|
|
defer backendConn.Close()
|
|
|
|
timer := h.TimerFunc(h.MaxDuration, func() {
|
|
log.Printf("Connection [%s] exceeded max duration, terminating", connID)
|
|
clientConn.Close()
|
|
backendConn.Close()
|
|
})
|
|
defer timer.Stop()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
if _, err := io.Copy(backendConn, clientConn); err != nil {
|
|
log.Printf("Error forwarding from client to backend [%s]: %s", connID, err)
|
|
}
|
|
cancel()
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
if _, err := io.Copy(clientConn, backendConn); err != nil {
|
|
log.Printf("Error forwarding from backend to client [%s]: %s", connID, err)
|
|
}
|
|
cancel()
|
|
}()
|
|
<-ctx.Done()
|
|
timer.Stop()
|
|
wg.Wait()
|
|
log.Printf("Connection [%s] terminated", connID)
|
|
}
|
|
|
|
func main() {
|
|
listenAddr := os.Getenv("LISTEN_ADDR")
|
|
if listenAddr == "" {
|
|
listenAddr = ":2222" // Default listen address
|
|
}
|
|
|
|
backendAddr := os.Getenv("SSH_BACKEND")
|
|
maxDuration, err := strconv.Atoi(os.Getenv("SSH_MAX_DURATION"))
|
|
if err != nil {
|
|
log.Fatalf("Invalid SSH_MAX_DURATION value: %s", err)
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", listenAddr)
|
|
if err != nil {
|
|
log.Fatalf("Failed to open listener: %s", err)
|
|
}
|
|
defer listener.Close()
|
|
log.Println("Listening on", listenAddr)
|
|
|
|
handler := NewConnHandler(backendAddr, time.Duration(maxDuration)*time.Second)
|
|
|
|
// Handling graceful shutdown
|
|
stopChan := make(chan os.Signal, 1)
|
|
signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM)
|
|
go func() {
|
|
<-stopChan
|
|
log.Println("Shutting down server...")
|
|
listener.Close()
|
|
}()
|
|
|
|
for {
|
|
clientConn, err := listener.Accept()
|
|
if err != nil {
|
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
|
log.Printf("Temporary listener error: %v", err)
|
|
continue
|
|
}
|
|
log.Printf("Listener closed, stopping accept loop: %v", err)
|
|
break
|
|
}
|
|
go handler.HandleConnection(clientConn)
|
|
}
|
|
}
|
|
|