diff --git a/api/stream/client.go b/api/stream/client.go index 9c6af9d..d123483 100644 --- a/api/stream/client.go +++ b/api/stream/client.go @@ -17,14 +17,16 @@ type client struct { onClose func(*client) write chan *model.Message userID uint + token string once sync.Once } -func newClient(conn *websocket.Conn, userID uint, onClose func(*client)) *client { +func newClient(conn *websocket.Conn, userID uint, token string, onClose func(*client)) *client { return &client{ conn: conn, write: make(chan *model.Message, 1), userID: userID, + token: token, onClose: onClose, } } diff --git a/api/stream/stream.go b/api/stream/stream.go index 2d08325..be8ae80 100644 --- a/api/stream/stream.go +++ b/api/stream/stream.go @@ -57,21 +57,30 @@ func New(pingPeriod, pongTimeout time.Duration) *API { } } -func (a *API) getClients(userID uint) ([]*client, bool) { - a.lock.RLock() - defer a.lock.RUnlock() - clients, ok := a.clients[userID] - return clients, ok +// NotifyDeleted closes existing connections with the given token. +func (a *API) NotifyDeleted(userID uint, token string) { + a.lock.Lock() + defer a.lock.Unlock() + if clients, ok := a.clients[userID]; ok { + for i := len(clients) - 1; i >= 0; i-- { + client := clients[i] + if client.token == token { + client.Close() + clients = append(clients[:i], clients[i+1:]...) + } + } + a.clients[userID] = clients + } } // Notify notifies the clients with the given userID that a new messages was created. func (a *API) Notify(userID uint, msg *model.Message) { - if clients, ok := a.getClients(userID); ok { - go func() { - for _, c := range clients { - c.write <- msg - } - }() + a.lock.RLock() + defer a.lock.RUnlock() + if clients, ok := a.clients[userID]; ok { + for _, c := range clients { + c.write <- msg + } } } @@ -102,7 +111,7 @@ func (a *API) Handle(ctx *gin.Context) { return } - client := newClient(conn, auth.GetUserID(ctx), a.remove) + client := newClient(conn, auth.GetUserID(ctx), auth.GetTokenID(ctx), a.remove) a.register(client) go client.startReading(a.pongTimeout) go client.startWriteHandler(a.pingPeriod) diff --git a/api/stream/stream_test.go b/api/stream/stream_test.go index ef82892..d801151 100644 --- a/api/stream/stream_test.go +++ b/api/stream/stream_test.go @@ -55,11 +55,11 @@ func TestWriteMessageFails(t *testing.T) { // the server may take some time to register the client time.Sleep(100 * time.Millisecond) - client, _ := api.getClients(1) - assert.NotEmpty(t, client) + clients := clients(api, 1) + assert.NotEmpty(t, clients) // try emulate an write error, mostly this should kill the ReadMessage goroutine first but you'll never know. - patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteJSON", func(*websocket.Conn, interface{}) error { + patch := monkey.PatchInstanceMethod(reflect.TypeOf(clients[0].conn), "WriteJSON", func(*websocket.Conn, interface{}) error { return errors.New("could not do something") }) defer patch.Unpatch() @@ -83,10 +83,11 @@ func TestWritePingFails(t *testing.T) { // the server may take some time to register the client time.Sleep(100 * time.Millisecond) - client, _ := api.getClients(1) - assert.NotEmpty(t, client) + clients := clients(api, 1) + + assert.NotEmpty(t, clients) // try emulate an write error, mostly this should kill the ReadMessage gorouting first but you'll never know. - patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteMessage", func(*websocket.Conn, int, []byte) error { + patch := monkey.PatchInstanceMethod(reflect.TypeOf(clients[0].conn), "WriteMessage", func(*websocket.Conn, int, []byte) error { return errors.New("could not do something") }) defer patch.Unpatch() @@ -146,13 +147,11 @@ func TestCloseClientOnNotReading(t *testing.T) { // the server may take some time to register the client time.Sleep(100 * time.Millisecond) - clients, _ := api.getClients(1) - assert.NotEmpty(t, clients) + assert.NotEmpty(t, clients(api, 1)) time.Sleep(7 * time.Second) - clients, _ = api.getClients(1) - assert.Empty(t, clients) + assert.Empty(t, clients(api, 1)) } func TestMessageDirectlyAfterConnect(t *testing.T) { @@ -172,6 +171,92 @@ func TestMessageDirectlyAfterConnect(t *testing.T) { user.expectMessage(&model.Message{Message: "msg"}) } +func TestDeleteClientShouldCloseConnection(t *testing.T) { + mode.Set(mode.Prod) + defer leaktest.Check(t)() + server, api := bootTestServer(staticUserID()) + defer server.Close() + defer api.Close() + + wsURL := wsURL(server.URL) + + user := testClient(t, wsURL) + defer user.conn.Close() + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + api.Notify(1, &model.Message{Message: "msg"}) + user.expectMessage(&model.Message{Message: "msg"}) + + api.NotifyDeleted(1, "customtoken") + + api.Notify(1, &model.Message{Message: "msg"}) + user.expectNoMessage() +} + +func TestDeleteMultipleClients(t *testing.T) { + mode.Set(mode.TestDev) + + defer leaktest.Check(t)() + userIDs := []uint{1, 1, 1, 1, 2, 2, 3} + tokens := []string{"1-1", "1-2", "1-2", "1-3", "2-1", "2-2", "3"} + i := 0 + server, api := bootTestServer(func(context *gin.Context) { + auth.RegisterAuthentication(context, nil, userIDs[i], tokens[i]) + i++ + }) + defer server.Close() + + wsURL := wsURL(server.URL) + + userOneIPhone := testClient(t, wsURL) + defer userOneIPhone.conn.Close() + userOneAndroid := testClient(t, wsURL) + defer userOneAndroid.conn.Close() + userOneBrowser := testClient(t, wsURL) + defer userOneBrowser.conn.Close() + userOneOther := testClient(t, wsURL) + defer userOneOther.conn.Close() + userOne := []*testingClient{userOneAndroid, userOneBrowser, userOneIPhone, userOneOther} + + userTwoBrowser := testClient(t, wsURL) + defer userTwoBrowser.conn.Close() + userTwoAndroid := testClient(t, wsURL) + defer userTwoAndroid.conn.Close() + userTwo := []*testingClient{userTwoAndroid, userTwoBrowser} + + userThreeAndroid := testClient(t, wsURL) + defer userThreeAndroid.conn.Close() + userThree := []*testingClient{userThreeAndroid} + + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + + api.Notify(1, &model.Message{ID: 4, Message: "there"}) + expectMessage(&model.Message{ID: 4, Message: "there"}, userOne...) + expectNoMessage(userTwo...) + expectNoMessage(userThree...) + + api.NotifyDeleted(1, "1-2") + + api.Notify(1, &model.Message{ID: 2, Message: "there"}) + expectMessage(&model.Message{ID: 2, Message: "there"}, userOneIPhone, userOneOther) + expectNoMessage(userOneBrowser, userOneAndroid) + expectNoMessage(userThree...) + expectNoMessage(userTwo...) + + api.Notify(2, &model.Message{ID: 2, Message: "there"}) + expectNoMessage(userOne...) + expectMessage(&model.Message{ID: 2, Message: "there"}, userTwo...) + expectNoMessage(userThree...) + + api.Notify(3, &model.Message{ID: 5, Message: "there"}) + expectNoMessage(userOne...) + expectNoMessage(userTwo...) + expectMessage(&model.Message{ID: 5, Message: "there"}, userThree...) + + api.Close() +} + func TestMultipleClients(t *testing.T) { mode.Set(mode.TestDev) @@ -179,7 +264,7 @@ func TestMultipleClients(t *testing.T) { userIDs := []uint{1, 1, 1, 2, 2, 3} i := 0 server, api := bootTestServer(func(context *gin.Context) { - auth.RegisterAuthentication(context, nil, userIDs[i], "") + auth.RegisterAuthentication(context, nil, userIDs[i], "t"+string(userIDs[i])) i++ }) defer server.Close() @@ -281,6 +366,13 @@ func Test_invalidOrigin_returnsFalse(t *testing.T) { assert.False(t, actual) } +func clients(api *API, user uint) []*client { + api.lock.RLock() + defer api.lock.RUnlock() + + return api.clients[user] +} + func testClient(t *testing.T, url string) *testingClient { ws, _, err := websocket.DefaultDialer.Dial(url, nil) assert.Nil(t, err) @@ -357,6 +449,6 @@ func wsURL(httpURL string) string { func staticUserID() gin.HandlerFunc { return func(context *gin.Context) { - auth.RegisterAuthentication(context, nil, 1, "") + auth.RegisterAuthentication(context, nil, 1, "customtoken") } } diff --git a/api/token.go b/api/token.go index 266f430..3ef9212 100644 --- a/api/token.go +++ b/api/token.go @@ -34,8 +34,9 @@ type TokenDatabase interface { // The TokenAPI provides handlers for managing clients and applications. type TokenAPI struct { - DB TokenDatabase - ImageDir string + DB TokenDatabase + ImageDir string + NotifyDeleted func(uint, string) } // CreateApplication creates an application and returns the access token. @@ -95,6 +96,7 @@ func (a *TokenAPI) DeleteApplication(ctx *gin.Context) { func (a *TokenAPI) DeleteClient(ctx *gin.Context) { withID(ctx, "id", func(id uint) { if client := a.DB.GetClientByID(id); client != nil && client.UserID == auth.GetUserID(ctx) { + a.NotifyDeleted(client.UserID, client.Token) a.DB.DeleteClientByID(id) } else { ctx.AbortWithError(404, fmt.Errorf("client with id %d doesn't exists", id)) diff --git a/api/token_test.go b/api/token_test.go index 77cdef3..ecf01a4 100644 --- a/api/token_test.go +++ b/api/token_test.go @@ -43,6 +43,7 @@ type TokenSuite struct { a *TokenAPI ctx *gin.Context recorder *httptest.ResponseRecorder + notified bool } func (s *TokenSuite) BeforeTest(suiteName, testName string) { @@ -52,7 +53,12 @@ func (s *TokenSuite) BeforeTest(suiteName, testName string) { s.db = test.NewDB(s.T()) s.ctx, _ = gin.CreateTestContext(s.recorder) withURL(s.ctx, "http", "example.com") - s.a = &TokenAPI{DB: s.db} + s.notified = false + s.a = &TokenAPI{DB: s.db, NotifyDeleted: s.notify} +} + +func (s *TokenSuite) notify(uint, string) { + s.notified = true } func (s *TokenSuite) AfterTest(suiteName, testName string) { @@ -306,10 +312,13 @@ func (s *TokenSuite) Test_DeleteClient() { s.ctx.Request = httptest.NewRequest("DELETE", "/token/"+firstClientToken, nil) s.ctx.Params = gin.Params{{Key: "id", Value: "8"}} + assert.False(s.T(), s.notified) + s.a.DeleteClient(s.ctx) assert.Equal(s.T(), 200, s.recorder.Code) s.db.AssertClientNotExist(8) + assert.True(s.T(), s.notified) } func (s *TokenSuite) Test_UploadAppImage_NoImageProvided_expectBadRequest() { diff --git a/router/router.go b/router/router.go index b7c1b8d..5b6900b 100644 --- a/router/router.go +++ b/router/router.go @@ -26,7 +26,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co streamHandler := stream.New(200*time.Second, 15*time.Second) authentication := auth.Auth{DB: db} messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db} - tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir} + tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir, NotifyDeleted: streamHandler.NotifyDeleted} userHandler := api.UserAPI{DB: db, PasswordStrength: conf.PassStrength} g := gin.New() diff --git a/ui/src/actions/MessageAction.js b/ui/src/actions/MessageAction.js index bfd8e41..2b04666 100644 --- a/ui/src/actions/MessageAction.js +++ b/ui/src/actions/MessageAction.js @@ -49,7 +49,7 @@ export function listenToWebSocket() { setTimeout(listenToWebSocket, 60000); }; - ws.onmessage = (data) => { - dispatcher.dispatch({type: 'ONE_MESSAGE', payload: JSON.parse(data.data)}); - }; + ws.onmessage = (data) => dispatcher.dispatch({type: 'ONE_MESSAGE', payload: JSON.parse(data.data)}); + + ws.onclose = (data) => console.log('WebSocket closed, this normally means the client was deleted.', data); }