v2 for server agent
This commit is contained in:
parent
d6134d470c
commit
1477281cff
|
@ -11,3 +11,4 @@ env.yaml
|
|||
node_modules
|
||||
out
|
||||
sentinel.h
|
||||
agent/agent
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
func logWarn(err error) {
|
||||
log.Println(err.Error())
|
||||
}
|
||||
|
||||
func logWarnf(format string, arg ...interface{}) {
|
||||
logWarn(fmt.Errorf(format, arg...))
|
||||
}
|
||||
|
||||
func logError(err error) {
|
||||
log.Println(err.Error())
|
||||
}
|
||||
|
||||
func logErrorf(format string, arg ...interface{}) {
|
||||
logError(fmt.Errorf(format, arg...))
|
||||
}
|
||||
|
||||
func tryClose(obj io.Closer, objName string) {
|
||||
err := obj.Close()
|
||||
if err != nil {
|
||||
logErrorf("error closing %s: %w", objName, err)
|
||||
}
|
||||
}
|
364
agent/main.go
364
agent/main.go
|
@ -3,11 +3,9 @@ package main
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"syscall"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
@ -31,117 +29,64 @@ type serverMessage struct {
|
|||
ExitStatus *int `json:"exitStatus,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
pingPeriod = 15 * time.Second
|
||||
pongWait = 10 * time.Second
|
||||
writeDeadline = 1 * time.Second
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func logWarn(err error) {
|
||||
log.Println(err.Error())
|
||||
func closeWs(ms *ManagedWebsocket) {
|
||||
ms.CloseChan <- struct{}{}
|
||||
}
|
||||
|
||||
func logWarnf(format string, arg ...interface{}) {
|
||||
logWarn(fmt.Errorf(format, arg...))
|
||||
}
|
||||
|
||||
func logError(err error) {
|
||||
log.Println(err.Error())
|
||||
}
|
||||
|
||||
func logErrorf(format string, arg ...interface{}) {
|
||||
logError(fmt.Errorf(format, arg...))
|
||||
}
|
||||
|
||||
func tryClose(obj io.Closer, objName string) {
|
||||
err := obj.Close()
|
||||
if err != nil {
|
||||
logErrorf("error closing %s: %w", objName, err)
|
||||
}
|
||||
}
|
||||
|
||||
func closeWs(ws *websocket.Conn) {
|
||||
err := ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
if err != nil {
|
||||
logErrorf("sending close message: %w", err)
|
||||
}
|
||||
tryClose(ws, "websocket")
|
||||
}
|
||||
|
||||
func send(ws *websocket.Conn, msg *serverMessage) {
|
||||
func send(ms *ManagedWebsocket, msg *serverMessage) {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
logErrorf("marshaling message: %w", err)
|
||||
closeWs(ws)
|
||||
return
|
||||
}
|
||||
err = ws.WriteMessage(websocket.TextMessage, data)
|
||||
if err != nil {
|
||||
logErrorf("sending message: %w", err)
|
||||
closeWs(ws)
|
||||
closeWs(ms)
|
||||
return
|
||||
}
|
||||
ms.OutgoingChan <- data
|
||||
}
|
||||
|
||||
func fatal(ws *websocket.Conn, err error) {
|
||||
send(ws, &serverMessage{
|
||||
func fatal(ms *ManagedWebsocket, err error) {
|
||||
send(ms, &serverMessage{
|
||||
Event: "fatal",
|
||||
Text: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func fatalf(ws *websocket.Conn, format string, arg ...interface{}) {
|
||||
fatal(ws, fmt.Errorf(format, arg...))
|
||||
func fatalf(ms *ManagedWebsocket, format string, arg ...interface{}) {
|
||||
fatal(ms, fmt.Errorf(format, arg...))
|
||||
}
|
||||
|
||||
func warn(ws *websocket.Conn, err error) {
|
||||
send(ws, &serverMessage{
|
||||
func warn(ms *ManagedWebsocket, err error) {
|
||||
send(ms, &serverMessage{
|
||||
Event: "warn",
|
||||
Text: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func warnf(ws *websocket.Conn, format string, arg ...interface{}) {
|
||||
warn(ws, fmt.Errorf(format, arg...))
|
||||
func warnf(ms *ManagedWebsocket, format string, arg ...interface{}) {
|
||||
warn(ms, fmt.Errorf(format, arg...))
|
||||
}
|
||||
|
||||
func handleClientMessages(ws *websocket.Conn, stdinChan chan<- []byte) {
|
||||
// Close channel after we exit
|
||||
defer close(stdinChan)
|
||||
// Stop processing reads some time after we stop receiving
|
||||
// timely responses to our pings.
|
||||
ws.SetReadDeadline(time.Now().Add(pongWait))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
// Read data and dispatch appropriately. Return on timeout or
|
||||
// error. Caller is responsible for cleanup.
|
||||
for {
|
||||
msgtype, data, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
fatalf(ws, "reading message: %w", err)
|
||||
return
|
||||
func handleClientMessages(ms *ManagedWebsocket) <-chan []byte {
|
||||
stdinChan := make(chan []byte, 16)
|
||||
go func() {
|
||||
defer close(stdinChan)
|
||||
for data := range ms.IncomingChan {
|
||||
msg := clientMessage{}
|
||||
err := json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
fatalf(ms, "parsing json: %w", err)
|
||||
return
|
||||
}
|
||||
switch msg.Event {
|
||||
case "stdin":
|
||||
stdinChan <- msg.Data
|
||||
default:
|
||||
logWarnf("received unknown event type %s", msg.Event)
|
||||
}
|
||||
}
|
||||
if msgtype != websocket.TextMessage {
|
||||
fatalf(ws, "received non-text message type %d", msgtype)
|
||||
return
|
||||
}
|
||||
msg := clientMessage{}
|
||||
err = json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
fatalf(ws, "parsing json: %w", err)
|
||||
return
|
||||
}
|
||||
switch msg.Event {
|
||||
case "stdin":
|
||||
stdinChan <- msg.Data
|
||||
default:
|
||||
logWarnf("received unknown event type %s", msg.Event)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return stdinChan
|
||||
}
|
||||
|
||||
// https://github.com/gorilla/websocket/blob/76ecc29eff79f0cedf70c530605e486fc32131d1/examples/command/main.go
|
||||
|
@ -152,197 +97,106 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
|||
logErrorf("upgrading connection: %w", err)
|
||||
return
|
||||
}
|
||||
// Close websocket on error or when we exit
|
||||
defer closeWs(ws)
|
||||
// Set up channels to handle incoming and outgoing websocket
|
||||
// messages more conveniently, and also to handle closing the
|
||||
// websocket on error or when we ask.
|
||||
ms := &ManagedWebsocket{
|
||||
Socket: ws,
|
||||
|
||||
MessageType: websocket.TextMessage,
|
||||
PingPeriod: 5 * time.Second,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
ms.Init()
|
||||
// Ensure that websocket will be closed eventually when we
|
||||
// exit.
|
||||
defer closeWs(ms)
|
||||
// Parse request query parameters; do this after upgrading to
|
||||
// websocket so that we can send errors back on the websocket
|
||||
// which is easier for clients to parse
|
||||
cmdline := r.URL.Query()["cmdline"]
|
||||
if len(cmdline) == 0 {
|
||||
fatalf(ws, "cmdline query parameter missing")
|
||||
fatalf(ms, "cmdline query parameter missing")
|
||||
return
|
||||
}
|
||||
// Create pipes for communicating with subprocess
|
||||
stdinRead, stdinWrite, err := os.Pipe()
|
||||
binary, err := exec.LookPath(cmdline[0])
|
||||
if err != nil {
|
||||
fatalf(ws, "creating stdin pipe: %w", err)
|
||||
fatalf(ms, "searching for executable: %w", err)
|
||||
return
|
||||
}
|
||||
defer tryClose(stdinRead, "read end of stdin pipe")
|
||||
defer tryClose(stdinWrite, "write end of stdin pipe")
|
||||
stdoutRead, stdoutWrite, err := os.Pipe()
|
||||
if err != nil {
|
||||
fatalf(ws, "creating stdout pipe: %w", err)
|
||||
return
|
||||
}
|
||||
defer tryClose(stdoutRead, "read end of stdout pipe")
|
||||
defer tryClose(stdoutWrite, "write end of stdout pipe")
|
||||
stderrRead, stderrWrite, err := os.Pipe()
|
||||
if err != nil {
|
||||
fatalf(ws, "creating stderr pipe: %w", err)
|
||||
return
|
||||
}
|
||||
defer tryClose(stderrRead, "read end of stderr pipe")
|
||||
defer tryClose(stderrWrite, "write end of stderr pipe")
|
||||
// Spawn subprocess
|
||||
proc, err := os.StartProcess(cmdline[0], cmdline, &os.ProcAttr{
|
||||
Files: []*os.File{stdinRead, stdoutWrite, stderrWrite},
|
||||
})
|
||||
mp, err := NewManagedProcess(binary, cmdline, nil)
|
||||
if err != nil {
|
||||
fatalf(ws, "spawning process: %w", err)
|
||||
fatalf(ms, "spawning process: %w", err)
|
||||
return
|
||||
}
|
||||
// Setup a way for other goroutines to report a fatal error,
|
||||
// use increased capacity to avoid blockage with large number
|
||||
// of write callsites
|
||||
doneChan := make(chan struct{}, 10)
|
||||
// Setup channels and variables to monitor process state
|
||||
waitChan := make(chan struct{}, 1)
|
||||
state := (*os.ProcessState)(nil)
|
||||
// Monitor the process to see when it exits
|
||||
go func() {
|
||||
s, err := proc.Wait()
|
||||
if err != nil {
|
||||
fatalf(ws, "waiting for process to exit: %w", err)
|
||||
} else {
|
||||
state = s
|
||||
}
|
||||
waitChan <- struct{}{}
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
// Arrange to send information about the process exit status
|
||||
// if we have obtained it by the time we return
|
||||
// Ensure eventual process termination
|
||||
defer func() {
|
||||
if state != nil {
|
||||
status := state.ExitCode()
|
||||
send(ws, &serverMessage{
|
||||
Event: "exit",
|
||||
ExitStatus: &status,
|
||||
})
|
||||
}
|
||||
mp.CloseChan <- struct{}{}
|
||||
}()
|
||||
// Arrange for subprocess to be killed when we exit
|
||||
defer func() {
|
||||
// See if process has already exited or is about to
|
||||
select {
|
||||
case <-waitChan:
|
||||
return
|
||||
case <-time.NewTimer(500 * time.Millisecond).C:
|
||||
//
|
||||
}
|
||||
// Try killing the process by closing stdin
|
||||
tryClose(stdinWrite, "stdin to child")
|
||||
select {
|
||||
case <-waitChan:
|
||||
return
|
||||
case <-time.NewTimer(500 * time.Millisecond).C:
|
||||
//
|
||||
}
|
||||
// Try killing the process with SIGTERM, SIGINT, then
|
||||
// finally SIGKILL
|
||||
for _, sig := range []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL} {
|
||||
err = proc.Signal(sig)
|
||||
if err != nil {
|
||||
logWarnf("sending %s to child: %w", sig.String(), err)
|
||||
}
|
||||
select {
|
||||
case <-waitChan:
|
||||
return
|
||||
case <-time.NewTimer(500 * time.Millisecond).C:
|
||||
//
|
||||
}
|
||||
}
|
||||
// We are unable to kill the process
|
||||
fatalf(ws, "unable to kill child")
|
||||
}()
|
||||
// Close our copies of pipe ends passed to subprocess
|
||||
err = stdinRead.Close()
|
||||
if err != nil {
|
||||
fatalf(ws, "closing read end of stdin pipe from parent")
|
||||
return
|
||||
}
|
||||
err = stdoutWrite.Close()
|
||||
if err != nil {
|
||||
fatalf(ws, "closing write end of stdout pipe from parent")
|
||||
return
|
||||
}
|
||||
err = stderrWrite.Close()
|
||||
if err != nil {
|
||||
fatalf(ws, "closing write end of stderr pipe from parent")
|
||||
return
|
||||
}
|
||||
// Handle received messages from client
|
||||
stdinChan := make(chan []byte)
|
||||
go func() {
|
||||
handleClientMessages(ws, stdinChan)
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
for data := range stdinChan {
|
||||
_, err := stdinWrite.Write(data)
|
||||
for data := range ms.IncomingChan {
|
||||
msg := clientMessage{}
|
||||
err := json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
warnf(ws, "writing to stdin: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Send regular pings to ensure we get regular pongs to
|
||||
// satisfy the read deadline on handleClientMessages
|
||||
pingDoneChan := make(chan struct{}, 1)
|
||||
defer func() {
|
||||
pingDoneChan <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeDeadline))
|
||||
if err != nil {
|
||||
logErrorf("sending ping: %w", err)
|
||||
doneChan <- struct{}{}
|
||||
return
|
||||
}
|
||||
case <-pingDoneChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Proxy stdout and stderr back to client
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, 1024)
|
||||
nr, err := stdoutRead.Read(buf)
|
||||
if err != nil {
|
||||
warnf(ws, "reading from stdout: %w", err)
|
||||
return
|
||||
}
|
||||
if nr == 0 {
|
||||
warnf(ms, "parsing json: %w", err)
|
||||
continue
|
||||
}
|
||||
data, err := json.Marshal(&serverMessage{
|
||||
Event: "stdout",
|
||||
Data: buf[:nr],
|
||||
})
|
||||
if err != nil {
|
||||
fatalf(ws, "wrapping stdout in json: %w", err)
|
||||
doneChan <- struct{}{}
|
||||
return
|
||||
}
|
||||
ws.SetWriteDeadline(time.Now().Add(writeDeadline))
|
||||
err = ws.WriteMessage(websocket.TextMessage, data)
|
||||
if err != nil {
|
||||
fatalf(ws, "sending message: %w", err)
|
||||
doneChan <- struct{}{}
|
||||
return
|
||||
switch msg.Event {
|
||||
case "stdin":
|
||||
mp.StdinChan <- msg.Data
|
||||
default:
|
||||
logWarnf("received unknown event type %s", msg.Event)
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Wait until either process is exited or a websocket
|
||||
// operation fails
|
||||
<-doneChan
|
||||
// Process and websocket will be cleaned up after return
|
||||
// Proxy stdout and stderr from subprocess
|
||||
go func() {
|
||||
for data := range mp.StdoutChan {
|
||||
msg, err := json.Marshal(&serverMessage{
|
||||
Event: "stdout",
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
warnf(ms, "wrapping stdout in json: %w", err)
|
||||
return
|
||||
}
|
||||
ms.OutgoingChan <- msg
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for data := range mp.StderrChan {
|
||||
msg, err := json.Marshal(&serverMessage{
|
||||
Event: "stderr",
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
warnf(ms, "wrapping stderr in json: %w", err)
|
||||
return
|
||||
}
|
||||
ms.OutgoingChan <- msg
|
||||
}
|
||||
}()
|
||||
// Send info about process exit status
|
||||
exitChan2 := make(chan struct{}, 16)
|
||||
go func() {
|
||||
for status := range mp.ExitChan {
|
||||
exitChan2 <- struct{}{}
|
||||
code := status.ExitCode()
|
||||
send(ms, &serverMessage{
|
||||
Event: "exit",
|
||||
ExitStatus: &code,
|
||||
})
|
||||
}
|
||||
}()
|
||||
// Wait until one of subprocess or websocket exits. The other
|
||||
// one will be cleaned up on return.
|
||||
select {
|
||||
case <-exitChan2:
|
||||
case <-ms.ClosedChan:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type managedProcess struct {
|
||||
proc *os.Process
|
||||
|
||||
stdinRead *os.File
|
||||
stdinWrite *os.File
|
||||
stdoutRead *os.File
|
||||
stdoutWrite *os.File
|
||||
stderrRead *os.File
|
||||
stderrWrite *os.File
|
||||
|
||||
internalExitChan chan struct{}
|
||||
|
||||
StdinChan chan []byte
|
||||
StdoutChan chan []byte
|
||||
StderrChan chan []byte
|
||||
ExitChan chan *os.ProcessState
|
||||
CloseChan chan struct{}
|
||||
}
|
||||
|
||||
func NewManagedProcess(name string, argv []string, attr *os.ProcAttr) (*managedProcess, error) {
|
||||
mp := &managedProcess{
|
||||
StdinChan: make(chan []byte, 16),
|
||||
StdoutChan: make(chan []byte, 16),
|
||||
StderrChan: make(chan []byte, 16),
|
||||
CloseChan: make(chan struct{}, 16),
|
||||
}
|
||||
done := false
|
||||
go mp.handleClose()
|
||||
defer func() {
|
||||
if !done {
|
||||
mp.CloseChan <- struct{}{}
|
||||
}
|
||||
}()
|
||||
var err error
|
||||
mp.stdinRead, mp.stdinWrite, err = os.Pipe()
|
||||
if err != nil {
|
||||
return mp, fmt.Errorf("creating stdin pipe: %w", err)
|
||||
}
|
||||
mp.stdoutRead, mp.stdoutWrite, err = os.Pipe()
|
||||
if err != nil {
|
||||
return mp, fmt.Errorf("creating stdout pipe: %w", err)
|
||||
}
|
||||
mp.stderrRead, mp.stderrWrite, err = os.Pipe()
|
||||
if err != nil {
|
||||
return mp, fmt.Errorf("creating stderr pipe: %w", err)
|
||||
}
|
||||
newAttr := &os.ProcAttr{}
|
||||
if attr != nil {
|
||||
*newAttr = *attr
|
||||
}
|
||||
if len(newAttr.Files) < 3 {
|
||||
newAttr.Files = append(newAttr.Files, make([]*os.File, 3-len(newAttr.Files))...)
|
||||
newAttr.Files[0] = mp.stdinRead
|
||||
newAttr.Files[1] = mp.stdoutWrite
|
||||
newAttr.Files[2] = mp.stderrWrite
|
||||
}
|
||||
mp.proc, err = os.StartProcess(name, argv, newAttr)
|
||||
if err != nil {
|
||||
return mp, fmt.Errorf("spawning process: %w", err)
|
||||
}
|
||||
go mp.handleWait()
|
||||
go mp.handleInput(mp.StdinChan, mp.stdinWrite, "stdin")
|
||||
go mp.handleOutput(mp.StdoutChan, mp.stdoutRead, "stdout")
|
||||
go mp.handleOutput(mp.StderrChan, mp.stderrRead, "stderr")
|
||||
done = true
|
||||
return mp, nil
|
||||
}
|
||||
|
||||
func (mp *managedProcess) handleInput(ch chan []byte, f *os.File, name string) {
|
||||
for data := range ch {
|
||||
nw, err := f.Write(data)
|
||||
if err != nil {
|
||||
logWarnf("writing to %s: got error after %d of %d byte(s): %w", name, nw, len(data), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *managedProcess) handleOutput(ch chan []byte, f *os.File, name string) {
|
||||
for {
|
||||
buf := make([]byte, 1024)
|
||||
nr, err := f.Read(buf)
|
||||
if err != nil {
|
||||
logWarnf("reading from %s: got error after %d byte(s): %w", name, nr, err)
|
||||
return
|
||||
}
|
||||
if nr == 0 {
|
||||
continue
|
||||
}
|
||||
ch <- buf[:nr]
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *managedProcess) handleWait() {
|
||||
s, err := mp.proc.Wait()
|
||||
if err != nil {
|
||||
logErrorf("waiting on process: %w", err)
|
||||
}
|
||||
mp.internalExitChan <- struct{}{}
|
||||
mp.ExitChan <- s
|
||||
}
|
||||
|
||||
func (mp *managedProcess) killProc() {
|
||||
// See if process has already exited or is about to
|
||||
select {
|
||||
case <-mp.internalExitChan:
|
||||
return
|
||||
case <-time.NewTimer(500 * time.Millisecond).C:
|
||||
//
|
||||
}
|
||||
// Try killing the process by closing stdin
|
||||
mp.stdinWrite.Close()
|
||||
select {
|
||||
case <-mp.internalExitChan:
|
||||
return
|
||||
case <-time.NewTimer(500 * time.Millisecond).C:
|
||||
//
|
||||
}
|
||||
// Try killing the process with SIGTERM, SIGINT, then
|
||||
// finally SIGKILL
|
||||
for _, sig := range []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL} {
|
||||
err := mp.proc.Signal(sig)
|
||||
if err != nil {
|
||||
logErrorf("sending %s to child: %w", sig.String(), err)
|
||||
}
|
||||
select {
|
||||
case <-mp.internalExitChan:
|
||||
return
|
||||
case <-time.NewTimer(500 * time.Millisecond).C:
|
||||
//
|
||||
}
|
||||
}
|
||||
// We are unable to kill the process
|
||||
logErrorf("unable to kill child process (pid %d)", mp.proc.Pid)
|
||||
}
|
||||
|
||||
func (mp *managedProcess) handleClose() {
|
||||
<-mp.CloseChan
|
||||
for _, p := range []*os.File{
|
||||
mp.stdinRead, mp.stdinWrite,
|
||||
mp.stdoutRead, mp.stdoutWrite,
|
||||
mp.stderrRead, mp.stderrWrite,
|
||||
} {
|
||||
if p != nil {
|
||||
p.Close()
|
||||
}
|
||||
}
|
||||
if mp.proc != nil {
|
||||
//
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type ManagedWebsocket struct {
|
||||
Socket *websocket.Conn
|
||||
|
||||
MessageType int
|
||||
PingPeriod time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
|
||||
IncomingChan chan []byte
|
||||
OutgoingChan chan []byte
|
||||
CloseChan chan struct{}
|
||||
ClosedChan chan struct{}
|
||||
}
|
||||
|
||||
func (m *ManagedWebsocket) handleIncoming() {
|
||||
pongChan := make(chan struct{}, 16)
|
||||
m.Socket.SetPongHandler(func(string) error {
|
||||
pongChan <- struct{}{}
|
||||
return nil
|
||||
})
|
||||
msgChan := make(chan []byte, 16)
|
||||
go func() {
|
||||
defer close(msgChan)
|
||||
for {
|
||||
msgtype, data, err := m.Socket.ReadMessage()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
logErrorf("reading message: %w", err)
|
||||
}
|
||||
m.Socket.Close()
|
||||
return
|
||||
}
|
||||
if msgtype != m.MessageType {
|
||||
logWarnf("ignoring message of unexpected type %d", msgtype)
|
||||
continue
|
||||
}
|
||||
msgChan <- data
|
||||
|
||||
}
|
||||
}()
|
||||
for {
|
||||
m.Socket.SetReadDeadline(time.Now().Add(m.ReadTimeout))
|
||||
var msgtype int
|
||||
var msgdata []byte
|
||||
select {
|
||||
case <-pongChan:
|
||||
msgtype = websocket.PongMessage
|
||||
case data := <-msgChan:
|
||||
msgtype = m.MessageType
|
||||
msgdata = data
|
||||
}
|
||||
if msgtype != m.MessageType {
|
||||
continue
|
||||
}
|
||||
m.OutgoingChan <- msgdata
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ManagedWebsocket) handleOutgoing() {
|
||||
pingTicker := time.NewTicker(m.PingPeriod)
|
||||
defer pingTicker.Stop()
|
||||
defer func() {
|
||||
m.ClosedChan <- struct{}{}
|
||||
}()
|
||||
for {
|
||||
var msgtype int
|
||||
var msgdata []byte
|
||||
select {
|
||||
case <-pingTicker.C:
|
||||
msgtype = websocket.PingMessage
|
||||
msgdata = []byte{}
|
||||
case data := <-m.OutgoingChan:
|
||||
msgtype = m.MessageType
|
||||
msgdata = data
|
||||
case <-m.CloseChan:
|
||||
msgtype = websocket.CloseMessage
|
||||
msgdata = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
|
||||
}
|
||||
wd := time.Now().Add(m.WriteTimeout)
|
||||
m.Socket.SetWriteDeadline(wd)
|
||||
err := m.Socket.WriteMessage(msgtype, msgdata)
|
||||
if err != nil {
|
||||
logErrorf("writing message: %w", err)
|
||||
m.Socket.Close()
|
||||
return
|
||||
}
|
||||
if msgtype == websocket.CloseMessage {
|
||||
time.Sleep(wd.Sub(time.Now()))
|
||||
m.Socket.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ManagedWebsocket) Init() {
|
||||
m.IncomingChan = make(chan []byte, 16)
|
||||
m.OutgoingChan = make(chan []byte, 16)
|
||||
m.CloseChan = make(chan struct{}, 16)
|
||||
m.ClosedChan = make(chan struct{}, 16)
|
||||
|
||||
go m.handleIncoming()
|
||||
go m.handleOutgoing()
|
||||
}
|
Loading…
Reference in New Issue