diff --git a/.gitignore b/.gitignore index 207808b..3934bba 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ env.yaml node_modules out sentinel.h +agent/agent diff --git a/agent/logging.go b/agent/logging.go new file mode 100644 index 0000000..7906905 --- /dev/null +++ b/agent/logging.go @@ -0,0 +1,30 @@ +package main + +import ( + "fmt" + "io" + "log" +) + +func logWarn(err error) { + log.Println(err.Error()) +} + +func logWarnf(format string, arg ...interface{}) { + logWarn(fmt.Errorf(format, arg...)) +} + +func logError(err error) { + log.Println(err.Error()) +} + +func logErrorf(format string, arg ...interface{}) { + logError(fmt.Errorf(format, arg...)) +} + +func tryClose(obj io.Closer, objName string) { + err := obj.Close() + if err != nil { + logErrorf("error closing %s: %w", objName, err) + } +} diff --git a/agent/main.go b/agent/main.go index 7b0dc44..6d27742 100644 --- a/agent/main.go +++ b/agent/main.go @@ -3,11 +3,9 @@ package main import ( "encoding/json" "fmt" - "io" - "log" "net/http" "os" - "syscall" + "os/exec" "time" "github.com/gorilla/websocket" @@ -31,117 +29,64 @@ type serverMessage struct { ExitStatus *int `json:"exitStatus,omitempty"` } -const ( - pingPeriod = 15 * time.Second - pongWait = 10 * time.Second - writeDeadline = 1 * time.Second -) - var upgrader = websocket.Upgrader{} -func logWarn(err error) { - log.Println(err.Error()) +func closeWs(ms *ManagedWebsocket) { + ms.CloseChan <- struct{}{} } -func logWarnf(format string, arg ...interface{}) { - logWarn(fmt.Errorf(format, arg...)) -} - -func logError(err error) { - log.Println(err.Error()) -} - -func logErrorf(format string, arg ...interface{}) { - logError(fmt.Errorf(format, arg...)) -} - -func tryClose(obj io.Closer, objName string) { - err := obj.Close() - if err != nil { - logErrorf("error closing %s: %w", objName, err) - } -} - -func closeWs(ws *websocket.Conn) { - err := ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { - logErrorf("sending close message: %w", err) - } - tryClose(ws, "websocket") -} - -func send(ws *websocket.Conn, msg *serverMessage) { +func send(ms *ManagedWebsocket, msg *serverMessage) { data, err := json.Marshal(msg) if err != nil { logErrorf("marshaling message: %w", err) - closeWs(ws) - return - } - err = ws.WriteMessage(websocket.TextMessage, data) - if err != nil { - logErrorf("sending message: %w", err) - closeWs(ws) + closeWs(ms) return } + ms.OutgoingChan <- data } -func fatal(ws *websocket.Conn, err error) { - send(ws, &serverMessage{ +func fatal(ms *ManagedWebsocket, err error) { + send(ms, &serverMessage{ Event: "fatal", Text: err.Error(), }) } -func fatalf(ws *websocket.Conn, format string, arg ...interface{}) { - fatal(ws, fmt.Errorf(format, arg...)) +func fatalf(ms *ManagedWebsocket, format string, arg ...interface{}) { + fatal(ms, fmt.Errorf(format, arg...)) } -func warn(ws *websocket.Conn, err error) { - send(ws, &serverMessage{ +func warn(ms *ManagedWebsocket, err error) { + send(ms, &serverMessage{ Event: "warn", Text: err.Error(), }) } -func warnf(ws *websocket.Conn, format string, arg ...interface{}) { - warn(ws, fmt.Errorf(format, arg...)) +func warnf(ms *ManagedWebsocket, format string, arg ...interface{}) { + warn(ms, fmt.Errorf(format, arg...)) } -func handleClientMessages(ws *websocket.Conn, stdinChan chan<- []byte) { - // Close channel after we exit - defer close(stdinChan) - // Stop processing reads some time after we stop receiving - // timely responses to our pings. - ws.SetReadDeadline(time.Now().Add(pongWait)) - ws.SetPongHandler(func(string) error { - ws.SetReadDeadline(time.Now().Add(pongWait)) - return nil - }) - // Read data and dispatch appropriately. Return on timeout or - // error. Caller is responsible for cleanup. - for { - msgtype, data, err := ws.ReadMessage() - if err != nil { - fatalf(ws, "reading message: %w", err) - return +func handleClientMessages(ms *ManagedWebsocket) <-chan []byte { + stdinChan := make(chan []byte, 16) + go func() { + defer close(stdinChan) + for data := range ms.IncomingChan { + msg := clientMessage{} + err := json.Unmarshal(data, &msg) + if err != nil { + fatalf(ms, "parsing json: %w", err) + return + } + switch msg.Event { + case "stdin": + stdinChan <- msg.Data + default: + logWarnf("received unknown event type %s", msg.Event) + } } - if msgtype != websocket.TextMessage { - fatalf(ws, "received non-text message type %d", msgtype) - return - } - msg := clientMessage{} - err = json.Unmarshal(data, &msg) - if err != nil { - fatalf(ws, "parsing json: %w", err) - return - } - switch msg.Event { - case "stdin": - stdinChan <- msg.Data - default: - logWarnf("received unknown event type %s", msg.Event) - } - } + }() + return stdinChan } // https://github.com/gorilla/websocket/blob/76ecc29eff79f0cedf70c530605e486fc32131d1/examples/command/main.go @@ -152,197 +97,106 @@ func handler(w http.ResponseWriter, r *http.Request) { logErrorf("upgrading connection: %w", err) return } - // Close websocket on error or when we exit - defer closeWs(ws) + // Set up channels to handle incoming and outgoing websocket + // messages more conveniently, and also to handle closing the + // websocket on error or when we ask. + ms := &ManagedWebsocket{ + Socket: ws, + + MessageType: websocket.TextMessage, + PingPeriod: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + ms.Init() + // Ensure that websocket will be closed eventually when we + // exit. + defer closeWs(ms) // Parse request query parameters; do this after upgrading to // websocket so that we can send errors back on the websocket // which is easier for clients to parse cmdline := r.URL.Query()["cmdline"] if len(cmdline) == 0 { - fatalf(ws, "cmdline query parameter missing") + fatalf(ms, "cmdline query parameter missing") return } - // Create pipes for communicating with subprocess - stdinRead, stdinWrite, err := os.Pipe() + binary, err := exec.LookPath(cmdline[0]) if err != nil { - fatalf(ws, "creating stdin pipe: %w", err) + fatalf(ms, "searching for executable: %w", err) return } - defer tryClose(stdinRead, "read end of stdin pipe") - defer tryClose(stdinWrite, "write end of stdin pipe") - stdoutRead, stdoutWrite, err := os.Pipe() - if err != nil { - fatalf(ws, "creating stdout pipe: %w", err) - return - } - defer tryClose(stdoutRead, "read end of stdout pipe") - defer tryClose(stdoutWrite, "write end of stdout pipe") - stderrRead, stderrWrite, err := os.Pipe() - if err != nil { - fatalf(ws, "creating stderr pipe: %w", err) - return - } - defer tryClose(stderrRead, "read end of stderr pipe") - defer tryClose(stderrWrite, "write end of stderr pipe") // Spawn subprocess - proc, err := os.StartProcess(cmdline[0], cmdline, &os.ProcAttr{ - Files: []*os.File{stdinRead, stdoutWrite, stderrWrite}, - }) + mp, err := NewManagedProcess(binary, cmdline, nil) if err != nil { - fatalf(ws, "spawning process: %w", err) + fatalf(ms, "spawning process: %w", err) return } - // Setup a way for other goroutines to report a fatal error, - // use increased capacity to avoid blockage with large number - // of write callsites - doneChan := make(chan struct{}, 10) - // Setup channels and variables to monitor process state - waitChan := make(chan struct{}, 1) - state := (*os.ProcessState)(nil) - // Monitor the process to see when it exits - go func() { - s, err := proc.Wait() - if err != nil { - fatalf(ws, "waiting for process to exit: %w", err) - } else { - state = s - } - waitChan <- struct{}{} - doneChan <- struct{}{} - }() - // Arrange to send information about the process exit status - // if we have obtained it by the time we return + // Ensure eventual process termination defer func() { - if state != nil { - status := state.ExitCode() - send(ws, &serverMessage{ - Event: "exit", - ExitStatus: &status, - }) - } + mp.CloseChan <- struct{}{} }() - // Arrange for subprocess to be killed when we exit - defer func() { - // See if process has already exited or is about to - select { - case <-waitChan: - return - case <-time.NewTimer(500 * time.Millisecond).C: - // - } - // Try killing the process by closing stdin - tryClose(stdinWrite, "stdin to child") - select { - case <-waitChan: - return - case <-time.NewTimer(500 * time.Millisecond).C: - // - } - // Try killing the process with SIGTERM, SIGINT, then - // finally SIGKILL - for _, sig := range []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL} { - err = proc.Signal(sig) - if err != nil { - logWarnf("sending %s to child: %w", sig.String(), err) - } - select { - case <-waitChan: - return - case <-time.NewTimer(500 * time.Millisecond).C: - // - } - } - // We are unable to kill the process - fatalf(ws, "unable to kill child") - }() - // Close our copies of pipe ends passed to subprocess - err = stdinRead.Close() - if err != nil { - fatalf(ws, "closing read end of stdin pipe from parent") - return - } - err = stdoutWrite.Close() - if err != nil { - fatalf(ws, "closing write end of stdout pipe from parent") - return - } - err = stderrWrite.Close() - if err != nil { - fatalf(ws, "closing write end of stderr pipe from parent") - return - } // Handle received messages from client - stdinChan := make(chan []byte) go func() { - handleClientMessages(ws, stdinChan) - doneChan <- struct{}{} - }() - go func() { - for data := range stdinChan { - _, err := stdinWrite.Write(data) + for data := range ms.IncomingChan { + msg := clientMessage{} + err := json.Unmarshal(data, &msg) if err != nil { - warnf(ws, "writing to stdin: %w", err) - return - } - } - }() - // Send regular pings to ensure we get regular pongs to - // satisfy the read deadline on handleClientMessages - pingDoneChan := make(chan struct{}, 1) - defer func() { - pingDoneChan <- struct{}{} - }() - go func() { - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - for { - select { - case <-ticker.C: - err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeDeadline)) - if err != nil { - logErrorf("sending ping: %w", err) - doneChan <- struct{}{} - return - } - case <-pingDoneChan: - return - } - } - }() - // Proxy stdout and stderr back to client - go func() { - for { - buf := make([]byte, 1024) - nr, err := stdoutRead.Read(buf) - if err != nil { - warnf(ws, "reading from stdout: %w", err) - return - } - if nr == 0 { + warnf(ms, "parsing json: %w", err) continue } - data, err := json.Marshal(&serverMessage{ - Event: "stdout", - Data: buf[:nr], - }) - if err != nil { - fatalf(ws, "wrapping stdout in json: %w", err) - doneChan <- struct{}{} - return - } - ws.SetWriteDeadline(time.Now().Add(writeDeadline)) - err = ws.WriteMessage(websocket.TextMessage, data) - if err != nil { - fatalf(ws, "sending message: %w", err) - doneChan <- struct{}{} - return + switch msg.Event { + case "stdin": + mp.StdinChan <- msg.Data + default: + logWarnf("received unknown event type %s", msg.Event) } } }() - // Wait until either process is exited or a websocket - // operation fails - <-doneChan - // Process and websocket will be cleaned up after return + // Proxy stdout and stderr from subprocess + go func() { + for data := range mp.StdoutChan { + msg, err := json.Marshal(&serverMessage{ + Event: "stdout", + Data: data, + }) + if err != nil { + warnf(ms, "wrapping stdout in json: %w", err) + return + } + ms.OutgoingChan <- msg + } + }() + go func() { + for data := range mp.StderrChan { + msg, err := json.Marshal(&serverMessage{ + Event: "stderr", + Data: data, + }) + if err != nil { + warnf(ms, "wrapping stderr in json: %w", err) + return + } + ms.OutgoingChan <- msg + } + }() + // Send info about process exit status + exitChan2 := make(chan struct{}, 16) + go func() { + for status := range mp.ExitChan { + exitChan2 <- struct{}{} + code := status.ExitCode() + send(ms, &serverMessage{ + Event: "exit", + ExitStatus: &code, + }) + } + }() + // Wait until one of subprocess or websocket exits. The other + // one will be cleaned up on return. + select { + case <-exitChan2: + case <-ms.ClosedChan: + } return } diff --git a/agent/process.go b/agent/process.go new file mode 100644 index 0000000..ca20c07 --- /dev/null +++ b/agent/process.go @@ -0,0 +1,160 @@ +package main + +import ( + "fmt" + "os" + "syscall" + "time" +) + +type managedProcess struct { + proc *os.Process + + stdinRead *os.File + stdinWrite *os.File + stdoutRead *os.File + stdoutWrite *os.File + stderrRead *os.File + stderrWrite *os.File + + internalExitChan chan struct{} + + StdinChan chan []byte + StdoutChan chan []byte + StderrChan chan []byte + ExitChan chan *os.ProcessState + CloseChan chan struct{} +} + +func NewManagedProcess(name string, argv []string, attr *os.ProcAttr) (*managedProcess, error) { + mp := &managedProcess{ + StdinChan: make(chan []byte, 16), + StdoutChan: make(chan []byte, 16), + StderrChan: make(chan []byte, 16), + CloseChan: make(chan struct{}, 16), + } + done := false + go mp.handleClose() + defer func() { + if !done { + mp.CloseChan <- struct{}{} + } + }() + var err error + mp.stdinRead, mp.stdinWrite, err = os.Pipe() + if err != nil { + return mp, fmt.Errorf("creating stdin pipe: %w", err) + } + mp.stdoutRead, mp.stdoutWrite, err = os.Pipe() + if err != nil { + return mp, fmt.Errorf("creating stdout pipe: %w", err) + } + mp.stderrRead, mp.stderrWrite, err = os.Pipe() + if err != nil { + return mp, fmt.Errorf("creating stderr pipe: %w", err) + } + newAttr := &os.ProcAttr{} + if attr != nil { + *newAttr = *attr + } + if len(newAttr.Files) < 3 { + newAttr.Files = append(newAttr.Files, make([]*os.File, 3-len(newAttr.Files))...) + newAttr.Files[0] = mp.stdinRead + newAttr.Files[1] = mp.stdoutWrite + newAttr.Files[2] = mp.stderrWrite + } + mp.proc, err = os.StartProcess(name, argv, newAttr) + if err != nil { + return mp, fmt.Errorf("spawning process: %w", err) + } + go mp.handleWait() + go mp.handleInput(mp.StdinChan, mp.stdinWrite, "stdin") + go mp.handleOutput(mp.StdoutChan, mp.stdoutRead, "stdout") + go mp.handleOutput(mp.StderrChan, mp.stderrRead, "stderr") + done = true + return mp, nil +} + +func (mp *managedProcess) handleInput(ch chan []byte, f *os.File, name string) { + for data := range ch { + nw, err := f.Write(data) + if err != nil { + logWarnf("writing to %s: got error after %d of %d byte(s): %w", name, nw, len(data), err) + return + } + } +} + +func (mp *managedProcess) handleOutput(ch chan []byte, f *os.File, name string) { + for { + buf := make([]byte, 1024) + nr, err := f.Read(buf) + if err != nil { + logWarnf("reading from %s: got error after %d byte(s): %w", name, nr, err) + return + } + if nr == 0 { + continue + } + ch <- buf[:nr] + } +} + +func (mp *managedProcess) handleWait() { + s, err := mp.proc.Wait() + if err != nil { + logErrorf("waiting on process: %w", err) + } + mp.internalExitChan <- struct{}{} + mp.ExitChan <- s +} + +func (mp *managedProcess) killProc() { + // See if process has already exited or is about to + select { + case <-mp.internalExitChan: + return + case <-time.NewTimer(500 * time.Millisecond).C: + // + } + // Try killing the process by closing stdin + mp.stdinWrite.Close() + select { + case <-mp.internalExitChan: + return + case <-time.NewTimer(500 * time.Millisecond).C: + // + } + // Try killing the process with SIGTERM, SIGINT, then + // finally SIGKILL + for _, sig := range []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL} { + err := mp.proc.Signal(sig) + if err != nil { + logErrorf("sending %s to child: %w", sig.String(), err) + } + select { + case <-mp.internalExitChan: + return + case <-time.NewTimer(500 * time.Millisecond).C: + // + } + } + // We are unable to kill the process + logErrorf("unable to kill child process (pid %d)", mp.proc.Pid) +} + +func (mp *managedProcess) handleClose() { + <-mp.CloseChan + for _, p := range []*os.File{ + mp.stdinRead, mp.stdinWrite, + mp.stdoutRead, mp.stdoutWrite, + mp.stderrRead, mp.stderrWrite, + } { + if p != nil { + p.Close() + } + } + if mp.proc != nil { + // + } +} diff --git a/agent/websocket.go b/agent/websocket.go new file mode 100644 index 0000000..dbe816b --- /dev/null +++ b/agent/websocket.go @@ -0,0 +1,112 @@ +package main + +import ( + "io" + "time" + + "github.com/gorilla/websocket" +) + +type ManagedWebsocket struct { + Socket *websocket.Conn + + MessageType int + PingPeriod time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + + IncomingChan chan []byte + OutgoingChan chan []byte + CloseChan chan struct{} + ClosedChan chan struct{} +} + +func (m *ManagedWebsocket) handleIncoming() { + pongChan := make(chan struct{}, 16) + m.Socket.SetPongHandler(func(string) error { + pongChan <- struct{}{} + return nil + }) + msgChan := make(chan []byte, 16) + go func() { + defer close(msgChan) + for { + msgtype, data, err := m.Socket.ReadMessage() + if err != nil { + if err != io.EOF { + logErrorf("reading message: %w", err) + } + m.Socket.Close() + return + } + if msgtype != m.MessageType { + logWarnf("ignoring message of unexpected type %d", msgtype) + continue + } + msgChan <- data + + } + }() + for { + m.Socket.SetReadDeadline(time.Now().Add(m.ReadTimeout)) + var msgtype int + var msgdata []byte + select { + case <-pongChan: + msgtype = websocket.PongMessage + case data := <-msgChan: + msgtype = m.MessageType + msgdata = data + } + if msgtype != m.MessageType { + continue + } + m.OutgoingChan <- msgdata + } +} + +func (m *ManagedWebsocket) handleOutgoing() { + pingTicker := time.NewTicker(m.PingPeriod) + defer pingTicker.Stop() + defer func() { + m.ClosedChan <- struct{}{} + }() + for { + var msgtype int + var msgdata []byte + select { + case <-pingTicker.C: + msgtype = websocket.PingMessage + msgdata = []byte{} + case data := <-m.OutgoingChan: + msgtype = m.MessageType + msgdata = data + case <-m.CloseChan: + msgtype = websocket.CloseMessage + msgdata = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + } + wd := time.Now().Add(m.WriteTimeout) + m.Socket.SetWriteDeadline(wd) + err := m.Socket.WriteMessage(msgtype, msgdata) + if err != nil { + logErrorf("writing message: %w", err) + m.Socket.Close() + return + } + if msgtype == websocket.CloseMessage { + time.Sleep(wd.Sub(time.Now())) + m.Socket.Close() + return + } + } +} + +func (m *ManagedWebsocket) Init() { + m.IncomingChan = make(chan []byte, 16) + m.OutgoingChan = make(chan []byte, 16) + m.CloseChan = make(chan struct{}, 16) + m.ClosedChan = make(chan struct{}, 16) + + go m.handleIncoming() + go m.handleOutgoing() +}