better handling and brakes
This commit is contained in:
parent
b5f0ab4908
commit
8dde1883cf
19
build.sh
19
build.sh
|
@ -11,11 +11,12 @@ prepare_build() {
|
|||
# Create necessary directories if they don't exist
|
||||
mkdir -p dist
|
||||
mkdir -p build_logs
|
||||
mkdir -p test_logs # Directory for test logs
|
||||
|
||||
# Initialize go modules if go.mod does not exist
|
||||
if [ ! -f go.mod ]; then
|
||||
echo "Initializing Go modules"
|
||||
go mod init yourmodule # Replace 'yourmodule' with your actual module name or path
|
||||
go mod init ssh-timeout # Replace 'yourmodule' with your actual module name or path
|
||||
fi
|
||||
|
||||
# Fetch and ensure all dependencies are up to date
|
||||
|
@ -44,8 +45,24 @@ build_binary() {
|
|||
fi
|
||||
}
|
||||
|
||||
# Test function
|
||||
run_tests() {
|
||||
if ls *.go | grep '_test.go$' >/dev/null; then
|
||||
echo "Running tests..."
|
||||
go test ./... > test_logs/test_output.log 2>&1
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Tests completed successfully."
|
||||
else
|
||||
echo "Tests failed. Check test_logs/test_output.log for details."
|
||||
fi
|
||||
else
|
||||
echo "No test files found, skipping tests."
|
||||
fi
|
||||
}
|
||||
|
||||
# Main Build Process
|
||||
prepare_build
|
||||
run_tests # Run tests optionally before the build
|
||||
for arch in "${ARCHITECTURES[@]}"; do
|
||||
IFS='/' read -r -a parts <<< "$arch" # Split architecture string
|
||||
os=${parts[0]}
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
133
main.go
133
main.go
|
@ -1,22 +1,86 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func main() {
|
||||
listenAddr := ":2222" // The local port on which to listen for incoming SSH connections.
|
||||
backendAddr := os.Getenv("SSH_BACKEND") // Backend SSH server address from environment variable.
|
||||
maxDuration := os.Getenv("SSH_MAX_DURATION") // Max connection duration from environment variable.
|
||||
type ConnHandler struct {
|
||||
BackendAddr string
|
||||
MaxDuration time.Duration
|
||||
Dialer func(network, addr string) (net.Conn, error)
|
||||
TimerFunc func(d time.Duration, f func()) *time.Timer
|
||||
}
|
||||
|
||||
duration, err := strconv.Atoi(maxDuration)
|
||||
func NewConnHandler(backendAddr string, maxDuration time.Duration) *ConnHandler {
|
||||
return &ConnHandler{
|
||||
BackendAddr: backendAddr,
|
||||
MaxDuration: maxDuration,
|
||||
Dialer: net.Dial,
|
||||
TimerFunc: time.AfterFunc,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ConnHandler) HandleConnection(clientConn net.Conn) {
|
||||
connID := uuid.New().String()
|
||||
log.Printf("New connection [%s] started from %s", connID, clientConn.RemoteAddr())
|
||||
defer clientConn.Close()
|
||||
|
||||
backendConn, err := h.Dialer("tcp", h.BackendAddr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to connect to backend [%s]: %s", connID, err)
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
timer := h.TimerFunc(h.MaxDuration, func() {
|
||||
log.Printf("Connection [%s] exceeded max duration, terminating", connID)
|
||||
clientConn.Close()
|
||||
backendConn.Close()
|
||||
})
|
||||
defer timer.Stop()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := io.Copy(backendConn, clientConn); err != nil {
|
||||
log.Printf("Error forwarding from client to backend [%s]: %s", connID, err)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := io.Copy(clientConn, backendConn); err != nil {
|
||||
log.Printf("Error forwarding from backend to client [%s]: %s", connID, err)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
<-ctx.Done()
|
||||
timer.Stop()
|
||||
wg.Wait()
|
||||
log.Printf("Connection [%s] terminated", connID)
|
||||
}
|
||||
|
||||
func main() {
|
||||
listenAddr := os.Getenv("LISTEN_ADDR")
|
||||
if listenAddr == "" {
|
||||
listenAddr = ":2222" // Default listen address
|
||||
}
|
||||
|
||||
backendAddr := os.Getenv("SSH_BACKEND")
|
||||
maxDuration, err := strconv.Atoi(os.Getenv("SSH_MAX_DURATION"))
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid SSH_MAX_DURATION value: %s", err)
|
||||
}
|
||||
|
@ -28,51 +92,28 @@ func main() {
|
|||
defer listener.Close()
|
||||
log.Println("Listening on", listenAddr)
|
||||
|
||||
handler := NewConnHandler(backendAddr, time.Duration(maxDuration)*time.Second)
|
||||
|
||||
// Handling graceful shutdown
|
||||
stopChan := make(chan os.Signal, 1)
|
||||
signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-stopChan
|
||||
log.Println("Shutting down server...")
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
clientConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println("Failed to accept connection:", err)
|
||||
continue
|
||||
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||
log.Printf("Temporary listener error: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Printf("Listener closed, stopping accept loop: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
go handleConnection(clientConn, backendAddr, time.Duration(duration)*time.Second)
|
||||
go handler.HandleConnection(clientConn)
|
||||
}
|
||||
}
|
||||
|
||||
func handleConnection(clientConn net.Conn, backendAddr string, maxDuration time.Duration) {
|
||||
connID := uuid.New().String()
|
||||
log.Printf("New connection [%s] started from %s", connID, clientConn.RemoteAddr())
|
||||
|
||||
defer clientConn.Close()
|
||||
|
||||
backendConn, err := net.Dial("tcp", backendAddr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to connect to backend [%s]: %s", connID, err)
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
// Set up a timer to close both connections when the maxDuration is exceeded
|
||||
timer := time.AfterFunc(maxDuration, func() {
|
||||
log.Printf("Connection [%s] exceeded max duration, terminating", connID)
|
||||
clientConn.Close()
|
||||
backendConn.Close()
|
||||
})
|
||||
defer timer.Stop()
|
||||
|
||||
// Forward traffic between the client and the backend
|
||||
go func() {
|
||||
_, err := io.Copy(backendConn, clientConn)
|
||||
if err != nil {
|
||||
log.Printf("Error forwarding from client to backend [%s]: %s", connID, err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = io.Copy(clientConn, backendConn)
|
||||
if err != nil {
|
||||
log.Printf("Error forwarding from backend to client [%s]: %s", connID, err)
|
||||
}
|
||||
|
||||
log.Printf("Connection [%s] terminated", connID)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
--- FAIL: TestHandleConnection_Success (0.00s)
|
||||
panic: runtime error: invalid memory address or nil pointer dereference [recovered]
|
||||
panic: runtime error: invalid memory address or nil pointer dereference
|
||||
[signal SIGSEGV: segmentation violation code=0x2 addr=0x18 pc=0x10069ca6c]
|
||||
|
||||
goroutine 19 [running]:
|
||||
testing.tRunner.func1.2({0x1007ffee0, 0x100956250})
|
||||
/usr/local/go/src/testing/testing.go:1545 +0x1c8
|
||||
testing.tRunner.func1()
|
||||
/usr/local/go/src/testing/testing.go:1548 +0x360
|
||||
panic({0x1007ffee0?, 0x100956250?})
|
||||
/usr/local/go/src/runtime/panic.go:914 +0x218
|
||||
io.(*nopCloser).Read(0x140000a8af0?, {0x140000a8af0?, 0x14000056c48?, 0x1007589f0?})
|
||||
<autogenerated>:1 +0x2c
|
||||
io.ReadAtLeast({0x100d918c8, 0x14000098500}, {0x140000a8af0, 0x10, 0x10}, 0x10)
|
||||
/usr/local/go/src/io/io.go:335 +0xa0
|
||||
io.ReadFull(...)
|
||||
/usr/local/go/src/io/io.go:354
|
||||
github.com/google/uuid.NewRandomFromReader({0x100d918c8, 0x14000098500})
|
||||
/Users/aedev/go/pkg/mod/github.com/google/uuid@v1.6.0/version4.go:49 +0x54
|
||||
github.com/google/uuid.NewRandom()
|
||||
/Users/aedev/go/pkg/mod/github.com/google/uuid@v1.6.0/version4.go:41 +0x5c
|
||||
github.com/google/uuid.New(...)
|
||||
/Users/aedev/go/pkg/mod/github.com/google/uuid@v1.6.0/version4.go:14
|
||||
ssh-timeout.(*ConnHandler).HandleConnection(0x140000e1f30, {0x1008353f0?, 0x140000ae180})
|
||||
/Users/aedev/dev/ssh-timeout/main.go:31 +0x44
|
||||
ssh-timeout.TestHandleConnection_Success(0x0?)
|
||||
/Users/aedev/dev/ssh-timeout/main_test.go:57 +0x250
|
||||
testing.tRunner(0x14000083040, 0x1008322c8)
|
||||
/usr/local/go/src/testing/testing.go:1595 +0xe8
|
||||
created by testing.(*T).Run in goroutine 1
|
||||
/usr/local/go/src/testing/testing.go:1648 +0x33c
|
||||
FAIL ssh-timeout 0.505s
|
||||
FAIL
|
Loading…
Reference in New Issue