better handling and brakes

This commit is contained in:
Colin 2024-04-30 22:04:07 -04:00
parent b5f0ab4908
commit 8dde1883cf
9 changed files with 140 additions and 48 deletions

View File

@ -11,11 +11,12 @@ prepare_build() {
# Create necessary directories if they don't exist
mkdir -p dist
mkdir -p build_logs
mkdir -p test_logs # Directory for test logs
# Initialize go modules if go.mod does not exist
if [ ! -f go.mod ]; then
echo "Initializing Go modules"
go mod init yourmodule # Replace 'yourmodule' with your actual module name or path
go mod init ssh-timeout # Replace 'yourmodule' with your actual module name or path
fi
# Fetch and ensure all dependencies are up to date
@ -44,8 +45,24 @@ build_binary() {
fi
}
# Test function
run_tests() {
if ls *.go | grep '_test.go$' >/dev/null; then
echo "Running tests..."
go test ./... > test_logs/test_output.log 2>&1
if [ $? -eq 0 ]; then
echo "Tests completed successfully."
else
echo "Tests failed. Check test_logs/test_output.log for details."
fi
else
echo "No test files found, skipping tests."
fi
}
# Main Build Process
prepare_build
run_tests # Run tests optionally before the build
for arch in "${ARCHITECTURES[@]}"; do
IFS='/' read -r -a parts <<< "$arch" # Split architecture string
os=${parts[0]}

BIN
dist/ssh-timeout vendored

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

2
go.mod
View File

@ -1,4 +1,4 @@
module yourmodule
module ssh-timeout
go 1.21.1

131
main.go
View File

@ -1,22 +1,86 @@
package main
import (
"context"
"io"
"log"
"net"
"os"
"os/signal"
"strconv"
"sync"
"syscall"
"time"
"github.com/google/uuid"
)
func main() {
listenAddr := ":2222" // The local port on which to listen for incoming SSH connections.
backendAddr := os.Getenv("SSH_BACKEND") // Backend SSH server address from environment variable.
maxDuration := os.Getenv("SSH_MAX_DURATION") // Max connection duration from environment variable.
type ConnHandler struct {
BackendAddr string
MaxDuration time.Duration
Dialer func(network, addr string) (net.Conn, error)
TimerFunc func(d time.Duration, f func()) *time.Timer
}
duration, err := strconv.Atoi(maxDuration)
func NewConnHandler(backendAddr string, maxDuration time.Duration) *ConnHandler {
return &ConnHandler{
BackendAddr: backendAddr,
MaxDuration: maxDuration,
Dialer: net.Dial,
TimerFunc: time.AfterFunc,
}
}
func (h *ConnHandler) HandleConnection(clientConn net.Conn) {
connID := uuid.New().String()
log.Printf("New connection [%s] started from %s", connID, clientConn.RemoteAddr())
defer clientConn.Close()
backendConn, err := h.Dialer("tcp", h.BackendAddr)
if err != nil {
log.Printf("Failed to connect to backend [%s]: %s", connID, err)
return
}
defer backendConn.Close()
timer := h.TimerFunc(h.MaxDuration, func() {
log.Printf("Connection [%s] exceeded max duration, terminating", connID)
clientConn.Close()
backendConn.Close()
})
defer timer.Stop()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
if _, err := io.Copy(backendConn, clientConn); err != nil {
log.Printf("Error forwarding from client to backend [%s]: %s", connID, err)
}
cancel()
}()
go func() {
defer wg.Done()
if _, err := io.Copy(clientConn, backendConn); err != nil {
log.Printf("Error forwarding from backend to client [%s]: %s", connID, err)
}
cancel()
}()
<-ctx.Done()
timer.Stop()
wg.Wait()
log.Printf("Connection [%s] terminated", connID)
}
func main() {
listenAddr := os.Getenv("LISTEN_ADDR")
if listenAddr == "" {
listenAddr = ":2222" // Default listen address
}
backendAddr := os.Getenv("SSH_BACKEND")
maxDuration, err := strconv.Atoi(os.Getenv("SSH_MAX_DURATION"))
if err != nil {
log.Fatalf("Invalid SSH_MAX_DURATION value: %s", err)
}
@ -28,51 +92,28 @@ func main() {
defer listener.Close()
log.Println("Listening on", listenAddr)
handler := NewConnHandler(backendAddr, time.Duration(maxDuration)*time.Second)
// Handling graceful shutdown
stopChan := make(chan os.Signal, 1)
signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-stopChan
log.Println("Shutting down server...")
listener.Close()
}()
for {
clientConn, err := listener.Accept()
if err != nil {
log.Println("Failed to accept connection:", err)
if ne, ok := err.(net.Error); ok && ne.Temporary() {
log.Printf("Temporary listener error: %v", err)
continue
}
go handleConnection(clientConn, backendAddr, time.Duration(duration)*time.Second)
log.Printf("Listener closed, stopping accept loop: %v", err)
break
}
go handler.HandleConnection(clientConn)
}
}
func handleConnection(clientConn net.Conn, backendAddr string, maxDuration time.Duration) {
connID := uuid.New().String()
log.Printf("New connection [%s] started from %s", connID, clientConn.RemoteAddr())
defer clientConn.Close()
backendConn, err := net.Dial("tcp", backendAddr)
if err != nil {
log.Printf("Failed to connect to backend [%s]: %s", connID, err)
return
}
defer backendConn.Close()
// Set up a timer to close both connections when the maxDuration is exceeded
timer := time.AfterFunc(maxDuration, func() {
log.Printf("Connection [%s] exceeded max duration, terminating", connID)
clientConn.Close()
backendConn.Close()
})
defer timer.Stop()
// Forward traffic between the client and the backend
go func() {
_, err := io.Copy(backendConn, clientConn)
if err != nil {
log.Printf("Error forwarding from client to backend [%s]: %s", connID, err)
}
}()
_, err = io.Copy(clientConn, backendConn)
if err != nil {
log.Printf("Error forwarding from backend to client [%s]: %s", connID, err)
}
log.Printf("Connection [%s] terminated", connID)
}

34
test_logs/test_output.log Normal file
View File

@ -0,0 +1,34 @@
--- FAIL: TestHandleConnection_Success (0.00s)
panic: runtime error: invalid memory address or nil pointer dereference [recovered]
panic: runtime error: invalid memory address or nil pointer dereference
[signal SIGSEGV: segmentation violation code=0x2 addr=0x18 pc=0x10069ca6c]
goroutine 19 [running]:
testing.tRunner.func1.2({0x1007ffee0, 0x100956250})
/usr/local/go/src/testing/testing.go:1545 +0x1c8
testing.tRunner.func1()
/usr/local/go/src/testing/testing.go:1548 +0x360
panic({0x1007ffee0?, 0x100956250?})
/usr/local/go/src/runtime/panic.go:914 +0x218
io.(*nopCloser).Read(0x140000a8af0?, {0x140000a8af0?, 0x14000056c48?, 0x1007589f0?})
<autogenerated>:1 +0x2c
io.ReadAtLeast({0x100d918c8, 0x14000098500}, {0x140000a8af0, 0x10, 0x10}, 0x10)
/usr/local/go/src/io/io.go:335 +0xa0
io.ReadFull(...)
/usr/local/go/src/io/io.go:354
github.com/google/uuid.NewRandomFromReader({0x100d918c8, 0x14000098500})
/Users/aedev/go/pkg/mod/github.com/google/uuid@v1.6.0/version4.go:49 +0x54
github.com/google/uuid.NewRandom()
/Users/aedev/go/pkg/mod/github.com/google/uuid@v1.6.0/version4.go:41 +0x5c
github.com/google/uuid.New(...)
/Users/aedev/go/pkg/mod/github.com/google/uuid@v1.6.0/version4.go:14
ssh-timeout.(*ConnHandler).HandleConnection(0x140000e1f30, {0x1008353f0?, 0x140000ae180})
/Users/aedev/dev/ssh-timeout/main.go:31 +0x44
ssh-timeout.TestHandleConnection_Success(0x0?)
/Users/aedev/dev/ssh-timeout/main_test.go:57 +0x250
testing.tRunner(0x14000083040, 0x1008322c8)
/usr/local/go/src/testing/testing.go:1595 +0xe8
created by testing.(*T).Run in goroutine 1
/usr/local/go/src/testing/testing.go:1648 +0x33c
FAIL ssh-timeout 0.505s
FAIL