tunnel-client: TUNNEL_HOST + dynamic backend TLS (try TLS, fallback to plain)
ci/woodpecker/push/woodpecker Pipeline was successful Details

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Leopere 2026-02-18 14:14:48 -05:00
parent 29acb88398
commit 7e4b5fce60
Signed by: colin
SSH Key Fingerprint: SHA256:nRPCQTeMFLdGytxRQmPVK9VXY3/ePKQ5lGRyJhT5DY8
2 changed files with 36 additions and 13 deletions

View File

@ -36,6 +36,7 @@ func main() {
authUser := envOr("TUNNEL_AUTH_USER", "") authUser := envOr("TUNNEL_AUTH_USER", "")
authPass := envOr("TUNNEL_AUTH_PASS", "") authPass := envOr("TUNNEL_AUTH_PASS", "")
localHost := envOr("TUNNEL_HOST", "127.0.0.1")
localPortStr := envOr("TUNNEL_PORT", "8080") localPortStr := envOr("TUNNEL_PORT", "8080")
localPort, err := strconv.Atoi(localPortStr) localPort, err := strconv.Atoi(localPortStr)
if err != nil { if err != nil {
@ -55,9 +56,9 @@ func main() {
for { for {
if authUser != "" { if authUser != "" {
log.Printf("Connecting to %s (domain=%s, local_port=%d, basicauth=enabled)", serverAddr, domain, localPort) log.Printf("Connecting to %s (domain=%s, local=%s:%d, basicauth=enabled)", serverAddr, domain, localHost, localPort)
} else { } else {
log.Printf("Connecting to %s (domain=%s, local_port=%d)", serverAddr, domain, localPort) log.Printf("Connecting to %s (domain=%s, local=%s:%d)", serverAddr, domain, localHost, localPort)
} }
sshClient, err := client.Connect(serverAddr, signer) sshClient, err := client.Connect(serverAddr, signer)
@ -73,7 +74,7 @@ func main() {
log.Printf("Connected to %s", serverAddr) log.Printf("Connected to %s", serverAddr)
// Set up the reverse tunnel (blocks until disconnected). // Set up the reverse tunnel (blocks until disconnected).
if err := client.SetupTunnel(sshClient, domain, localPort, authUser, authPass); err != nil { if err := client.SetupTunnel(sshClient, domain, localHost, localPort, authUser, authPass); err != nil {
log.Printf("Tunnel error: %v (reconnecting in %s)", err, backoff) log.Printf("Tunnel error: %v (reconnecting in %s)", err, backoff)
} }

View File

@ -1,12 +1,14 @@
package client package client
import ( import (
"crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"sync" "sync"
"time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -20,10 +22,11 @@ type TunnelRequest struct {
// SetupTunnel sends domain metadata and establishes a reverse port forward. // SetupTunnel sends domain metadata and establishes a reverse port forward.
// The server will allocate a port and register Traefik routes for the domain. // The server will allocate a port and register Traefik routes for the domain.
// localPort is the port of the service running on the client side. // localHost and localPort are the backend address (e.g. "127.0.0.1" or "192.168.0.2", 11001).
// Backend TLS is detected dynamically: TLS is tried first; on failure, plain TCP is used.
// authUser and authPass are optional; if both are non-empty, the server will // authUser and authPass are optional; if both are non-empty, the server will
// add a Traefik basicauth middleware in front of this tunnel. // add a Traefik basicauth middleware in front of this tunnel.
func SetupTunnel(client *ssh.Client, domain string, localPort int, authUser, authPass string) error { func SetupTunnel(client *ssh.Client, domain string, localHost string, localPort int, authUser, authPass string) error {
// Step 1: Open custom channel to send domain metadata. // Step 1: Open custom channel to send domain metadata.
if err := sendMetadata(client, domain, authUser, authPass); err != nil { if err := sendMetadata(client, domain, authUser, authPass); err != nil {
return fmt.Errorf("send metadata: %w", err) return fmt.Errorf("send metadata: %w", err)
@ -37,7 +40,7 @@ func SetupTunnel(client *ssh.Client, domain string, localPort int, authUser, aut
} }
defer listener.Close() defer listener.Close()
log.Printf("Reverse tunnel established: %s -> localhost:%d", domain, localPort) log.Printf("Reverse tunnel established: %s -> %s:%d", domain, localHost, localPort)
// Step 3: Accept connections from the server and forward to local service. // Step 3: Accept connections from the server and forward to local service.
for { for {
@ -45,7 +48,7 @@ func SetupTunnel(client *ssh.Client, domain string, localPort int, authUser, aut
if err != nil { if err != nil {
return fmt.Errorf("tunnel accept: %w", err) return fmt.Errorf("tunnel accept: %w", err)
} }
go forwardToLocal(conn, localPort) go forwardToLocal(conn, localHost, localPort)
} }
} }
@ -83,14 +86,14 @@ func sendMetadata(client *ssh.Client, domain, authUser, authPass string) error {
return nil return nil
} }
// forwardToLocal connects an incoming tunnel connection to the local service. // forwardToLocal connects an incoming tunnel connection to the backend.
func forwardToLocal(remoteConn net.Conn, localPort int) { // It tries TLS first; if the backend does not speak TLS, it falls back to plain TCP.
func forwardToLocal(remoteConn net.Conn, localHost string, localPort int) {
defer remoteConn.Close() defer remoteConn.Close()
localAddr := fmt.Sprintf("127.0.0.1:%d", localPort) addr := fmt.Sprintf("%s:%d", localHost, localPort)
localConn, err := net.Dial("tcp", localAddr) localConn, _ := dialBackend(addr, localHost)
if err != nil { if localConn == nil {
log.Printf("failed to connect to local service at %s: %v", localAddr, err)
return return
} }
defer localConn.Close() defer localConn.Close()
@ -108,3 +111,22 @@ func forwardToLocal(remoteConn net.Conn, localPort int) {
}() }()
wg.Wait() wg.Wait()
} }
// dialBackend tries TLS first; on failure (e.g. backend is plain HTTP), uses plain TCP.
// Returns (conn, true) if TLS was used, (conn, false) if plain, (nil, false) on error.
func dialBackend(addr, serverName string) (net.Conn, bool) {
tlsConn, err := tls.DialWithDialer(&net.Dialer{Timeout: 5 * time.Second}, "tcp", addr, &tls.Config{
ServerName: serverName,
InsecureSkipVerify: true,
})
if err == nil {
return tlsConn, true
}
// Fall back to plain TCP (backend is HTTP or not TLS).
plainConn, err := net.DialTimeout("tcp", addr, 5*time.Second)
if err != nil {
log.Printf("failed to connect to backend at %s: %v", addr, err)
return nil, false
}
return plainConn, false
}