Close web socket connection on delete client

This commit is contained in:
Jannis Mattheis 2018-04-01 21:10:14 +02:00 committed by Jannis Mattheis
parent c912bb8cba
commit 6954fb5adf
7 changed files with 146 additions and 32 deletions

View File

@ -17,14 +17,16 @@ type client struct {
onClose func(*client) onClose func(*client)
write chan *model.Message write chan *model.Message
userID uint userID uint
token string
once sync.Once once sync.Once
} }
func newClient(conn *websocket.Conn, userID uint, onClose func(*client)) *client { func newClient(conn *websocket.Conn, userID uint, token string, onClose func(*client)) *client {
return &client{ return &client{
conn: conn, conn: conn,
write: make(chan *model.Message, 1), write: make(chan *model.Message, 1),
userID: userID, userID: userID,
token: token,
onClose: onClose, onClose: onClose,
} }
} }

View File

@ -57,21 +57,30 @@ func New(pingPeriod, pongTimeout time.Duration) *API {
} }
} }
func (a *API) getClients(userID uint) ([]*client, bool) { // NotifyDeleted closes existing connections with the given token.
a.lock.RLock() func (a *API) NotifyDeleted(userID uint, token string) {
defer a.lock.RUnlock() a.lock.Lock()
clients, ok := a.clients[userID] defer a.lock.Unlock()
return clients, ok if clients, ok := a.clients[userID]; ok {
for i := len(clients) - 1; i >= 0; i-- {
client := clients[i]
if client.token == token {
client.Close()
clients = append(clients[:i], clients[i+1:]...)
}
}
a.clients[userID] = clients
}
} }
// Notify notifies the clients with the given userID that a new messages was created. // Notify notifies the clients with the given userID that a new messages was created.
func (a *API) Notify(userID uint, msg *model.Message) { func (a *API) Notify(userID uint, msg *model.Message) {
if clients, ok := a.getClients(userID); ok { a.lock.RLock()
go func() { defer a.lock.RUnlock()
if clients, ok := a.clients[userID]; ok {
for _, c := range clients { for _, c := range clients {
c.write <- msg c.write <- msg
} }
}()
} }
} }
@ -102,7 +111,7 @@ func (a *API) Handle(ctx *gin.Context) {
return return
} }
client := newClient(conn, auth.GetUserID(ctx), a.remove) client := newClient(conn, auth.GetUserID(ctx), auth.GetTokenID(ctx), a.remove)
a.register(client) a.register(client)
go client.startReading(a.pongTimeout) go client.startReading(a.pongTimeout)
go client.startWriteHandler(a.pingPeriod) go client.startWriteHandler(a.pingPeriod)

View File

@ -55,11 +55,11 @@ func TestWriteMessageFails(t *testing.T) {
// the server may take some time to register the client // the server may take some time to register the client
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
client, _ := api.getClients(1) clients := clients(api, 1)
assert.NotEmpty(t, client) assert.NotEmpty(t, clients)
// try emulate an write error, mostly this should kill the ReadMessage goroutine first but you'll never know. // try emulate an write error, mostly this should kill the ReadMessage goroutine first but you'll never know.
patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteJSON", func(*websocket.Conn, interface{}) error { patch := monkey.PatchInstanceMethod(reflect.TypeOf(clients[0].conn), "WriteJSON", func(*websocket.Conn, interface{}) error {
return errors.New("could not do something") return errors.New("could not do something")
}) })
defer patch.Unpatch() defer patch.Unpatch()
@ -83,10 +83,11 @@ func TestWritePingFails(t *testing.T) {
// the server may take some time to register the client // the server may take some time to register the client
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
client, _ := api.getClients(1) clients := clients(api, 1)
assert.NotEmpty(t, client)
assert.NotEmpty(t, clients)
// try emulate an write error, mostly this should kill the ReadMessage gorouting first but you'll never know. // try emulate an write error, mostly this should kill the ReadMessage gorouting first but you'll never know.
patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteMessage", func(*websocket.Conn, int, []byte) error { patch := monkey.PatchInstanceMethod(reflect.TypeOf(clients[0].conn), "WriteMessage", func(*websocket.Conn, int, []byte) error {
return errors.New("could not do something") return errors.New("could not do something")
}) })
defer patch.Unpatch() defer patch.Unpatch()
@ -146,13 +147,11 @@ func TestCloseClientOnNotReading(t *testing.T) {
// the server may take some time to register the client // the server may take some time to register the client
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
clients, _ := api.getClients(1) assert.NotEmpty(t, clients(api, 1))
assert.NotEmpty(t, clients)
time.Sleep(7 * time.Second) time.Sleep(7 * time.Second)
clients, _ = api.getClients(1) assert.Empty(t, clients(api, 1))
assert.Empty(t, clients)
} }
func TestMessageDirectlyAfterConnect(t *testing.T) { func TestMessageDirectlyAfterConnect(t *testing.T) {
@ -172,6 +171,92 @@ func TestMessageDirectlyAfterConnect(t *testing.T) {
user.expectMessage(&model.Message{Message: "msg"}) user.expectMessage(&model.Message{Message: "msg"})
} }
func TestDeleteClientShouldCloseConnection(t *testing.T) {
mode.Set(mode.Prod)
defer leaktest.Check(t)()
server, api := bootTestServer(staticUserID())
defer server.Close()
defer api.Close()
wsURL := wsURL(server.URL)
user := testClient(t, wsURL)
defer user.conn.Close()
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
api.Notify(1, &model.Message{Message: "msg"})
user.expectMessage(&model.Message{Message: "msg"})
api.NotifyDeleted(1, "customtoken")
api.Notify(1, &model.Message{Message: "msg"})
user.expectNoMessage()
}
func TestDeleteMultipleClients(t *testing.T) {
mode.Set(mode.TestDev)
defer leaktest.Check(t)()
userIDs := []uint{1, 1, 1, 1, 2, 2, 3}
tokens := []string{"1-1", "1-2", "1-2", "1-3", "2-1", "2-2", "3"}
i := 0
server, api := bootTestServer(func(context *gin.Context) {
auth.RegisterAuthentication(context, nil, userIDs[i], tokens[i])
i++
})
defer server.Close()
wsURL := wsURL(server.URL)
userOneIPhone := testClient(t, wsURL)
defer userOneIPhone.conn.Close()
userOneAndroid := testClient(t, wsURL)
defer userOneAndroid.conn.Close()
userOneBrowser := testClient(t, wsURL)
defer userOneBrowser.conn.Close()
userOneOther := testClient(t, wsURL)
defer userOneOther.conn.Close()
userOne := []*testingClient{userOneAndroid, userOneBrowser, userOneIPhone, userOneOther}
userTwoBrowser := testClient(t, wsURL)
defer userTwoBrowser.conn.Close()
userTwoAndroid := testClient(t, wsURL)
defer userTwoAndroid.conn.Close()
userTwo := []*testingClient{userTwoAndroid, userTwoBrowser}
userThreeAndroid := testClient(t, wsURL)
defer userThreeAndroid.conn.Close()
userThree := []*testingClient{userThreeAndroid}
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
api.Notify(1, &model.Message{ID: 4, Message: "there"})
expectMessage(&model.Message{ID: 4, Message: "there"}, userOne...)
expectNoMessage(userTwo...)
expectNoMessage(userThree...)
api.NotifyDeleted(1, "1-2")
api.Notify(1, &model.Message{ID: 2, Message: "there"})
expectMessage(&model.Message{ID: 2, Message: "there"}, userOneIPhone, userOneOther)
expectNoMessage(userOneBrowser, userOneAndroid)
expectNoMessage(userThree...)
expectNoMessage(userTwo...)
api.Notify(2, &model.Message{ID: 2, Message: "there"})
expectNoMessage(userOne...)
expectMessage(&model.Message{ID: 2, Message: "there"}, userTwo...)
expectNoMessage(userThree...)
api.Notify(3, &model.Message{ID: 5, Message: "there"})
expectNoMessage(userOne...)
expectNoMessage(userTwo...)
expectMessage(&model.Message{ID: 5, Message: "there"}, userThree...)
api.Close()
}
func TestMultipleClients(t *testing.T) { func TestMultipleClients(t *testing.T) {
mode.Set(mode.TestDev) mode.Set(mode.TestDev)
@ -179,7 +264,7 @@ func TestMultipleClients(t *testing.T) {
userIDs := []uint{1, 1, 1, 2, 2, 3} userIDs := []uint{1, 1, 1, 2, 2, 3}
i := 0 i := 0
server, api := bootTestServer(func(context *gin.Context) { server, api := bootTestServer(func(context *gin.Context) {
auth.RegisterAuthentication(context, nil, userIDs[i], "") auth.RegisterAuthentication(context, nil, userIDs[i], "t"+string(userIDs[i]))
i++ i++
}) })
defer server.Close() defer server.Close()
@ -281,6 +366,13 @@ func Test_invalidOrigin_returnsFalse(t *testing.T) {
assert.False(t, actual) assert.False(t, actual)
} }
func clients(api *API, user uint) []*client {
api.lock.RLock()
defer api.lock.RUnlock()
return api.clients[user]
}
func testClient(t *testing.T, url string) *testingClient { func testClient(t *testing.T, url string) *testingClient {
ws, _, err := websocket.DefaultDialer.Dial(url, nil) ws, _, err := websocket.DefaultDialer.Dial(url, nil)
assert.Nil(t, err) assert.Nil(t, err)
@ -357,6 +449,6 @@ func wsURL(httpURL string) string {
func staticUserID() gin.HandlerFunc { func staticUserID() gin.HandlerFunc {
return func(context *gin.Context) { return func(context *gin.Context) {
auth.RegisterAuthentication(context, nil, 1, "") auth.RegisterAuthentication(context, nil, 1, "customtoken")
} }
} }

View File

@ -36,6 +36,7 @@ type TokenDatabase interface {
type TokenAPI struct { type TokenAPI struct {
DB TokenDatabase DB TokenDatabase
ImageDir string ImageDir string
NotifyDeleted func(uint, string)
} }
// CreateApplication creates an application and returns the access token. // CreateApplication creates an application and returns the access token.
@ -95,6 +96,7 @@ func (a *TokenAPI) DeleteApplication(ctx *gin.Context) {
func (a *TokenAPI) DeleteClient(ctx *gin.Context) { func (a *TokenAPI) DeleteClient(ctx *gin.Context) {
withID(ctx, "id", func(id uint) { withID(ctx, "id", func(id uint) {
if client := a.DB.GetClientByID(id); client != nil && client.UserID == auth.GetUserID(ctx) { if client := a.DB.GetClientByID(id); client != nil && client.UserID == auth.GetUserID(ctx) {
a.NotifyDeleted(client.UserID, client.Token)
a.DB.DeleteClientByID(id) a.DB.DeleteClientByID(id)
} else { } else {
ctx.AbortWithError(404, fmt.Errorf("client with id %d doesn't exists", id)) ctx.AbortWithError(404, fmt.Errorf("client with id %d doesn't exists", id))

View File

@ -43,6 +43,7 @@ type TokenSuite struct {
a *TokenAPI a *TokenAPI
ctx *gin.Context ctx *gin.Context
recorder *httptest.ResponseRecorder recorder *httptest.ResponseRecorder
notified bool
} }
func (s *TokenSuite) BeforeTest(suiteName, testName string) { func (s *TokenSuite) BeforeTest(suiteName, testName string) {
@ -52,7 +53,12 @@ func (s *TokenSuite) BeforeTest(suiteName, testName string) {
s.db = test.NewDB(s.T()) s.db = test.NewDB(s.T())
s.ctx, _ = gin.CreateTestContext(s.recorder) s.ctx, _ = gin.CreateTestContext(s.recorder)
withURL(s.ctx, "http", "example.com") withURL(s.ctx, "http", "example.com")
s.a = &TokenAPI{DB: s.db} s.notified = false
s.a = &TokenAPI{DB: s.db, NotifyDeleted: s.notify}
}
func (s *TokenSuite) notify(uint, string) {
s.notified = true
} }
func (s *TokenSuite) AfterTest(suiteName, testName string) { func (s *TokenSuite) AfterTest(suiteName, testName string) {
@ -306,10 +312,13 @@ func (s *TokenSuite) Test_DeleteClient() {
s.ctx.Request = httptest.NewRequest("DELETE", "/token/"+firstClientToken, nil) s.ctx.Request = httptest.NewRequest("DELETE", "/token/"+firstClientToken, nil)
s.ctx.Params = gin.Params{{Key: "id", Value: "8"}} s.ctx.Params = gin.Params{{Key: "id", Value: "8"}}
assert.False(s.T(), s.notified)
s.a.DeleteClient(s.ctx) s.a.DeleteClient(s.ctx)
assert.Equal(s.T(), 200, s.recorder.Code) assert.Equal(s.T(), 200, s.recorder.Code)
s.db.AssertClientNotExist(8) s.db.AssertClientNotExist(8)
assert.True(s.T(), s.notified)
} }
func (s *TokenSuite) Test_UploadAppImage_NoImageProvided_expectBadRequest() { func (s *TokenSuite) Test_UploadAppImage_NoImageProvided_expectBadRequest() {

View File

@ -26,7 +26,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
streamHandler := stream.New(200*time.Second, 15*time.Second) streamHandler := stream.New(200*time.Second, 15*time.Second)
authentication := auth.Auth{DB: db} authentication := auth.Auth{DB: db}
messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db} messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db}
tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir} tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir, NotifyDeleted: streamHandler.NotifyDeleted}
userHandler := api.UserAPI{DB: db, PasswordStrength: conf.PassStrength} userHandler := api.UserAPI{DB: db, PasswordStrength: conf.PassStrength}
g := gin.New() g := gin.New()

View File

@ -49,7 +49,7 @@ export function listenToWebSocket() {
setTimeout(listenToWebSocket, 60000); setTimeout(listenToWebSocket, 60000);
}; };
ws.onmessage = (data) => { ws.onmessage = (data) => dispatcher.dispatch({type: 'ONE_MESSAGE', payload: JSON.parse(data.data)});
dispatcher.dispatch({type: 'ONE_MESSAGE', payload: JSON.parse(data.data)});
}; ws.onclose = (data) => console.log('WebSocket closed, this normally means the client was deleted.', data);
} }