365 lines
8.7 KiB
Go
365 lines
8.7 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type clientMessage struct {
|
|
// "stdin"
|
|
Event string `json:"event"`
|
|
// contents of stdin
|
|
Data []byte `json:"data,omitempty"`
|
|
}
|
|
|
|
type serverMessage struct {
|
|
// "start", "stdout", "stderr", "exit", "warn", "error"
|
|
Event string `json:"event"`
|
|
// contents of stdout/stderr
|
|
Data []byte `json:"data,omitempty"`
|
|
// error message
|
|
Text string `json:"text,omitempty"`
|
|
// exit status
|
|
ExitStatus *int `json:"exitStatus,omitempty"`
|
|
}
|
|
|
|
const (
|
|
pingPeriod = 15 * time.Second
|
|
pongWait = 10 * time.Second
|
|
writeDeadline = 1 * time.Second
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{}
|
|
|
|
func logWarn(err error) {
|
|
log.Println(err.Error())
|
|
}
|
|
|
|
func logWarnf(format string, arg ...interface{}) {
|
|
logWarn(fmt.Errorf(format, arg...))
|
|
}
|
|
|
|
func logError(err error) {
|
|
log.Println(err.Error())
|
|
}
|
|
|
|
func logErrorf(format string, arg ...interface{}) {
|
|
logError(fmt.Errorf(format, arg...))
|
|
}
|
|
|
|
func tryClose(obj io.Closer, objName string) {
|
|
err := obj.Close()
|
|
if err != nil {
|
|
logErrorf("error closing %s: %w", objName, err)
|
|
}
|
|
}
|
|
|
|
func closeWs(ws *websocket.Conn) {
|
|
err := ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
|
if err != nil {
|
|
logErrorf("sending close message: %w", err)
|
|
}
|
|
tryClose(ws, "websocket")
|
|
}
|
|
|
|
func send(ws *websocket.Conn, msg *serverMessage) {
|
|
data, err := json.Marshal(msg)
|
|
if err != nil {
|
|
logErrorf("marshaling message: %w", err)
|
|
closeWs(ws)
|
|
return
|
|
}
|
|
err = ws.WriteMessage(websocket.TextMessage, data)
|
|
if err != nil {
|
|
logErrorf("sending message: %w", err)
|
|
closeWs(ws)
|
|
return
|
|
}
|
|
}
|
|
|
|
func fatal(ws *websocket.Conn, err error) {
|
|
send(ws, &serverMessage{
|
|
Event: "fatal",
|
|
Text: err.Error(),
|
|
})
|
|
}
|
|
|
|
func fatalf(ws *websocket.Conn, format string, arg ...interface{}) {
|
|
fatal(ws, fmt.Errorf(format, arg...))
|
|
}
|
|
|
|
func warn(ws *websocket.Conn, err error) {
|
|
send(ws, &serverMessage{
|
|
Event: "warn",
|
|
Text: err.Error(),
|
|
})
|
|
}
|
|
|
|
func warnf(ws *websocket.Conn, format string, arg ...interface{}) {
|
|
warn(ws, fmt.Errorf(format, arg...))
|
|
}
|
|
|
|
func handleClientMessages(ws *websocket.Conn, stdinChan chan<- []byte) {
|
|
// Close channel after we exit
|
|
defer close(stdinChan)
|
|
// Stop processing reads some time after we stop receiving
|
|
// timely responses to our pings.
|
|
ws.SetReadDeadline(time.Now().Add(pongWait))
|
|
ws.SetPongHandler(func(string) error {
|
|
ws.SetReadDeadline(time.Now().Add(pongWait))
|
|
return nil
|
|
})
|
|
// Read data and dispatch appropriately. Return on timeout or
|
|
// error. Caller is responsible for cleanup.
|
|
for {
|
|
msgtype, data, err := ws.ReadMessage()
|
|
if err != nil {
|
|
fatalf(ws, "reading message: %w", err)
|
|
return
|
|
}
|
|
if msgtype != websocket.TextMessage {
|
|
fatalf(ws, "received non-text message type %d", msgtype)
|
|
return
|
|
}
|
|
msg := clientMessage{}
|
|
err = json.Unmarshal(data, &msg)
|
|
if err != nil {
|
|
fatalf(ws, "parsing json: %w", err)
|
|
return
|
|
}
|
|
switch msg.Event {
|
|
case "stdin":
|
|
stdinChan <- msg.Data
|
|
default:
|
|
logWarnf("received unknown event type %s", msg.Event)
|
|
}
|
|
}
|
|
}
|
|
|
|
// https://github.com/gorilla/websocket/blob/76ecc29eff79f0cedf70c530605e486fc32131d1/examples/command/main.go
|
|
func handler(w http.ResponseWriter, r *http.Request) {
|
|
// Upgrade http connection to websocket
|
|
ws, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
logErrorf("upgrading connection: %w", err)
|
|
return
|
|
}
|
|
// Close websocket on error or when we exit
|
|
defer closeWs(ws)
|
|
// Parse request query parameters; do this after upgrading to
|
|
// websocket so that we can send errors back on the websocket
|
|
// which is easier for clients to parse
|
|
cmdline := r.URL.Query()["cmdline"]
|
|
if len(cmdline) == 0 {
|
|
fatalf(ws, "cmdline query parameter missing")
|
|
return
|
|
}
|
|
// Create pipes for communicating with subprocess
|
|
stdinRead, stdinWrite, err := os.Pipe()
|
|
if err != nil {
|
|
fatalf(ws, "creating stdin pipe: %w", err)
|
|
return
|
|
}
|
|
defer tryClose(stdinRead, "read end of stdin pipe")
|
|
defer tryClose(stdinWrite, "write end of stdin pipe")
|
|
stdoutRead, stdoutWrite, err := os.Pipe()
|
|
if err != nil {
|
|
fatalf(ws, "creating stdout pipe: %w", err)
|
|
return
|
|
}
|
|
defer tryClose(stdoutRead, "read end of stdout pipe")
|
|
defer tryClose(stdoutWrite, "write end of stdout pipe")
|
|
stderrRead, stderrWrite, err := os.Pipe()
|
|
if err != nil {
|
|
fatalf(ws, "creating stderr pipe: %w", err)
|
|
return
|
|
}
|
|
defer tryClose(stderrRead, "read end of stderr pipe")
|
|
defer tryClose(stderrWrite, "write end of stderr pipe")
|
|
// Spawn subprocess
|
|
proc, err := os.StartProcess(cmdline[0], cmdline, &os.ProcAttr{
|
|
Files: []*os.File{stdinRead, stdoutWrite, stderrWrite},
|
|
})
|
|
if err != nil {
|
|
fatalf(ws, "spawning process: %w", err)
|
|
return
|
|
}
|
|
// Setup a way for other goroutines to report a fatal error,
|
|
// use increased capacity to avoid blockage with large number
|
|
// of write callsites
|
|
doneChan := make(chan struct{}, 10)
|
|
// Setup channels and variables to monitor process state
|
|
waitChan := make(chan struct{}, 1)
|
|
state := (*os.ProcessState)(nil)
|
|
// Monitor the process to see when it exits
|
|
go func() {
|
|
s, err := proc.Wait()
|
|
if err != nil {
|
|
fatalf(ws, "waiting for process to exit: %w", err)
|
|
} else {
|
|
state = s
|
|
}
|
|
waitChan <- struct{}{}
|
|
doneChan <- struct{}{}
|
|
}()
|
|
// Arrange to send information about the process exit status
|
|
// if we have obtained it by the time we return
|
|
defer func() {
|
|
if state != nil {
|
|
status := state.ExitCode()
|
|
send(ws, &serverMessage{
|
|
Event: "exit",
|
|
ExitStatus: &status,
|
|
})
|
|
}
|
|
}()
|
|
// Arrange for subprocess to be killed when we exit
|
|
defer func() {
|
|
// See if process has already exited or is about to
|
|
select {
|
|
case <-waitChan:
|
|
return
|
|
case <-time.NewTimer(500 * time.Millisecond).C:
|
|
//
|
|
}
|
|
// Try killing the process by closing stdin
|
|
tryClose(stdinWrite, "stdin to child")
|
|
select {
|
|
case <-waitChan:
|
|
return
|
|
case <-time.NewTimer(500 * time.Millisecond).C:
|
|
//
|
|
}
|
|
// Try killing the process with SIGTERM, SIGINT, then
|
|
// finally SIGKILL
|
|
for _, sig := range []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL} {
|
|
err = proc.Signal(sig)
|
|
if err != nil {
|
|
logWarnf("sending %s to child: %w", sig.String(), err)
|
|
}
|
|
select {
|
|
case <-waitChan:
|
|
return
|
|
case <-time.NewTimer(500 * time.Millisecond).C:
|
|
//
|
|
}
|
|
}
|
|
// We are unable to kill the process
|
|
fatalf(ws, "unable to kill child")
|
|
}()
|
|
// Close our copies of pipe ends passed to subprocess
|
|
err = stdinRead.Close()
|
|
if err != nil {
|
|
fatalf(ws, "closing read end of stdin pipe from parent")
|
|
return
|
|
}
|
|
err = stdoutWrite.Close()
|
|
if err != nil {
|
|
fatalf(ws, "closing write end of stdout pipe from parent")
|
|
return
|
|
}
|
|
err = stderrWrite.Close()
|
|
if err != nil {
|
|
fatalf(ws, "closing write end of stderr pipe from parent")
|
|
return
|
|
}
|
|
// Handle received messages from client
|
|
stdinChan := make(chan []byte)
|
|
go func() {
|
|
handleClientMessages(ws, stdinChan)
|
|
doneChan <- struct{}{}
|
|
}()
|
|
go func() {
|
|
for data := range stdinChan {
|
|
_, err := stdinWrite.Write(data)
|
|
if err != nil {
|
|
warnf(ws, "writing to stdin: %w", err)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
// Send regular pings to ensure we get regular pongs to
|
|
// satisfy the read deadline on handleClientMessages
|
|
pingDoneChan := make(chan struct{}, 1)
|
|
defer func() {
|
|
pingDoneChan <- struct{}{}
|
|
}()
|
|
go func() {
|
|
ticker := time.NewTicker(pingPeriod)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeDeadline))
|
|
if err != nil {
|
|
logErrorf("sending ping: %w", err)
|
|
doneChan <- struct{}{}
|
|
return
|
|
}
|
|
case <-pingDoneChan:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
// Proxy stdout and stderr back to client
|
|
go func() {
|
|
for {
|
|
buf := make([]byte, 1024)
|
|
nr, err := stdoutRead.Read(buf)
|
|
if err != nil {
|
|
warnf(ws, "reading from stdout: %w", err)
|
|
return
|
|
}
|
|
if nr == 0 {
|
|
continue
|
|
}
|
|
data, err := json.Marshal(&serverMessage{
|
|
Event: "stdout",
|
|
Data: buf[:nr],
|
|
})
|
|
if err != nil {
|
|
fatalf(ws, "wrapping stdout in json: %w", err)
|
|
doneChan <- struct{}{}
|
|
return
|
|
}
|
|
ws.SetWriteDeadline(time.Now().Add(writeDeadline))
|
|
err = ws.WriteMessage(websocket.TextMessage, data)
|
|
if err != nil {
|
|
fatalf(ws, "sending message: %w", err)
|
|
doneChan <- struct{}{}
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
// Wait until either process is exited or a websocket
|
|
// operation fails
|
|
<-doneChan
|
|
// Process and websocket will be cleaned up after return
|
|
return
|
|
}
|
|
|
|
func main() {
|
|
port := os.Getenv("RIJU_AGENT_PORT")
|
|
if port == "" {
|
|
port = "869"
|
|
}
|
|
host := os.Getenv("RIJU_AGENT_HOST")
|
|
if host == "" {
|
|
host = "0.0.0.0"
|
|
}
|
|
fmt.Printf("Listening on http://%s:%s\n", host, port)
|
|
err := http.ListenAndServe(fmt.Sprintf("%s:%s", host, port), http.HandlerFunc(handler))
|
|
if err != nil {
|
|
logError(err)
|
|
os.Exit(1)
|
|
}
|
|
}
|