Close web socket connection on delete client
This commit is contained in:
parent
c912bb8cba
commit
6954fb5adf
|
|
@ -17,14 +17,16 @@ type client struct {
|
||||||
onClose func(*client)
|
onClose func(*client)
|
||||||
write chan *model.Message
|
write chan *model.Message
|
||||||
userID uint
|
userID uint
|
||||||
|
token string
|
||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClient(conn *websocket.Conn, userID uint, onClose func(*client)) *client {
|
func newClient(conn *websocket.Conn, userID uint, token string, onClose func(*client)) *client {
|
||||||
return &client{
|
return &client{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
write: make(chan *model.Message, 1),
|
write: make(chan *model.Message, 1),
|
||||||
userID: userID,
|
userID: userID,
|
||||||
|
token: token,
|
||||||
onClose: onClose,
|
onClose: onClose,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,21 +57,30 @@ func New(pingPeriod, pongTimeout time.Duration) *API {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) getClients(userID uint) ([]*client, bool) {
|
// NotifyDeleted closes existing connections with the given token.
|
||||||
a.lock.RLock()
|
func (a *API) NotifyDeleted(userID uint, token string) {
|
||||||
defer a.lock.RUnlock()
|
a.lock.Lock()
|
||||||
clients, ok := a.clients[userID]
|
defer a.lock.Unlock()
|
||||||
return clients, ok
|
if clients, ok := a.clients[userID]; ok {
|
||||||
|
for i := len(clients) - 1; i >= 0; i-- {
|
||||||
|
client := clients[i]
|
||||||
|
if client.token == token {
|
||||||
|
client.Close()
|
||||||
|
clients = append(clients[:i], clients[i+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a.clients[userID] = clients
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify notifies the clients with the given userID that a new messages was created.
|
// Notify notifies the clients with the given userID that a new messages was created.
|
||||||
func (a *API) Notify(userID uint, msg *model.Message) {
|
func (a *API) Notify(userID uint, msg *model.Message) {
|
||||||
if clients, ok := a.getClients(userID); ok {
|
a.lock.RLock()
|
||||||
go func() {
|
defer a.lock.RUnlock()
|
||||||
|
if clients, ok := a.clients[userID]; ok {
|
||||||
for _, c := range clients {
|
for _, c := range clients {
|
||||||
c.write <- msg
|
c.write <- msg
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -102,7 +111,7 @@ func (a *API) Handle(ctx *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client := newClient(conn, auth.GetUserID(ctx), a.remove)
|
client := newClient(conn, auth.GetUserID(ctx), auth.GetTokenID(ctx), a.remove)
|
||||||
a.register(client)
|
a.register(client)
|
||||||
go client.startReading(a.pongTimeout)
|
go client.startReading(a.pongTimeout)
|
||||||
go client.startWriteHandler(a.pingPeriod)
|
go client.startWriteHandler(a.pingPeriod)
|
||||||
|
|
|
||||||
|
|
@ -55,11 +55,11 @@ func TestWriteMessageFails(t *testing.T) {
|
||||||
|
|
||||||
// the server may take some time to register the client
|
// the server may take some time to register the client
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
client, _ := api.getClients(1)
|
clients := clients(api, 1)
|
||||||
assert.NotEmpty(t, client)
|
assert.NotEmpty(t, clients)
|
||||||
|
|
||||||
// try emulate an write error, mostly this should kill the ReadMessage goroutine first but you'll never know.
|
// try emulate an write error, mostly this should kill the ReadMessage goroutine first but you'll never know.
|
||||||
patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteJSON", func(*websocket.Conn, interface{}) error {
|
patch := monkey.PatchInstanceMethod(reflect.TypeOf(clients[0].conn), "WriteJSON", func(*websocket.Conn, interface{}) error {
|
||||||
return errors.New("could not do something")
|
return errors.New("could not do something")
|
||||||
})
|
})
|
||||||
defer patch.Unpatch()
|
defer patch.Unpatch()
|
||||||
|
|
@ -83,10 +83,11 @@ func TestWritePingFails(t *testing.T) {
|
||||||
|
|
||||||
// the server may take some time to register the client
|
// the server may take some time to register the client
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
client, _ := api.getClients(1)
|
clients := clients(api, 1)
|
||||||
assert.NotEmpty(t, client)
|
|
||||||
|
assert.NotEmpty(t, clients)
|
||||||
// try emulate an write error, mostly this should kill the ReadMessage gorouting first but you'll never know.
|
// try emulate an write error, mostly this should kill the ReadMessage gorouting first but you'll never know.
|
||||||
patch := monkey.PatchInstanceMethod(reflect.TypeOf(client[0].conn), "WriteMessage", func(*websocket.Conn, int, []byte) error {
|
patch := monkey.PatchInstanceMethod(reflect.TypeOf(clients[0].conn), "WriteMessage", func(*websocket.Conn, int, []byte) error {
|
||||||
return errors.New("could not do something")
|
return errors.New("could not do something")
|
||||||
})
|
})
|
||||||
defer patch.Unpatch()
|
defer patch.Unpatch()
|
||||||
|
|
@ -146,13 +147,11 @@ func TestCloseClientOnNotReading(t *testing.T) {
|
||||||
|
|
||||||
// the server may take some time to register the client
|
// the server may take some time to register the client
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
clients, _ := api.getClients(1)
|
assert.NotEmpty(t, clients(api, 1))
|
||||||
assert.NotEmpty(t, clients)
|
|
||||||
|
|
||||||
time.Sleep(7 * time.Second)
|
time.Sleep(7 * time.Second)
|
||||||
|
|
||||||
clients, _ = api.getClients(1)
|
assert.Empty(t, clients(api, 1))
|
||||||
assert.Empty(t, clients)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessageDirectlyAfterConnect(t *testing.T) {
|
func TestMessageDirectlyAfterConnect(t *testing.T) {
|
||||||
|
|
@ -172,6 +171,92 @@ func TestMessageDirectlyAfterConnect(t *testing.T) {
|
||||||
user.expectMessage(&model.Message{Message: "msg"})
|
user.expectMessage(&model.Message{Message: "msg"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteClientShouldCloseConnection(t *testing.T) {
|
||||||
|
mode.Set(mode.Prod)
|
||||||
|
defer leaktest.Check(t)()
|
||||||
|
server, api := bootTestServer(staticUserID())
|
||||||
|
defer server.Close()
|
||||||
|
defer api.Close()
|
||||||
|
|
||||||
|
wsURL := wsURL(server.URL)
|
||||||
|
|
||||||
|
user := testClient(t, wsURL)
|
||||||
|
defer user.conn.Close()
|
||||||
|
// the server may take some time to register the client
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
api.Notify(1, &model.Message{Message: "msg"})
|
||||||
|
user.expectMessage(&model.Message{Message: "msg"})
|
||||||
|
|
||||||
|
api.NotifyDeleted(1, "customtoken")
|
||||||
|
|
||||||
|
api.Notify(1, &model.Message{Message: "msg"})
|
||||||
|
user.expectNoMessage()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteMultipleClients(t *testing.T) {
|
||||||
|
mode.Set(mode.TestDev)
|
||||||
|
|
||||||
|
defer leaktest.Check(t)()
|
||||||
|
userIDs := []uint{1, 1, 1, 1, 2, 2, 3}
|
||||||
|
tokens := []string{"1-1", "1-2", "1-2", "1-3", "2-1", "2-2", "3"}
|
||||||
|
i := 0
|
||||||
|
server, api := bootTestServer(func(context *gin.Context) {
|
||||||
|
auth.RegisterAuthentication(context, nil, userIDs[i], tokens[i])
|
||||||
|
i++
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := wsURL(server.URL)
|
||||||
|
|
||||||
|
userOneIPhone := testClient(t, wsURL)
|
||||||
|
defer userOneIPhone.conn.Close()
|
||||||
|
userOneAndroid := testClient(t, wsURL)
|
||||||
|
defer userOneAndroid.conn.Close()
|
||||||
|
userOneBrowser := testClient(t, wsURL)
|
||||||
|
defer userOneBrowser.conn.Close()
|
||||||
|
userOneOther := testClient(t, wsURL)
|
||||||
|
defer userOneOther.conn.Close()
|
||||||
|
userOne := []*testingClient{userOneAndroid, userOneBrowser, userOneIPhone, userOneOther}
|
||||||
|
|
||||||
|
userTwoBrowser := testClient(t, wsURL)
|
||||||
|
defer userTwoBrowser.conn.Close()
|
||||||
|
userTwoAndroid := testClient(t, wsURL)
|
||||||
|
defer userTwoAndroid.conn.Close()
|
||||||
|
userTwo := []*testingClient{userTwoAndroid, userTwoBrowser}
|
||||||
|
|
||||||
|
userThreeAndroid := testClient(t, wsURL)
|
||||||
|
defer userThreeAndroid.conn.Close()
|
||||||
|
userThree := []*testingClient{userThreeAndroid}
|
||||||
|
|
||||||
|
// the server may take some time to register the client
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
api.Notify(1, &model.Message{ID: 4, Message: "there"})
|
||||||
|
expectMessage(&model.Message{ID: 4, Message: "there"}, userOne...)
|
||||||
|
expectNoMessage(userTwo...)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
|
||||||
|
api.NotifyDeleted(1, "1-2")
|
||||||
|
|
||||||
|
api.Notify(1, &model.Message{ID: 2, Message: "there"})
|
||||||
|
expectMessage(&model.Message{ID: 2, Message: "there"}, userOneIPhone, userOneOther)
|
||||||
|
expectNoMessage(userOneBrowser, userOneAndroid)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
expectNoMessage(userTwo...)
|
||||||
|
|
||||||
|
api.Notify(2, &model.Message{ID: 2, Message: "there"})
|
||||||
|
expectNoMessage(userOne...)
|
||||||
|
expectMessage(&model.Message{ID: 2, Message: "there"}, userTwo...)
|
||||||
|
expectNoMessage(userThree...)
|
||||||
|
|
||||||
|
api.Notify(3, &model.Message{ID: 5, Message: "there"})
|
||||||
|
expectNoMessage(userOne...)
|
||||||
|
expectNoMessage(userTwo...)
|
||||||
|
expectMessage(&model.Message{ID: 5, Message: "there"}, userThree...)
|
||||||
|
|
||||||
|
api.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func TestMultipleClients(t *testing.T) {
|
func TestMultipleClients(t *testing.T) {
|
||||||
mode.Set(mode.TestDev)
|
mode.Set(mode.TestDev)
|
||||||
|
|
||||||
|
|
@ -179,7 +264,7 @@ func TestMultipleClients(t *testing.T) {
|
||||||
userIDs := []uint{1, 1, 1, 2, 2, 3}
|
userIDs := []uint{1, 1, 1, 2, 2, 3}
|
||||||
i := 0
|
i := 0
|
||||||
server, api := bootTestServer(func(context *gin.Context) {
|
server, api := bootTestServer(func(context *gin.Context) {
|
||||||
auth.RegisterAuthentication(context, nil, userIDs[i], "")
|
auth.RegisterAuthentication(context, nil, userIDs[i], "t"+string(userIDs[i]))
|
||||||
i++
|
i++
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
@ -281,6 +366,13 @@ func Test_invalidOrigin_returnsFalse(t *testing.T) {
|
||||||
assert.False(t, actual)
|
assert.False(t, actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func clients(api *API, user uint) []*client {
|
||||||
|
api.lock.RLock()
|
||||||
|
defer api.lock.RUnlock()
|
||||||
|
|
||||||
|
return api.clients[user]
|
||||||
|
}
|
||||||
|
|
||||||
func testClient(t *testing.T, url string) *testingClient {
|
func testClient(t *testing.T, url string) *testingClient {
|
||||||
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
|
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
@ -357,6 +449,6 @@ func wsURL(httpURL string) string {
|
||||||
|
|
||||||
func staticUserID() gin.HandlerFunc {
|
func staticUserID() gin.HandlerFunc {
|
||||||
return func(context *gin.Context) {
|
return func(context *gin.Context) {
|
||||||
auth.RegisterAuthentication(context, nil, 1, "")
|
auth.RegisterAuthentication(context, nil, 1, "customtoken")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ type TokenDatabase interface {
|
||||||
type TokenAPI struct {
|
type TokenAPI struct {
|
||||||
DB TokenDatabase
|
DB TokenDatabase
|
||||||
ImageDir string
|
ImageDir string
|
||||||
|
NotifyDeleted func(uint, string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateApplication creates an application and returns the access token.
|
// CreateApplication creates an application and returns the access token.
|
||||||
|
|
@ -95,6 +96,7 @@ func (a *TokenAPI) DeleteApplication(ctx *gin.Context) {
|
||||||
func (a *TokenAPI) DeleteClient(ctx *gin.Context) {
|
func (a *TokenAPI) DeleteClient(ctx *gin.Context) {
|
||||||
withID(ctx, "id", func(id uint) {
|
withID(ctx, "id", func(id uint) {
|
||||||
if client := a.DB.GetClientByID(id); client != nil && client.UserID == auth.GetUserID(ctx) {
|
if client := a.DB.GetClientByID(id); client != nil && client.UserID == auth.GetUserID(ctx) {
|
||||||
|
a.NotifyDeleted(client.UserID, client.Token)
|
||||||
a.DB.DeleteClientByID(id)
|
a.DB.DeleteClientByID(id)
|
||||||
} else {
|
} else {
|
||||||
ctx.AbortWithError(404, fmt.Errorf("client with id %d doesn't exists", id))
|
ctx.AbortWithError(404, fmt.Errorf("client with id %d doesn't exists", id))
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ type TokenSuite struct {
|
||||||
a *TokenAPI
|
a *TokenAPI
|
||||||
ctx *gin.Context
|
ctx *gin.Context
|
||||||
recorder *httptest.ResponseRecorder
|
recorder *httptest.ResponseRecorder
|
||||||
|
notified bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TokenSuite) BeforeTest(suiteName, testName string) {
|
func (s *TokenSuite) BeforeTest(suiteName, testName string) {
|
||||||
|
|
@ -52,7 +53,12 @@ func (s *TokenSuite) BeforeTest(suiteName, testName string) {
|
||||||
s.db = test.NewDB(s.T())
|
s.db = test.NewDB(s.T())
|
||||||
s.ctx, _ = gin.CreateTestContext(s.recorder)
|
s.ctx, _ = gin.CreateTestContext(s.recorder)
|
||||||
withURL(s.ctx, "http", "example.com")
|
withURL(s.ctx, "http", "example.com")
|
||||||
s.a = &TokenAPI{DB: s.db}
|
s.notified = false
|
||||||
|
s.a = &TokenAPI{DB: s.db, NotifyDeleted: s.notify}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TokenSuite) notify(uint, string) {
|
||||||
|
s.notified = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TokenSuite) AfterTest(suiteName, testName string) {
|
func (s *TokenSuite) AfterTest(suiteName, testName string) {
|
||||||
|
|
@ -306,10 +312,13 @@ func (s *TokenSuite) Test_DeleteClient() {
|
||||||
s.ctx.Request = httptest.NewRequest("DELETE", "/token/"+firstClientToken, nil)
|
s.ctx.Request = httptest.NewRequest("DELETE", "/token/"+firstClientToken, nil)
|
||||||
s.ctx.Params = gin.Params{{Key: "id", Value: "8"}}
|
s.ctx.Params = gin.Params{{Key: "id", Value: "8"}}
|
||||||
|
|
||||||
|
assert.False(s.T(), s.notified)
|
||||||
|
|
||||||
s.a.DeleteClient(s.ctx)
|
s.a.DeleteClient(s.ctx)
|
||||||
|
|
||||||
assert.Equal(s.T(), 200, s.recorder.Code)
|
assert.Equal(s.T(), 200, s.recorder.Code)
|
||||||
s.db.AssertClientNotExist(8)
|
s.db.AssertClientNotExist(8)
|
||||||
|
assert.True(s.T(), s.notified)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TokenSuite) Test_UploadAppImage_NoImageProvided_expectBadRequest() {
|
func (s *TokenSuite) Test_UploadAppImage_NoImageProvided_expectBadRequest() {
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
|
||||||
streamHandler := stream.New(200*time.Second, 15*time.Second)
|
streamHandler := stream.New(200*time.Second, 15*time.Second)
|
||||||
authentication := auth.Auth{DB: db}
|
authentication := auth.Auth{DB: db}
|
||||||
messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db}
|
messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db}
|
||||||
tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir}
|
tokenHandler := api.TokenAPI{DB: db, ImageDir: conf.UploadedImagesDir, NotifyDeleted: streamHandler.NotifyDeleted}
|
||||||
userHandler := api.UserAPI{DB: db, PasswordStrength: conf.PassStrength}
|
userHandler := api.UserAPI{DB: db, PasswordStrength: conf.PassStrength}
|
||||||
g := gin.New()
|
g := gin.New()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ export function listenToWebSocket() {
|
||||||
setTimeout(listenToWebSocket, 60000);
|
setTimeout(listenToWebSocket, 60000);
|
||||||
};
|
};
|
||||||
|
|
||||||
ws.onmessage = (data) => {
|
ws.onmessage = (data) => dispatcher.dispatch({type: 'ONE_MESSAGE', payload: JSON.parse(data.data)});
|
||||||
dispatcher.dispatch({type: 'ONE_MESSAGE', payload: JSON.parse(data.data)});
|
|
||||||
};
|
ws.onclose = (data) => console.log('WebSocket closed, this normally means the client was deleted.', data);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue