package stream import ( "sync" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/jmattheis/memo/auth" "github.com/jmattheis/memo/model" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } // The API provides a handler for a WebSocket stream API. type API struct { clients map[uint][]*client lock sync.RWMutex pingPeriod time.Duration pongTimeout time.Duration } // New creates a new instance of API. // pingPeriod: is the interval, in which is server sends the a ping to the client. // pongTimeout: is the duration after the connection will be terminated, when the client does not respond with the // pong command. func New(pingPeriod, pongTimeout time.Duration) *API { return &API{ clients: make(map[uint][]*client), pingPeriod: pingPeriod, pongTimeout: pingPeriod + pongTimeout, } } func (a *API) getClients(userID uint) ([]*client, bool) { a.lock.RLock() defer a.lock.RUnlock() clients, ok := a.clients[userID] return clients, ok } // Notify notifies the clients with the given userID that a new messages was created. func (a *API) Notify(userID uint, msg *model.Message) { if clients, ok := a.getClients(userID); ok { go func() { for _, c := range clients { c.write <- msg } }() } } func (a *API) remove(remove *client) { a.lock.Lock() defer a.lock.Unlock() if userIDClients, ok := a.clients[remove.userID]; ok { for i, client := range userIDClients { if client == remove { a.clients[remove.userID] = append(userIDClients[:i], userIDClients[i+1:]...) break } } } } func (a *API) register(client *client) { a.lock.Lock() defer a.lock.Unlock() a.clients[client.userID] = append(a.clients[client.userID], client) } // Handle handles incoming requests. First it upgrades the protocol to the WebSocket protocol and then starts listening // for read and writes. func (a *API) Handle(ctx *gin.Context) { conn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil) if err != nil { return } client := newClient(conn, auth.GetUserID(ctx), a.remove) a.register(client) go client.startReading(a.pongTimeout) go client.startWriteHandler(a.pingPeriod) } // Close closes all client connections and stops answering new connections. func (a *API) Close() { a.lock.Lock() defer a.lock.Unlock() for _, clients := range a.clients { for _, client := range clients { client.Close() } } for k := range a.clients { delete(a.clients, k) } }