From 1262f43846ede2c814486fda2dd64bae2570daed Mon Sep 17 00:00:00 2001 From: Jannis Mattheis Date: Mon, 2 Apr 2018 10:50:06 +0200 Subject: [PATCH] Close web socket connection on delete user --- api/stream/stream.go | 16 ++++++++-- api/stream/stream_test.go | 67 +++++++++++++++++++++++++++++++++++++-- api/user.go | 2 ++ api/user_test.go | 11 ++++++- router/router.go | 4 +-- 5 files changed, 93 insertions(+), 7 deletions(-) diff --git a/api/stream/stream.go b/api/stream/stream.go index be8ae80..4a46fba 100644 --- a/api/stream/stream.go +++ b/api/stream/stream.go @@ -57,8 +57,20 @@ func New(pingPeriod, pongTimeout time.Duration) *API { } } -// NotifyDeleted closes existing connections with the given token. -func (a *API) NotifyDeleted(userID uint, token string) { +// NotifyDeletedUser closes existing connections for the given user. +func (a *API) NotifyDeletedUser(userID uint) { + a.lock.Lock() + defer a.lock.Unlock() + if clients, ok := a.clients[userID]; ok { + for _, client := range clients { + client.Close() + } + delete(a.clients, userID) + } +} + +// NotifyDeletedClient closes existing connections with the given token. +func (a *API) NotifyDeletedClient(userID uint, token string) { a.lock.Lock() defer a.lock.Unlock() if clients, ok := a.clients[userID]; ok { diff --git a/api/stream/stream_test.go b/api/stream/stream_test.go index d801151..601484f 100644 --- a/api/stream/stream_test.go +++ b/api/stream/stream_test.go @@ -187,7 +187,7 @@ func TestDeleteClientShouldCloseConnection(t *testing.T) { api.Notify(1, &model.Message{Message: "msg"}) user.expectMessage(&model.Message{Message: "msg"}) - api.NotifyDeleted(1, "customtoken") + api.NotifyDeletedClient(1, "customtoken") api.Notify(1, &model.Message{Message: "msg"}) user.expectNoMessage() @@ -236,7 +236,7 @@ func TestDeleteMultipleClients(t *testing.T) { expectNoMessage(userTwo...) expectNoMessage(userThree...) - api.NotifyDeleted(1, "1-2") + api.NotifyDeletedClient(1, "1-2") api.Notify(1, &model.Message{ID: 2, Message: "there"}) expectMessage(&model.Message{ID: 2, Message: "there"}, userOneIPhone, userOneOther) @@ -257,6 +257,69 @@ func TestDeleteMultipleClients(t *testing.T) { api.Close() } +func TestDeleteUser(t *testing.T) { + mode.Set(mode.TestDev) + + defer leaktest.Check(t)() + userIDs := []uint{1, 1, 1, 1, 2, 2, 3} + tokens := []string{"1-1", "1-2", "1-2", "1-3", "2-1", "2-2", "3"} + i := 0 + server, api := bootTestServer(func(context *gin.Context) { + auth.RegisterAuthentication(context, nil, userIDs[i], tokens[i]) + i++ + }) + defer server.Close() + + wsURL := wsURL(server.URL) + + userOneIPhone := testClient(t, wsURL) + defer userOneIPhone.conn.Close() + userOneAndroid := testClient(t, wsURL) + defer userOneAndroid.conn.Close() + userOneBrowser := testClient(t, wsURL) + defer userOneBrowser.conn.Close() + userOneOther := testClient(t, wsURL) + defer userOneOther.conn.Close() + userOne := []*testingClient{userOneAndroid, userOneBrowser, userOneIPhone, userOneOther} + + userTwoBrowser := testClient(t, wsURL) + defer userTwoBrowser.conn.Close() + userTwoAndroid := testClient(t, wsURL) + defer userTwoAndroid.conn.Close() + userTwo := []*testingClient{userTwoAndroid, userTwoBrowser} + + userThreeAndroid := testClient(t, wsURL) + defer userThreeAndroid.conn.Close() + userThree := []*testingClient{userThreeAndroid} + + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + + api.Notify(1, &model.Message{ID: 4, Message: "there"}) + expectMessage(&model.Message{ID: 4, Message: "there"}, userOne...) + expectNoMessage(userTwo...) + expectNoMessage(userThree...) + + api.NotifyDeletedUser(1) + + api.Notify(1, &model.Message{ID: 2, Message: "there"}) + expectNoMessage(userOne...) + expectNoMessage(userThree...) + expectNoMessage(userTwo...) + + api.Notify(2, &model.Message{ID: 2, Message: "there"}) + expectNoMessage(userOne...) + expectMessage(&model.Message{ID: 2, Message: "there"}, userTwo...) + expectNoMessage(userThree...) + + api.Notify(3, &model.Message{ID: 5, Message: "there"}) + expectNoMessage(userOne...) + expectNoMessage(userTwo...) + expectMessage(&model.Message{ID: 5, Message: "there"}, userThree...) + + api.Close() +} + func TestMultipleClients(t *testing.T) { mode.Set(mode.TestDev) diff --git a/api/user.go b/api/user.go index 918592d..0da81d2 100644 --- a/api/user.go +++ b/api/user.go @@ -23,6 +23,7 @@ type UserDatabase interface { type UserAPI struct { DB UserDatabase PasswordStrength int + NotifyDeleted func(uint) } // GetUsers returns all the users @@ -72,6 +73,7 @@ func (a *UserAPI) GetUserByID(ctx *gin.Context) { func (a *UserAPI) DeleteUserByID(ctx *gin.Context) { withID(ctx, "id", func(id uint) { if user := a.DB.GetUserByID(id); user != nil { + a.NotifyDeleted(id) a.DB.DeleteUserByID(id) } else { ctx.AbortWithError(404, errors.New("user does not exist")) diff --git a/api/user_test.go b/api/user_test.go index 3f3bc4d..04bbd14 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -24,6 +24,7 @@ type UserSuite struct { a *UserAPI ctx *gin.Context recorder *httptest.ResponseRecorder + notified bool } func (s *UserSuite) BeforeTest(suiteName, testName string) { @@ -31,7 +32,12 @@ func (s *UserSuite) BeforeTest(suiteName, testName string) { s.recorder = httptest.NewRecorder() s.ctx, _ = gin.CreateTestContext(s.recorder) s.db = test.NewDB(s.T()) - s.a = &UserAPI{DB: s.db} + s.notified = false + s.a = &UserAPI{DB: s.db, NotifyDeleted: s.notify} +} + +func (s *UserSuite) notify(uint) { + s.notified = true } func (s *UserSuite) AfterTest(suiteName, testName string) { @@ -107,6 +113,8 @@ func (s *UserSuite) Test_DeleteUserByID_UnknownUser() { } func (s *UserSuite) Test_DeleteUserByID() { + assert.False(s.T(), s.notified) + s.db.User(2) s.ctx.Params = gin.Params{{Key: "id", Value: "2"}} @@ -115,6 +123,7 @@ func (s *UserSuite) Test_DeleteUserByID() { assert.Equal(s.T(), 200, s.recorder.Code) s.db.AssertUserNotExist(2) + assert.True(s.T(), s.notified) } func (s *UserSuite) Test_CreateUser() { diff --git a/router/router.go b/router/router.go index 5b6900b..e6d678e 100644 --- a/router/router.go +++ b/router/router.go @@ -26,8 +26,8 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co streamHandler := stream.New(200*time.Second, 15*time.Second) authentication := auth.Auth{DB: db} messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db} - tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir, NotifyDeleted: streamHandler.NotifyDeleted} - userHandler := api.UserAPI{DB: db, PasswordStrength: conf.PassStrength} + tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir, NotifyDeleted: streamHandler.NotifyDeletedClient} + userHandler := api.UserAPI{DB: db, PasswordStrength: conf.PassStrength, NotifyDeleted: streamHandler.NotifyDeletedUser} g := gin.New() g.Use(gin.Logger(), gin.Recovery(), error.Handler(), location.Default())