better-argo-tunnels/internal/client/tunnel.go

111 lines
3.1 KiB
Go

package client
import (
"encoding/json"
"fmt"
"io"
"log"
"net"
"sync"
"golang.org/x/crypto/ssh"
)
// TunnelRequest is the metadata sent to the server on the tunnel-request channel.
type TunnelRequest struct {
Domain string `json:"domain"`
AuthUser string `json:"auth_user,omitempty"` // optional HTTP Basic Auth username
AuthPass string `json:"auth_pass,omitempty"` // optional HTTP Basic Auth password
}
// SetupTunnel sends domain metadata and establishes a reverse port forward.
// The server will allocate a port and register Traefik routes for the domain.
// localPort is the port of the service running on the client side.
// authUser and authPass are optional; if both are non-empty, the server will
// add a Traefik basicauth middleware in front of this tunnel.
func SetupTunnel(client *ssh.Client, domain string, localPort int, authUser, authPass string) error {
// Step 1: Open custom channel to send domain metadata.
if err := sendMetadata(client, domain, authUser, authPass); err != nil {
return fmt.Errorf("send metadata: %w", err)
}
// Step 2: Request reverse port forward.
// We use the domain as the bind address so the server can associate it.
listener, err := client.Listen("tcp", fmt.Sprintf("%s:0", domain))
if err != nil {
return fmt.Errorf("reverse listen: %w", err)
}
defer listener.Close()
log.Printf("Reverse tunnel established: %s -> localhost:%d", domain, localPort)
// Step 3: Accept connections from the server and forward to local service.
for {
conn, err := listener.Accept()
if err != nil {
return fmt.Errorf("tunnel accept: %w", err)
}
go forwardToLocal(conn, localPort)
}
}
// sendMetadata opens a custom channel and sends the tunnel request JSON.
func sendMetadata(client *ssh.Client, domain, authUser, authPass string) error {
ch, _, err := client.OpenChannel("tunnel-request", nil)
if err != nil {
return fmt.Errorf("open tunnel-request channel: %w", err)
}
req := TunnelRequest{Domain: domain, AuthUser: authUser, AuthPass: authPass}
data, err := json.Marshal(req)
if err != nil {
ch.Close()
return fmt.Errorf("marshal metadata: %w", err)
}
if _, err := ch.Write(data); err != nil {
ch.Close()
return fmt.Errorf("write metadata: %w", err)
}
if authUser != "" {
log.Printf("Sent tunnel metadata: domain=%s (with basicauth)", domain)
} else {
log.Printf("Sent tunnel metadata: domain=%s", domain)
}
// Keep the channel open in a goroutine for disconnect detection.
go func() {
io.Copy(io.Discard, ch)
log.Printf("Metadata channel closed for %s", domain)
}()
return nil
}
// forwardToLocal connects an incoming tunnel connection to the local service.
func forwardToLocal(remoteConn net.Conn, localPort int) {
defer remoteConn.Close()
localAddr := fmt.Sprintf("127.0.0.1:%d", localPort)
localConn, err := net.Dial("tcp", localAddr)
if err != nil {
log.Printf("failed to connect to local service at %s: %v", localAddr, err)
return
}
defer localConn.Close()
// Bidirectional copy.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(localConn, remoteConn)
}()
go func() {
defer wg.Done()
io.Copy(remoteConn, localConn)
}()
wg.Wait()
}