riju/agent/websocket.go

108 lines
2.2 KiB
Go

package main
import (
"time"
"github.com/gorilla/websocket"
)
type ManagedWebsocket struct {
Socket *websocket.Conn
MessageType int
PingPeriod time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
IncomingChan chan []byte
OutgoingChan chan []byte
CloseChan chan struct{}
ClosedChan chan struct{}
}
func (m *ManagedWebsocket) handleIncoming() {
pongChan := make(chan struct{}, 16)
m.Socket.SetPongHandler(func(string) error {
pongChan <- struct{}{}
return nil
})
msgChan := make(chan []byte, 16)
go func() {
defer close(msgChan)
for {
msgtype, data, err := m.Socket.ReadMessage()
if err != nil {
m.Socket.Close()
return
}
if msgtype != m.MessageType {
logWarnf("ignoring message of unexpected type %d", msgtype)
continue
}
msgChan <- data
}
}()
for {
m.Socket.SetReadDeadline(time.Now().Add(m.ReadTimeout))
var msgtype int
var msgdata []byte
select {
case <-pongChan:
msgtype = websocket.PongMessage
case data := <-msgChan:
msgtype = m.MessageType
msgdata = data
}
if msgtype != m.MessageType {
continue
}
m.IncomingChan <- msgdata
}
}
func (m *ManagedWebsocket) handleOutgoing() {
pingTicker := time.NewTicker(m.PingPeriod)
defer pingTicker.Stop()
defer func() {
m.ClosedChan <- struct{}{}
}()
for {
var msgtype int
var msgdata []byte
select {
case <-pingTicker.C:
msgtype = websocket.PingMessage
msgdata = []byte{}
case data := <-m.OutgoingChan:
msgtype = m.MessageType
msgdata = data
case <-m.CloseChan:
msgtype = websocket.CloseMessage
msgdata = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
}
wd := time.Now().Add(m.WriteTimeout)
m.Socket.SetWriteDeadline(wd)
err := m.Socket.WriteMessage(msgtype, msgdata)
if err != nil {
m.Socket.Close()
return
}
if msgtype == websocket.CloseMessage {
time.Sleep(wd.Sub(time.Now()))
m.Socket.Close()
return
}
}
}
func (m *ManagedWebsocket) Init() {
m.IncomingChan = make(chan []byte, 16)
m.OutgoingChan = make(chan []byte, 16)
m.CloseChan = make(chan struct{}, 16)
m.ClosedChan = make(chan struct{}, 16)
go m.handleIncoming()
go m.handleOutgoing()
}