Add streaming api
This commit is contained in:
parent
7eed1703c5
commit
38378d2c8c
|
|
@ -0,0 +1,95 @@
|
||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/jmattheis/memo/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
writeWait = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type client struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
onClose func(*client)
|
||||||
|
write chan *model.Message
|
||||||
|
userID uint
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClient(conn *websocket.Conn, userID uint, onClose func(*client)) *client {
|
||||||
|
return &client{
|
||||||
|
conn: conn,
|
||||||
|
write: make(chan *model.Message, 1),
|
||||||
|
userID: userID,
|
||||||
|
onClose: onClose,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
func (c *client) Close() {
|
||||||
|
c.once.Do(func() {
|
||||||
|
c.conn.Close()
|
||||||
|
close(c.write)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotifyClose closes the connection and notifies that the connection was closed.
|
||||||
|
func (c *client) NotifyClose() {
|
||||||
|
c.once.Do(func() {
|
||||||
|
c.conn.Close()
|
||||||
|
close(c.write)
|
||||||
|
c.onClose(c)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// startWriteHandler starts listening on the client connection. As we do not need anything from the client,
|
||||||
|
// we ignore incoming messages. Leaves the loop on errors.
|
||||||
|
func (c *client) startReading(pongWait time.Duration) {
|
||||||
|
defer c.NotifyClose()
|
||||||
|
c.conn.SetReadLimit(64)
|
||||||
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||||
|
c.conn.SetPongHandler(func(appData string) error {
|
||||||
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
for {
|
||||||
|
if _, _, err := c.conn.NextReader(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startWriteHandler starts the write loop. The method has the following tasks:
|
||||||
|
// * ping the client in the interval provided as parameter
|
||||||
|
// * write messages send by the channel to the client
|
||||||
|
// * on errors exit the loop
|
||||||
|
func (c *client) startWriteHandler(pingPeriod time.Duration) {
|
||||||
|
pingTicker := time.NewTicker(pingPeriod)
|
||||||
|
defer func() {
|
||||||
|
c.NotifyClose()
|
||||||
|
pingTicker.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case message, ok := <-c.write:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||||
|
if err := c.conn.WriteJSON(message); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-pingTicker.C:
|
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||||
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,102 @@
|
||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/jmattheis/memo/auth"
|
||||||
|
"github.com/jmattheis/memo/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
// The API provides a handler for a WebSocket stream API.
|
||||||
|
type API struct {
|
||||||
|
clients map[uint][]*client
|
||||||
|
lock sync.RWMutex
|
||||||
|
pingPeriod time.Duration
|
||||||
|
pongTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new instance of API.
|
||||||
|
// pingPeriod: is the interval, in which is server sends the a ping to the client.
|
||||||
|
// pongTimeout: is the duration after the connection will be terminated, when the client does not respond with the
|
||||||
|
// pong command.
|
||||||
|
func New(pingPeriod, pongTimeout time.Duration) *API {
|
||||||
|
return &API{
|
||||||
|
clients: make(map[uint][]*client),
|
||||||
|
pingPeriod: pingPeriod,
|
||||||
|
pongTimeout: pingPeriod + pongTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) getClients(userID uint) ([]*client, bool) {
|
||||||
|
a.lock.RLock()
|
||||||
|
defer a.lock.RUnlock()
|
||||||
|
clients, ok := a.clients[userID]
|
||||||
|
return clients, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify notifies the clients with the given userID that a new messages was created.
|
||||||
|
func (a *API) Notify(userID uint, msg *model.Message) {
|
||||||
|
if clients, ok := a.getClients(userID); ok {
|
||||||
|
go func() {
|
||||||
|
for _, c := range clients {
|
||||||
|
c.write <- msg
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) remove(remove *client) {
|
||||||
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
if userIDClients, ok := a.clients[remove.userID]; ok {
|
||||||
|
for i, client := range userIDClients {
|
||||||
|
if client == remove {
|
||||||
|
a.clients[remove.userID] = append(userIDClients[:i], userIDClients[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) register(client *client) {
|
||||||
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
a.clients[client.userID] = append(a.clients[client.userID], client)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle handles incoming requests. First it upgrades the protocol to the WebSocket protocol and then starts listening
|
||||||
|
// for read and writes.
|
||||||
|
func (a *API) Handle(ctx *gin.Context) {
|
||||||
|
conn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := newClient(conn, auth.GetUserID(ctx), a.remove)
|
||||||
|
a.register(client)
|
||||||
|
go client.startReading(a.pongTimeout)
|
||||||
|
go client.startWriteHandler(a.pingPeriod)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes all client connections and stops answering new connections.
|
||||||
|
func (a *API) Close() {
|
||||||
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
|
for _, clients := range a.clients {
|
||||||
|
for _, client := range clients {
|
||||||
|
client.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k := range a.clients {
|
||||||
|
delete(a.clients, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,318 @@
|
||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/bouk/monkey"
|
||||||
|
"github.com/fortytw2/leaktest"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/jmattheis/memo/auth"
|
||||||
|
"github.com/jmattheis/memo/model"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFailureOnNormalHttpRequest(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
|
|
||||||
|
server, api := bootTestServer(staticUserID())
|
||||||
|
defer server.Close()
|
||||||
|
defer api.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(server.URL)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, 400, resp.StatusCode)
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMessageFails(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
|
|
||||||
|
server, api := bootTestServer(func(context *gin.Context) {
|
||||||
|
auth.RegisterAuthentication(context, nil, 1, "")
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
defer api.Close()
|
||||||
|
|
||||||
|
wsURL := wsURL(server.URL)
|
||||||
|
user := testClient(t, wsURL)
|
||||||
|
|
||||||
|
// the server may take some time to register the client
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
client, _ := api.getClients(1)
|
||||||
|
assert.NotEmpty(t, client)
|
||||||
|
|
||||||
|
// try emulate an write error, mostly this should kill the ReadMessage goroutine first but you'll never know.
|
||||||
|
patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteJSON", func(*websocket.Conn, interface{}) error {
|
||||||
|
return errors.New("could not do something")
|
||||||
|
})
|
||||||
|
defer patch.Unpatch()
|
||||||
|
|
||||||
|
api.Notify(1, &model.Message{Message: "HI"})
|
||||||
|
user.expectNoMessage()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWritePingFails(t *testing.T) {
|
||||||
|
defer leaktest.CheckTimeout(t, 10*time.Second)()
|
||||||
|
|
||||||
|
server, api := bootTestServer(staticUserID())
|
||||||
|
defer api.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := wsURL(server.URL)
|
||||||
|
user := testClient(t, wsURL)
|
||||||
|
defer user.conn.Close()
|
||||||
|
|
||||||
|
// the server may take some time to register the client
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
client, _ := api.getClients(1)
|
||||||
|
assert.NotEmpty(t, client)
|
||||||
|
// try emulate an write error, mostly this should kill the ReadMessage gorouting first but you'll never know.
|
||||||
|
patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteMessage", func(*websocket.Conn, int, []byte) error {
|
||||||
|
return errors.New("could not do something")
|
||||||
|
})
|
||||||
|
defer patch.Unpatch()
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second) // waiting for ping
|
||||||
|
|
||||||
|
api.Notify(1, &model.Message{Message: "HI"})
|
||||||
|
user.expectNoMessage()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPing(t *testing.T) {
|
||||||
|
server, api := bootTestServer(staticUserID())
|
||||||
|
defer server.Close()
|
||||||
|
defer api.Close()
|
||||||
|
|
||||||
|
wsURL := wsURL(server.URL)
|
||||||
|
|
||||||
|
user := testClient(t, wsURL)
|
||||||
|
defer user.conn.Close()
|
||||||
|
|
||||||
|
ping := make(chan bool)
|
||||||
|
oldPingHandler := user.conn.PingHandler()
|
||||||
|
user.conn.SetPingHandler(func(appData string) error {
|
||||||
|
err := oldPingHandler(appData)
|
||||||
|
ping <- true
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
expectNoMessage(user)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
assert.Fail(t, "Expected ping but there was one :(")
|
||||||
|
case <-ping:
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
|
||||||
|
expectNoMessage(user)
|
||||||
|
api.Notify(1, &model.Message{Message: "HI"})
|
||||||
|
user.expectMessage(&model.Message{Message: "HI"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseClientOnNotReading(t *testing.T) {
|
||||||
|
server, api := bootTestServer(staticUserID())
|
||||||
|
defer server.Close()
|
||||||
|
defer api.Close()
|
||||||
|
|
||||||
|
wsURL := wsURL(server.URL)
|
||||||
|
|
||||||
|
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
// the server may take some time to register the client
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
clients, _ := api.getClients(1)
|
||||||
|
assert.NotEmpty(t, clients)
|
||||||
|
|
||||||
|
time.Sleep(7 * time.Second)
|
||||||
|
|
||||||
|
clients, _ = api.getClients(1)
|
||||||
|
assert.Empty(t, clients)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageDirectlyAfterConnect(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
|
server, api := bootTestServer(staticUserID())
|
||||||
|
defer server.Close()
|
||||||
|
defer api.Close()
|
||||||
|
|
||||||
|
wsURL := wsURL(server.URL)
|
||||||
|
|
||||||
|
user := testClient(t, wsURL)
|
||||||
|
defer user.conn.Close()
|
||||||
|
// the server may take some time to register the client
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
api.Notify(1, &model.Message{Message: "msg"})
|
||||||
|
user.expectMessage(&model.Message{Message: "msg"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleClients(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
|
userIDs := []uint{1, 1, 1, 2, 2, 3}
|
||||||
|
i := 0
|
||||||
|
server, api := bootTestServer(func(context *gin.Context) {
|
||||||
|
auth.RegisterAuthentication(context, nil, userIDs[i], "")
|
||||||
|
i++
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := wsURL(server.URL)
|
||||||
|
|
||||||
|
userOneIPhone := testClient(t, wsURL)
|
||||||
|
defer userOneIPhone.conn.Close()
|
||||||
|
userOneAndroid := testClient(t, wsURL)
|
||||||
|
defer userOneAndroid.conn.Close()
|
||||||
|
userOneBrowser := testClient(t, wsURL)
|
||||||
|
defer userOneBrowser.conn.Close()
|
||||||
|
userOne := []*testingClient{userOneAndroid, userOneBrowser, userOneIPhone}
|
||||||
|
|
||||||
|
userTwoBrowser := testClient(t, wsURL)
|
||||||
|
defer userTwoBrowser.conn.Close()
|
||||||
|
userTwoAndroid := testClient(t, wsURL)
|
||||||
|
defer userTwoAndroid.conn.Close()
|
||||||
|
userTwo := []*testingClient{userTwoAndroid, userTwoBrowser}
|
||||||
|
|
||||||
|
userThreeAndroid := testClient(t, wsURL)
|
||||||
|
defer userThreeAndroid.conn.Close()
|
||||||
|
userThree := []*testingClient{userThreeAndroid}
|
||||||
|
|
||||||
|
// the server may take some time to register the client
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// there should not be messages at the beginning
|
||||||
|
expectNoMessage(userOne...)
|
||||||
|
expectNoMessage(userTwo...)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
|
||||||
|
api.Notify(1, &model.Message{ID: 1, Message: "hello"})
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
expectMessage(&model.Message{ID: 1, Message: "hello"}, userOne...)
|
||||||
|
expectNoMessage(userTwo...)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
|
||||||
|
api.Notify(2, &model.Message{ID: 2, Message: "there"})
|
||||||
|
expectNoMessage(userOne...)
|
||||||
|
expectMessage(&model.Message{ID: 2, Message: "there"}, userTwo...)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
|
||||||
|
userOneIPhone.conn.Close()
|
||||||
|
|
||||||
|
expectNoMessage(userOne...)
|
||||||
|
expectNoMessage(userTwo...)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
|
||||||
|
api.Notify(1, &model.Message{ID: 3, Message: "how"})
|
||||||
|
expectMessage(&model.Message{ID: 3, Message: "how"}, userOneAndroid, userOneBrowser)
|
||||||
|
expectNoMessage(userOneIPhone)
|
||||||
|
expectNoMessage(userTwo...)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
|
||||||
|
api.Notify(2, &model.Message{ID: 4, Message: "are"})
|
||||||
|
|
||||||
|
expectNoMessage(userOne...)
|
||||||
|
expectMessage(&model.Message{ID: 4, Message: "are"}, userTwo...)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
|
||||||
|
api.Close()
|
||||||
|
|
||||||
|
api.Notify(2, &model.Message{ID: 5, Message: "you"})
|
||||||
|
|
||||||
|
expectNoMessage(userOne...)
|
||||||
|
expectNoMessage(userTwo...)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testClient(t *testing.T, url string) *testingClient {
|
||||||
|
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
readMessages := make(chan model.Message)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
_, payload, err := ws.ReadMessage()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := &model.Message{}
|
||||||
|
json.NewDecoder(bytes.NewBuffer(payload)).Decode(actual)
|
||||||
|
readMessages <- *actual
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &testingClient{conn: ws, readMessage: readMessages, t: t}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testingClient struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
readMessage <-chan model.Message
|
||||||
|
t *testing.T
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testingClient) expectMessage(expected *model.Message) {
|
||||||
|
select {
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
assert.Fail(c.t, "Expected message but none was send :(")
|
||||||
|
case actual := <-c.readMessage:
|
||||||
|
assert.Equal(c.t, *expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectMessage(expected *model.Message, clients ...*testingClient) {
|
||||||
|
for _, client := range clients {
|
||||||
|
client.expectMessage(expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectNoMessage(clients ...*testingClient) {
|
||||||
|
for _, client := range clients {
|
||||||
|
client.expectNoMessage()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testingClient) expectNoMessage() {
|
||||||
|
select {
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
// no message == as expected
|
||||||
|
case msg := <-c.readMessage:
|
||||||
|
assert.Fail(c.t, "Expected NO message but there was one :(", fmt.Sprint(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bootTestServer(handlerFunc gin.HandlerFunc) (*httptest.Server, *API) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(handlerFunc)
|
||||||
|
// all 4 seconds a ping, and the client has 1 second to respond
|
||||||
|
api := New(4*time.Second, 1*time.Second)
|
||||||
|
r.GET("/", api.Handle)
|
||||||
|
server := httptest.NewServer(r)
|
||||||
|
return server, api
|
||||||
|
}
|
||||||
|
|
||||||
|
func wsURL(httpURL string) string {
|
||||||
|
return "ws" + strings.TrimPrefix(httpURL, "http")
|
||||||
|
}
|
||||||
|
|
||||||
|
func staticUserID() gin.HandlerFunc {
|
||||||
|
return func(context *gin.Context) {
|
||||||
|
auth.RegisterAuthentication(context, nil, 1, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue