package stream import ( "net/http" "net/url" "regexp" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/gotify/server/v2/auth" "github.com/gotify/server/v2/mode" "github.com/gotify/server/v2/model" ) // shard holds clients for a subset of users to reduce lock contention. type shard struct { clients map[uint]map[string][]*client // userID -> token -> []client (multiple clients can share same token) lock sync.RWMutex } // The API provides a handler for a WebSocket stream API. type API struct { shards []*shard shardCount int pingPeriod time.Duration pongTimeout time.Duration upgrader *websocket.Upgrader channelBuf int // Buffer size for client write channels } // New creates a new instance of API. // pingPeriod: is the interval, in which is server sends the a ping to the client. // pongTimeout: is the duration after the connection will be terminated, when the client does not respond with the // pong command. // shardCount: number of shards for client storage (should be power of 2 for optimal distribution). // readBufferSize: WebSocket read buffer size in bytes. // writeBufferSize: WebSocket write buffer size in bytes. // channelBufferSize: buffer size for client write channels in messages. func New(pingPeriod, pongTimeout time.Duration, allowedWebSocketOrigins []string, shardCount, readBufferSize, writeBufferSize, channelBufferSize int) *API { // Ensure shardCount is at least 1 and is a power of 2 for optimal hashing if shardCount < 1 { shardCount = 256 } // Round up to next power of 2 shardCount = nextPowerOf2(shardCount) shards := make([]*shard, shardCount) for i := range shards { shards[i] = &shard{ clients: make(map[uint]map[string][]*client), } } return &API{ shards: shards, shardCount: shardCount, pingPeriod: pingPeriod, pongTimeout: pingPeriod + pongTimeout, upgrader: newUpgrader(allowedWebSocketOrigins, readBufferSize, writeBufferSize), channelBuf: channelBufferSize, } } // nextPowerOf2 returns the next power of 2 >= n. func nextPowerOf2(n int) int { if n <= 0 { return 1 } n-- n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 n++ return n } // getShard returns the shard for a given userID using fast modulo operation. func (a *API) getShard(userID uint) *shard { return a.shards[userID%uint(a.shardCount)] } // CollectConnectedClientTokens returns all tokens of the connected clients. func (a *API) CollectConnectedClientTokens() []string { var allClients []string for _, shard := range a.shards { shard.lock.RLock() for _, userClients := range shard.clients { for _, tokenClients := range userClients { for _, c := range tokenClients { allClients = append(allClients, c.token) } } } shard.lock.RUnlock() } return uniq(allClients) } // NotifyDeletedUser closes existing connections for the given user. func (a *API) NotifyDeletedUser(userID uint) error { shard := a.getShard(userID) shard.lock.Lock() defer shard.lock.Unlock() if userClients, ok := shard.clients[userID]; ok { for _, tokenClients := range userClients { for _, client := range tokenClients { client.Close() } } delete(shard.clients, userID) } return nil } // NotifyDeletedClient closes existing connections with the given token. func (a *API) NotifyDeletedClient(userID uint, token string) { shard := a.getShard(userID) shard.lock.Lock() defer shard.lock.Unlock() if userClients, ok := shard.clients[userID]; ok { if tokenClients, exists := userClients[token]; exists { for _, client := range tokenClients { client.Close() } delete(userClients, token) // Clean up empty user map if len(userClients) == 0 { delete(shard.clients, userID) } } } } // Notify notifies the clients with the given userID that a new messages was created. func (a *API) Notify(userID uint, msg *model.MessageExternal) { shard := a.getShard(userID) shard.lock.RLock() userClients, ok := shard.clients[userID] if !ok { shard.lock.RUnlock() return } // Create a snapshot of clients to avoid holding lock during send clients := make([]*client, 0) for _, tokenClients := range userClients { for _, c := range tokenClients { clients = append(clients, c) } } shard.lock.RUnlock() // Send messages without holding the lock to prevent blocking other shards // The channel buffer (default 10) helps prevent blocking in most cases for _, c := range clients { c.write <- msg } } func (a *API) remove(c *client) { shard := a.getShard(c.userID) shard.lock.Lock() defer shard.lock.Unlock() if userClients, ok := shard.clients[c.userID]; ok { if tokenClients, exists := userClients[c.token]; exists { // Remove the specific client from the slice for i, client := range tokenClients { if client == c { userClients[c.token] = append(tokenClients[:i], tokenClients[i+1:]...) // Clean up empty token slice if len(userClients[c.token]) == 0 { delete(userClients, c.token) } // Clean up empty user map if len(userClients) == 0 { delete(shard.clients, c.userID) } break } } } } } func (a *API) register(c *client) { shard := a.getShard(c.userID) shard.lock.Lock() defer shard.lock.Unlock() if shard.clients[c.userID] == nil { shard.clients[c.userID] = make(map[string][]*client) } if shard.clients[c.userID][c.token] == nil { shard.clients[c.userID][c.token] = make([]*client, 0, 1) } shard.clients[c.userID][c.token] = append(shard.clients[c.userID][c.token], c) } // Handle handles incoming requests. First it upgrades the protocol to the WebSocket protocol and then starts listening // for read and writes. // swagger:operation GET /stream message streamMessages // // Websocket, return newly created messages. // // --- // schema: ws, wss // produces: [application/json] // security: [clientTokenAuthorizationHeader: [], clientTokenHeader: [], clientTokenQuery: [], basicAuth: []] // responses: // 200: // description: Ok // schema: // $ref: "#/definitions/Message" // 400: // description: Bad Request // schema: // $ref: "#/definitions/Error" // 401: // description: Unauthorized // schema: // $ref: "#/definitions/Error" // 403: // description: Forbidden // schema: // $ref: "#/definitions/Error" // 500: // description: Server Error // schema: // $ref: "#/definitions/Error" func (a *API) Handle(ctx *gin.Context) { conn, err := a.upgrader.Upgrade(ctx.Writer, ctx.Request, nil) if err != nil { ctx.Error(err) return } client := newClient(conn, auth.GetUserID(ctx), auth.GetTokenID(ctx), a.remove, a.channelBuf) a.register(client) go client.startReading(a.pongTimeout) go client.startWriteHandler(a.pingPeriod) } // Close closes all client connections and stops answering new connections. func (a *API) Close() { for _, shard := range a.shards { shard.lock.Lock() for _, userClients := range shard.clients { for _, tokenClients := range userClients { for _, client := range tokenClients { client.Close() } } } for k := range shard.clients { delete(shard.clients, k) } shard.lock.Unlock() } } func uniq[T comparable](s []T) []T { m := make(map[T]struct{}, len(s)) r := make([]T, 0, len(s)) for _, v := range s { if _, ok := m[v]; !ok { m[v] = struct{}{} r = append(r, v) } } return r } func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool { origin := r.Header.Get("origin") if origin == "" { return true } u, err := url.Parse(origin) if err != nil { return false } if strings.EqualFold(u.Host, r.Host) { return true } for _, allowedOrigin := range allowedOrigins { if allowedOrigin.MatchString(strings.ToLower(u.Hostname())) { return true } } return false } func newUpgrader(allowedWebSocketOrigins []string, readBufferSize, writeBufferSize int) *websocket.Upgrader { compiledAllowedOrigins := compileAllowedWebSocketOrigins(allowedWebSocketOrigins) return &websocket.Upgrader{ ReadBufferSize: readBufferSize, WriteBufferSize: writeBufferSize, CheckOrigin: func(r *http.Request) bool { if mode.IsDev() { return true } return isAllowedOrigin(r, compiledAllowedOrigins) }, } } func compileAllowedWebSocketOrigins(allowedOrigins []string) []*regexp.Regexp { var compiledAllowedOrigins []*regexp.Regexp for _, origin := range allowedOrigins { compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin)) } return compiledAllowedOrigins }