From 7e4b5fce60993af30a6fd939c8013a2e284418a2 Mon Sep 17 00:00:00 2001 From: Leopere Date: Wed, 18 Feb 2026 14:14:48 -0500 Subject: [PATCH] tunnel-client: TUNNEL_HOST + dynamic backend TLS (try TLS, fallback to plain) Co-authored-by: Cursor --- cmd/client/main.go | 7 ++++--- internal/client/tunnel.go | 42 +++++++++++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/cmd/client/main.go b/cmd/client/main.go index ef80a6a..9837b30 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -36,6 +36,7 @@ func main() { authUser := envOr("TUNNEL_AUTH_USER", "") authPass := envOr("TUNNEL_AUTH_PASS", "") + localHost := envOr("TUNNEL_HOST", "127.0.0.1") localPortStr := envOr("TUNNEL_PORT", "8080") localPort, err := strconv.Atoi(localPortStr) if err != nil { @@ -55,9 +56,9 @@ func main() { for { if authUser != "" { - log.Printf("Connecting to %s (domain=%s, local_port=%d, basicauth=enabled)", serverAddr, domain, localPort) + log.Printf("Connecting to %s (domain=%s, local=%s:%d, basicauth=enabled)", serverAddr, domain, localHost, localPort) } else { - log.Printf("Connecting to %s (domain=%s, local_port=%d)", serverAddr, domain, localPort) + log.Printf("Connecting to %s (domain=%s, local=%s:%d)", serverAddr, domain, localHost, localPort) } sshClient, err := client.Connect(serverAddr, signer) @@ -73,7 +74,7 @@ func main() { log.Printf("Connected to %s", serverAddr) // Set up the reverse tunnel (blocks until disconnected). - if err := client.SetupTunnel(sshClient, domain, localPort, authUser, authPass); err != nil { + if err := client.SetupTunnel(sshClient, domain, localHost, localPort, authUser, authPass); err != nil { log.Printf("Tunnel error: %v (reconnecting in %s)", err, backoff) } diff --git a/internal/client/tunnel.go b/internal/client/tunnel.go index 8837051..5fb9efd 100644 --- a/internal/client/tunnel.go +++ b/internal/client/tunnel.go @@ -1,12 +1,14 @@ package client import ( + "crypto/tls" "encoding/json" "fmt" "io" "log" "net" "sync" + "time" "golang.org/x/crypto/ssh" ) @@ -20,10 +22,11 @@ type TunnelRequest struct { // SetupTunnel sends domain metadata and establishes a reverse port forward. // The server will allocate a port and register Traefik routes for the domain. -// localPort is the port of the service running on the client side. +// localHost and localPort are the backend address (e.g. "127.0.0.1" or "192.168.0.2", 11001). +// Backend TLS is detected dynamically: TLS is tried first; on failure, plain TCP is used. // authUser and authPass are optional; if both are non-empty, the server will // add a Traefik basicauth middleware in front of this tunnel. -func SetupTunnel(client *ssh.Client, domain string, localPort int, authUser, authPass string) error { +func SetupTunnel(client *ssh.Client, domain string, localHost string, localPort int, authUser, authPass string) error { // Step 1: Open custom channel to send domain metadata. if err := sendMetadata(client, domain, authUser, authPass); err != nil { return fmt.Errorf("send metadata: %w", err) @@ -37,7 +40,7 @@ func SetupTunnel(client *ssh.Client, domain string, localPort int, authUser, aut } defer listener.Close() - log.Printf("Reverse tunnel established: %s -> localhost:%d", domain, localPort) + log.Printf("Reverse tunnel established: %s -> %s:%d", domain, localHost, localPort) // Step 3: Accept connections from the server and forward to local service. for { @@ -45,7 +48,7 @@ func SetupTunnel(client *ssh.Client, domain string, localPort int, authUser, aut if err != nil { return fmt.Errorf("tunnel accept: %w", err) } - go forwardToLocal(conn, localPort) + go forwardToLocal(conn, localHost, localPort) } } @@ -83,14 +86,14 @@ func sendMetadata(client *ssh.Client, domain, authUser, authPass string) error { return nil } -// forwardToLocal connects an incoming tunnel connection to the local service. -func forwardToLocal(remoteConn net.Conn, localPort int) { +// forwardToLocal connects an incoming tunnel connection to the backend. +// It tries TLS first; if the backend does not speak TLS, it falls back to plain TCP. +func forwardToLocal(remoteConn net.Conn, localHost string, localPort int) { defer remoteConn.Close() - localAddr := fmt.Sprintf("127.0.0.1:%d", localPort) - localConn, err := net.Dial("tcp", localAddr) - if err != nil { - log.Printf("failed to connect to local service at %s: %v", localAddr, err) + addr := fmt.Sprintf("%s:%d", localHost, localPort) + localConn, _ := dialBackend(addr, localHost) + if localConn == nil { return } defer localConn.Close() @@ -108,3 +111,22 @@ func forwardToLocal(remoteConn net.Conn, localPort int) { }() wg.Wait() } + +// dialBackend tries TLS first; on failure (e.g. backend is plain HTTP), uses plain TCP. +// Returns (conn, true) if TLS was used, (conn, false) if plain, (nil, false) on error. +func dialBackend(addr, serverName string) (net.Conn, bool) { + tlsConn, err := tls.DialWithDialer(&net.Dialer{Timeout: 5 * time.Second}, "tcp", addr, &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: true, + }) + if err == nil { + return tlsConn, true + } + // Fall back to plain TCP (backend is HTTP or not TLS). + plainConn, err := net.DialTimeout("tcp", addr, 5*time.Second) + if err != nil { + log.Printf("failed to connect to backend at %s: %v", addr, err) + return nil, false + } + return plainConn, false +}