diff --git a/build.sh b/build.sh index 293a43a..b80a04e 100755 --- a/build.sh +++ b/build.sh @@ -11,11 +11,12 @@ prepare_build() { # Create necessary directories if they don't exist mkdir -p dist mkdir -p build_logs + mkdir -p test_logs # Directory for test logs # Initialize go modules if go.mod does not exist if [ ! -f go.mod ]; then echo "Initializing Go modules" - go mod init yourmodule # Replace 'yourmodule' with your actual module name or path + go mod init ssh-timeout # Replace 'yourmodule' with your actual module name or path fi # Fetch and ensure all dependencies are up to date @@ -44,8 +45,24 @@ build_binary() { fi } +# Test function +run_tests() { + if ls *.go | grep '_test.go$' >/dev/null; then + echo "Running tests..." + go test ./... > test_logs/test_output.log 2>&1 + if [ $? -eq 0 ]; then + echo "Tests completed successfully." + else + echo "Tests failed. Check test_logs/test_output.log for details." + fi + else + echo "No test files found, skipping tests." + fi +} + # Main Build Process prepare_build +run_tests # Run tests optionally before the build for arch in "${ARCHITECTURES[@]}"; do IFS='/' read -r -a parts <<< "$arch" # Split architecture string os=${parts[0]} diff --git a/dist/ssh-timeout b/dist/ssh-timeout index acdb95b..9695346 100755 Binary files a/dist/ssh-timeout and b/dist/ssh-timeout differ diff --git a/dist/ssh-timeout_darwin_amd64 b/dist/ssh-timeout_darwin_amd64 index 5718021..849fcb1 100755 Binary files a/dist/ssh-timeout_darwin_amd64 and b/dist/ssh-timeout_darwin_amd64 differ diff --git a/dist/ssh-timeout_darwin_arm64 b/dist/ssh-timeout_darwin_arm64 index 19674ff..5fc1305 100755 Binary files a/dist/ssh-timeout_darwin_arm64 and b/dist/ssh-timeout_darwin_arm64 differ diff --git a/dist/ssh-timeout_linux_arm b/dist/ssh-timeout_linux_arm index e93a00d..1251dbd 100755 Binary files a/dist/ssh-timeout_linux_arm and b/dist/ssh-timeout_linux_arm differ diff --git a/dist/ssh-timeout_linux_arm64 b/dist/ssh-timeout_linux_arm64 index 144e361..277cc3b 100755 Binary files a/dist/ssh-timeout_linux_arm64 and b/dist/ssh-timeout_linux_arm64 differ diff --git a/go.mod b/go.mod index f9030c8..ae3c448 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module yourmodule +module ssh-timeout go 1.21.1 diff --git a/main.go b/main.go index 371671c..e3cc855 100644 --- a/main.go +++ b/main.go @@ -1,22 +1,86 @@ package main import ( + "context" "io" "log" "net" "os" + "os/signal" "strconv" + "sync" + "syscall" "time" "github.com/google/uuid" ) -func main() { - listenAddr := ":2222" // The local port on which to listen for incoming SSH connections. - backendAddr := os.Getenv("SSH_BACKEND") // Backend SSH server address from environment variable. - maxDuration := os.Getenv("SSH_MAX_DURATION") // Max connection duration from environment variable. +type ConnHandler struct { + BackendAddr string + MaxDuration time.Duration + Dialer func(network, addr string) (net.Conn, error) + TimerFunc func(d time.Duration, f func()) *time.Timer +} - duration, err := strconv.Atoi(maxDuration) +func NewConnHandler(backendAddr string, maxDuration time.Duration) *ConnHandler { + return &ConnHandler{ + BackendAddr: backendAddr, + MaxDuration: maxDuration, + Dialer: net.Dial, + TimerFunc: time.AfterFunc, + } +} + +func (h *ConnHandler) HandleConnection(clientConn net.Conn) { + connID := uuid.New().String() + log.Printf("New connection [%s] started from %s", connID, clientConn.RemoteAddr()) + defer clientConn.Close() + + backendConn, err := h.Dialer("tcp", h.BackendAddr) + if err != nil { + log.Printf("Failed to connect to backend [%s]: %s", connID, err) + return + } + defer backendConn.Close() + + timer := h.TimerFunc(h.MaxDuration, func() { + log.Printf("Connection [%s] exceeded max duration, terminating", connID) + clientConn.Close() + backendConn.Close() + }) + defer timer.Stop() + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + if _, err := io.Copy(backendConn, clientConn); err != nil { + log.Printf("Error forwarding from client to backend [%s]: %s", connID, err) + } + cancel() + }() + go func() { + defer wg.Done() + if _, err := io.Copy(clientConn, backendConn); err != nil { + log.Printf("Error forwarding from backend to client [%s]: %s", connID, err) + } + cancel() + }() + <-ctx.Done() + timer.Stop() + wg.Wait() + log.Printf("Connection [%s] terminated", connID) +} + +func main() { + listenAddr := os.Getenv("LISTEN_ADDR") + if listenAddr == "" { + listenAddr = ":2222" // Default listen address + } + + backendAddr := os.Getenv("SSH_BACKEND") + maxDuration, err := strconv.Atoi(os.Getenv("SSH_MAX_DURATION")) if err != nil { log.Fatalf("Invalid SSH_MAX_DURATION value: %s", err) } @@ -28,51 +92,28 @@ func main() { defer listener.Close() log.Println("Listening on", listenAddr) + handler := NewConnHandler(backendAddr, time.Duration(maxDuration)*time.Second) + + // Handling graceful shutdown + stopChan := make(chan os.Signal, 1) + signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-stopChan + log.Println("Shutting down server...") + listener.Close() + }() + for { clientConn, err := listener.Accept() if err != nil { - log.Println("Failed to accept connection:", err) - continue + if ne, ok := err.(net.Error); ok && ne.Temporary() { + log.Printf("Temporary listener error: %v", err) + continue + } + log.Printf("Listener closed, stopping accept loop: %v", err) + break } - - go handleConnection(clientConn, backendAddr, time.Duration(duration)*time.Second) + go handler.HandleConnection(clientConn) } } -func handleConnection(clientConn net.Conn, backendAddr string, maxDuration time.Duration) { - connID := uuid.New().String() - log.Printf("New connection [%s] started from %s", connID, clientConn.RemoteAddr()) - - defer clientConn.Close() - - backendConn, err := net.Dial("tcp", backendAddr) - if err != nil { - log.Printf("Failed to connect to backend [%s]: %s", connID, err) - return - } - defer backendConn.Close() - - // Set up a timer to close both connections when the maxDuration is exceeded - timer := time.AfterFunc(maxDuration, func() { - log.Printf("Connection [%s] exceeded max duration, terminating", connID) - clientConn.Close() - backendConn.Close() - }) - defer timer.Stop() - - // Forward traffic between the client and the backend - go func() { - _, err := io.Copy(backendConn, clientConn) - if err != nil { - log.Printf("Error forwarding from client to backend [%s]: %s", connID, err) - } - }() - - _, err = io.Copy(clientConn, backendConn) - if err != nil { - log.Printf("Error forwarding from backend to client [%s]: %s", connID, err) - } - - log.Printf("Connection [%s] terminated", connID) -} - diff --git a/test_logs/test_output.log b/test_logs/test_output.log new file mode 100644 index 0000000..a134877 --- /dev/null +++ b/test_logs/test_output.log @@ -0,0 +1,34 @@ +--- FAIL: TestHandleConnection_Success (0.00s) +panic: runtime error: invalid memory address or nil pointer dereference [recovered] + panic: runtime error: invalid memory address or nil pointer dereference +[signal SIGSEGV: segmentation violation code=0x2 addr=0x18 pc=0x10069ca6c] + +goroutine 19 [running]: +testing.tRunner.func1.2({0x1007ffee0, 0x100956250}) + /usr/local/go/src/testing/testing.go:1545 +0x1c8 +testing.tRunner.func1() + /usr/local/go/src/testing/testing.go:1548 +0x360 +panic({0x1007ffee0?, 0x100956250?}) + /usr/local/go/src/runtime/panic.go:914 +0x218 +io.(*nopCloser).Read(0x140000a8af0?, {0x140000a8af0?, 0x14000056c48?, 0x1007589f0?}) + :1 +0x2c +io.ReadAtLeast({0x100d918c8, 0x14000098500}, {0x140000a8af0, 0x10, 0x10}, 0x10) + /usr/local/go/src/io/io.go:335 +0xa0 +io.ReadFull(...) + /usr/local/go/src/io/io.go:354 +github.com/google/uuid.NewRandomFromReader({0x100d918c8, 0x14000098500}) + /Users/aedev/go/pkg/mod/github.com/google/uuid@v1.6.0/version4.go:49 +0x54 +github.com/google/uuid.NewRandom() + /Users/aedev/go/pkg/mod/github.com/google/uuid@v1.6.0/version4.go:41 +0x5c +github.com/google/uuid.New(...) + /Users/aedev/go/pkg/mod/github.com/google/uuid@v1.6.0/version4.go:14 +ssh-timeout.(*ConnHandler).HandleConnection(0x140000e1f30, {0x1008353f0?, 0x140000ae180}) + /Users/aedev/dev/ssh-timeout/main.go:31 +0x44 +ssh-timeout.TestHandleConnection_Success(0x0?) + /Users/aedev/dev/ssh-timeout/main_test.go:57 +0x250 +testing.tRunner(0x14000083040, 0x1008322c8) + /usr/local/go/src/testing/testing.go:1595 +0xe8 +created by testing.(*T).Run in goroutine 1 + /usr/local/go/src/testing/testing.go:1648 +0x33c +FAIL ssh-timeout 0.505s +FAIL