133 lines
4.0 KiB
Go
133 lines
4.0 KiB
Go
package client
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// TunnelRequest is the metadata sent to the server on the tunnel-request channel.
|
|
type TunnelRequest struct {
|
|
Domain string `json:"domain"`
|
|
AuthUser string `json:"auth_user,omitempty"` // optional HTTP Basic Auth username
|
|
AuthPass string `json:"auth_pass,omitempty"` // optional HTTP Basic Auth password
|
|
}
|
|
|
|
// SetupTunnel sends domain metadata and establishes a reverse port forward.
|
|
// The server will allocate a port and register Traefik routes for the domain.
|
|
// localHost and localPort are the backend address (e.g. "127.0.0.1" or "192.168.0.2", 11001).
|
|
// Backend TLS is detected dynamically: TLS is tried first; on failure, plain TCP is used.
|
|
// authUser and authPass are optional; if both are non-empty, the server will
|
|
// add a Traefik basicauth middleware in front of this tunnel.
|
|
func SetupTunnel(client *ssh.Client, domain string, localHost string, localPort int, authUser, authPass string) error {
|
|
// Step 1: Open custom channel to send domain metadata.
|
|
if err := sendMetadata(client, domain, authUser, authPass); err != nil {
|
|
return fmt.Errorf("send metadata: %w", err)
|
|
}
|
|
|
|
// Step 2: Request reverse port forward.
|
|
// We use the domain as the bind address so the server can associate it.
|
|
listener, err := client.Listen("tcp", fmt.Sprintf("%s:0", domain))
|
|
if err != nil {
|
|
return fmt.Errorf("reverse listen: %w", err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
log.Printf("Reverse tunnel established: %s -> %s:%d", domain, localHost, localPort)
|
|
|
|
// Step 3: Accept connections from the server and forward to local service.
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return fmt.Errorf("tunnel accept: %w", err)
|
|
}
|
|
go forwardToLocal(conn, localHost, localPort)
|
|
}
|
|
}
|
|
|
|
// sendMetadata opens a custom channel and sends the tunnel request JSON.
|
|
func sendMetadata(client *ssh.Client, domain, authUser, authPass string) error {
|
|
ch, _, err := client.OpenChannel("tunnel-request", nil)
|
|
if err != nil {
|
|
return fmt.Errorf("open tunnel-request channel: %w", err)
|
|
}
|
|
|
|
req := TunnelRequest{Domain: domain, AuthUser: authUser, AuthPass: authPass}
|
|
data, err := json.Marshal(req)
|
|
if err != nil {
|
|
ch.Close()
|
|
return fmt.Errorf("marshal metadata: %w", err)
|
|
}
|
|
|
|
if _, err := ch.Write(data); err != nil {
|
|
ch.Close()
|
|
return fmt.Errorf("write metadata: %w", err)
|
|
}
|
|
|
|
if authUser != "" {
|
|
log.Printf("Sent tunnel metadata: domain=%s (with basicauth)", domain)
|
|
} else {
|
|
log.Printf("Sent tunnel metadata: domain=%s", domain)
|
|
}
|
|
|
|
// Keep the channel open in a goroutine for disconnect detection.
|
|
go func() {
|
|
io.Copy(io.Discard, ch)
|
|
log.Printf("Metadata channel closed for %s", domain)
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
// forwardToLocal connects an incoming tunnel connection to the backend.
|
|
// It tries TLS first; if the backend does not speak TLS, it falls back to plain TCP.
|
|
func forwardToLocal(remoteConn net.Conn, localHost string, localPort int) {
|
|
defer remoteConn.Close()
|
|
|
|
addr := fmt.Sprintf("%s:%d", localHost, localPort)
|
|
localConn, _ := dialBackend(addr, localHost)
|
|
if localConn == nil {
|
|
return
|
|
}
|
|
defer localConn.Close()
|
|
|
|
// Bidirectional copy.
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
io.Copy(localConn, remoteConn)
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
io.Copy(remoteConn, localConn)
|
|
}()
|
|
wg.Wait()
|
|
}
|
|
|
|
// dialBackend tries TLS first; on failure (e.g. backend is plain HTTP), uses plain TCP.
|
|
// Returns (conn, true) if TLS was used, (conn, false) if plain, (nil, false) on error.
|
|
func dialBackend(addr, serverName string) (net.Conn, bool) {
|
|
tlsConn, err := tls.DialWithDialer(&net.Dialer{Timeout: 5 * time.Second}, "tcp", addr, &tls.Config{
|
|
ServerName: serverName,
|
|
InsecureSkipVerify: true,
|
|
})
|
|
if err == nil {
|
|
return tlsConn, true
|
|
}
|
|
// Fall back to plain TCP (backend is HTTP or not TLS).
|
|
plainConn, err := net.DialTimeout("tcp", addr, 5*time.Second)
|
|
if err != nil {
|
|
log.Printf("failed to connect to backend at %s: %v", addr, err)
|
|
return nil, false
|
|
}
|
|
return plainConn, false
|
|
}
|