Close web socket connection on delete client

This commit is contained in:
Jannis Mattheis 2018-04-01 21:10:14 +02:00 committed by Jannis Mattheis
parent c912bb8cba
commit 6954fb5adf
7 changed files with 146 additions and 32 deletions

View File

@ -17,14 +17,16 @@ type client struct {
onClose func(*client)
write chan *model.Message
userID uint
token string
once sync.Once
}
func newClient(conn *websocket.Conn, userID uint, onClose func(*client)) *client {
func newClient(conn *websocket.Conn, userID uint, token string, onClose func(*client)) *client {
return &client{
conn: conn,
write: make(chan *model.Message, 1),
userID: userID,
token: token,
onClose: onClose,
}
}

View File

@ -57,21 +57,30 @@ func New(pingPeriod, pongTimeout time.Duration) *API {
}
}
func (a *API) getClients(userID uint) ([]*client, bool) {
a.lock.RLock()
defer a.lock.RUnlock()
clients, ok := a.clients[userID]
return clients, ok
// NotifyDeleted closes existing connections with the given token.
func (a *API) NotifyDeleted(userID uint, token string) {
a.lock.Lock()
defer a.lock.Unlock()
if clients, ok := a.clients[userID]; ok {
for i := len(clients) - 1; i >= 0; i-- {
client := clients[i]
if client.token == token {
client.Close()
clients = append(clients[:i], clients[i+1:]...)
}
}
a.clients[userID] = clients
}
}
// Notify notifies the clients with the given userID that a new messages was created.
func (a *API) Notify(userID uint, msg *model.Message) {
if clients, ok := a.getClients(userID); ok {
go func() {
a.lock.RLock()
defer a.lock.RUnlock()
if clients, ok := a.clients[userID]; ok {
for _, c := range clients {
c.write <- msg
}
}()
}
}
@ -102,7 +111,7 @@ func (a *API) Handle(ctx *gin.Context) {
return
}
client := newClient(conn, auth.GetUserID(ctx), a.remove)
client := newClient(conn, auth.GetUserID(ctx), auth.GetTokenID(ctx), a.remove)
a.register(client)
go client.startReading(a.pongTimeout)
go client.startWriteHandler(a.pingPeriod)

View File

@ -55,11 +55,11 @@ func TestWriteMessageFails(t *testing.T) {
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
client, _ := api.getClients(1)
assert.NotEmpty(t, client)
clients := clients(api, 1)
assert.NotEmpty(t, clients)
// try emulate an write error, mostly this should kill the ReadMessage goroutine first but you'll never know.
patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteJSON", func(*websocket.Conn, interface{}) error {
patch := monkey.PatchInstanceMethod(reflect.TypeOf(clients[0].conn), "WriteJSON", func(*websocket.Conn, interface{}) error {
return errors.New("could not do something")
})
defer patch.Unpatch()
@ -83,10 +83,11 @@ func TestWritePingFails(t *testing.T) {
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
client, _ := api.getClients(1)
assert.NotEmpty(t, client)
clients := clients(api, 1)
assert.NotEmpty(t, clients)
// try emulate an write error, mostly this should kill the ReadMessage gorouting first but you'll never know.
patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteMessage", func(*websocket.Conn, int, []byte) error {
patch := monkey.PatchInstanceMethod(reflect.TypeOf(clients[0].conn), "WriteMessage", func(*websocket.Conn, int, []byte) error {
return errors.New("could not do something")
})
defer patch.Unpatch()
@ -146,13 +147,11 @@ func TestCloseClientOnNotReading(t *testing.T) {
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
clients, _ := api.getClients(1)
assert.NotEmpty(t, clients)
assert.NotEmpty(t, clients(api, 1))
time.Sleep(7 * time.Second)
clients, _ = api.getClients(1)
assert.Empty(t, clients)
assert.Empty(t, clients(api, 1))
}
func TestMessageDirectlyAfterConnect(t *testing.T) {
@ -172,6 +171,92 @@ func TestMessageDirectlyAfterConnect(t *testing.T) {
user.expectMessage(&model.Message{Message: "msg"})
}
func TestDeleteClientShouldCloseConnection(t *testing.T) {
mode.Set(mode.Prod)
defer leaktest.Check(t)()
server, api := bootTestServer(staticUserID())
defer server.Close()
defer api.Close()
wsURL := wsURL(server.URL)
user := testClient(t, wsURL)
defer user.conn.Close()
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
api.Notify(1, &model.Message{Message: "msg"})
user.expectMessage(&model.Message{Message: "msg"})
api.NotifyDeleted(1, "customtoken")
api.Notify(1, &model.Message{Message: "msg"})
user.expectNoMessage()
}
func TestDeleteMultipleClients(t *testing.T) {
mode.Set(mode.TestDev)
defer leaktest.Check(t)()
userIDs := []uint{1, 1, 1, 1, 2, 2, 3}
tokens := []string{"1-1", "1-2", "1-2", "1-3", "2-1", "2-2", "3"}
i := 0
server, api := bootTestServer(func(context *gin.Context) {
auth.RegisterAuthentication(context, nil, userIDs[i], tokens[i])
i++
})
defer server.Close()
wsURL := wsURL(server.URL)
userOneIPhone := testClient(t, wsURL)
defer userOneIPhone.conn.Close()
userOneAndroid := testClient(t, wsURL)
defer userOneAndroid.conn.Close()
userOneBrowser := testClient(t, wsURL)
defer userOneBrowser.conn.Close()
userOneOther := testClient(t, wsURL)
defer userOneOther.conn.Close()
userOne := []*testingClient{userOneAndroid, userOneBrowser, userOneIPhone, userOneOther}
userTwoBrowser := testClient(t, wsURL)
defer userTwoBrowser.conn.Close()
userTwoAndroid := testClient(t, wsURL)
defer userTwoAndroid.conn.Close()
userTwo := []*testingClient{userTwoAndroid, userTwoBrowser}
userThreeAndroid := testClient(t, wsURL)
defer userThreeAndroid.conn.Close()
userThree := []*testingClient{userThreeAndroid}
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
api.Notify(1, &model.Message{ID: 4, Message: "there"})
expectMessage(&model.Message{ID: 4, Message: "there"}, userOne...)
expectNoMessage(userTwo...)
expectNoMessage(userThree...)
api.NotifyDeleted(1, "1-2")
api.Notify(1, &model.Message{ID: 2, Message: "there"})
expectMessage(&model.Message{ID: 2, Message: "there"}, userOneIPhone, userOneOther)
expectNoMessage(userOneBrowser, userOneAndroid)
expectNoMessage(userThree...)
expectNoMessage(userTwo...)
api.Notify(2, &model.Message{ID: 2, Message: "there"})
expectNoMessage(userOne...)
expectMessage(&model.Message{ID: 2, Message: "there"}, userTwo...)
expectNoMessage(userThree...)
api.Notify(3, &model.Message{ID: 5, Message: "there"})
expectNoMessage(userOne...)
expectNoMessage(userTwo...)
expectMessage(&model.Message{ID: 5, Message: "there"}, userThree...)
api.Close()
}
func TestMultipleClients(t *testing.T) {
mode.Set(mode.TestDev)
@ -179,7 +264,7 @@ func TestMultipleClients(t *testing.T) {
userIDs := []uint{1, 1, 1, 2, 2, 3}
i := 0
server, api := bootTestServer(func(context *gin.Context) {
auth.RegisterAuthentication(context, nil, userIDs[i], "")
auth.RegisterAuthentication(context, nil, userIDs[i], "t"+string(userIDs[i]))
i++
})
defer server.Close()
@ -281,6 +366,13 @@ func Test_invalidOrigin_returnsFalse(t *testing.T) {
assert.False(t, actual)
}
func clients(api *API, user uint) []*client {
api.lock.RLock()
defer api.lock.RUnlock()
return api.clients[user]
}
func testClient(t *testing.T, url string) *testingClient {
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
assert.Nil(t, err)
@ -357,6 +449,6 @@ func wsURL(httpURL string) string {
func staticUserID() gin.HandlerFunc {
return func(context *gin.Context) {
auth.RegisterAuthentication(context, nil, 1, "")
auth.RegisterAuthentication(context, nil, 1, "customtoken")
}
}

View File

@ -36,6 +36,7 @@ type TokenDatabase interface {
type TokenAPI struct {
DB TokenDatabase
ImageDir string
NotifyDeleted func(uint, string)
}
// CreateApplication creates an application and returns the access token.
@ -95,6 +96,7 @@ func (a *TokenAPI) DeleteApplication(ctx *gin.Context) {
func (a *TokenAPI) DeleteClient(ctx *gin.Context) {
withID(ctx, "id", func(id uint) {
if client := a.DB.GetClientByID(id); client != nil && client.UserID == auth.GetUserID(ctx) {
a.NotifyDeleted(client.UserID, client.Token)
a.DB.DeleteClientByID(id)
} else {
ctx.AbortWithError(404, fmt.Errorf("client with id %d doesn't exists", id))

View File

@ -43,6 +43,7 @@ type TokenSuite struct {
a *TokenAPI
ctx *gin.Context
recorder *httptest.ResponseRecorder
notified bool
}
func (s *TokenSuite) BeforeTest(suiteName, testName string) {
@ -52,7 +53,12 @@ func (s *TokenSuite) BeforeTest(suiteName, testName string) {
s.db = test.NewDB(s.T())
s.ctx, _ = gin.CreateTestContext(s.recorder)
withURL(s.ctx, "http", "example.com")
s.a = &TokenAPI{DB: s.db}
s.notified = false
s.a = &TokenAPI{DB: s.db, NotifyDeleted: s.notify}
}
func (s *TokenSuite) notify(uint, string) {
s.notified = true
}
func (s *TokenSuite) AfterTest(suiteName, testName string) {
@ -306,10 +312,13 @@ func (s *TokenSuite) Test_DeleteClient() {
s.ctx.Request = httptest.NewRequest("DELETE", "/token/"+firstClientToken, nil)
s.ctx.Params = gin.Params{{Key: "id", Value: "8"}}
assert.False(s.T(), s.notified)
s.a.DeleteClient(s.ctx)
assert.Equal(s.T(), 200, s.recorder.Code)
s.db.AssertClientNotExist(8)
assert.True(s.T(), s.notified)
}
func (s *TokenSuite) Test_UploadAppImage_NoImageProvided_expectBadRequest() {

View File

@ -26,7 +26,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
streamHandler := stream.New(200*time.Second, 15*time.Second)
authentication := auth.Auth{DB: db}
messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db}
tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir}
tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir, NotifyDeleted: streamHandler.NotifyDeleted}
userHandler := api.UserAPI{DB: db, PasswordStrength: conf.PassStrength}
g := gin.New()

View File

@ -49,7 +49,7 @@ export function listenToWebSocket() {
setTimeout(listenToWebSocket, 60000);
};
ws.onmessage = (data) => {
dispatcher.dispatch({type: 'ONE_MESSAGE', payload: JSON.parse(data.data)});
};
ws.onmessage = (data) => dispatcher.dispatch({type: 'ONE_MESSAGE', payload: JSON.parse(data.data)});
ws.onclose = (data) => console.log('WebSocket closed, this normally means the client was deleted.', data);
}