Check origin on web socket requests in prod mode

This commit is contained in:
Jannis Mattheis 2018-03-18 14:00:28 +01:00 committed by Jannis Mattheis
parent f1aa490035
commit 090142c281
2 changed files with 61 additions and 2 deletions

View File

@ -10,16 +10,32 @@ import (
"github.com/gotify/server/model" "github.com/gotify/server/model"
"net/http" "net/http"
"github.com/gotify/server/mode" "github.com/gotify/server/mode"
"net/url"
) )
var upgrader = websocket.Upgrader{ var upgrader = websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return mode.IsDev(); if mode.IsDev() {
return true
}
return checkSameOrigin(r)
}, },
} }
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return u.Host == r.Host
}
// The API provides a handler for a WebSocket stream API. // The API provides a handler for a WebSocket stream API.
type API struct { type API struct {
clients map[uint][]*client clients map[uint][]*client

View File

@ -25,6 +25,8 @@ import (
) )
func TestFailureOnNormalHttpRequest(t *testing.T) { func TestFailureOnNormalHttpRequest(t *testing.T) {
mode.Set(mode.TestDev)
defer leaktest.Check(t)() defer leaktest.Check(t)()
server, api := bootTestServer(staticUserID()) server, api := bootTestServer(staticUserID())
@ -38,6 +40,8 @@ func TestFailureOnNormalHttpRequest(t *testing.T) {
} }
func TestWriteMessageFails(t *testing.T) { func TestWriteMessageFails(t *testing.T) {
mode.Set(mode.TestDev)
defer leaktest.Check(t)() defer leaktest.Check(t)()
server, api := bootTestServer(func(context *gin.Context) { server, api := bootTestServer(func(context *gin.Context) {
@ -65,6 +69,8 @@ func TestWriteMessageFails(t *testing.T) {
} }
func TestWritePingFails(t *testing.T) { func TestWritePingFails(t *testing.T) {
mode.Set(mode.TestDev)
defer leaktest.CheckTimeout(t, 10*time.Second)() defer leaktest.CheckTimeout(t, 10*time.Second)()
server, api := bootTestServer(staticUserID()) server, api := bootTestServer(staticUserID())
@ -92,6 +98,8 @@ func TestWritePingFails(t *testing.T) {
} }
func TestPing(t *testing.T) { func TestPing(t *testing.T) {
mode.Set(mode.TestDev)
server, api := bootTestServer(staticUserID()) server, api := bootTestServer(staticUserID())
defer server.Close() defer server.Close()
defer api.Close() defer api.Close()
@ -124,6 +132,8 @@ func TestPing(t *testing.T) {
} }
func TestCloseClientOnNotReading(t *testing.T) { func TestCloseClientOnNotReading(t *testing.T) {
mode.Set(mode.TestDev)
server, api := bootTestServer(staticUserID()) server, api := bootTestServer(staticUserID())
defer server.Close() defer server.Close()
defer api.Close() defer api.Close()
@ -146,6 +156,7 @@ func TestCloseClientOnNotReading(t *testing.T) {
} }
func TestMessageDirectlyAfterConnect(t *testing.T) { func TestMessageDirectlyAfterConnect(t *testing.T) {
mode.Set(mode.Prod)
defer leaktest.Check(t)() defer leaktest.Check(t)()
server, api := bootTestServer(staticUserID()) server, api := bootTestServer(staticUserID())
defer server.Close() defer server.Close()
@ -162,6 +173,8 @@ func TestMessageDirectlyAfterConnect(t *testing.T) {
} }
func TestMultipleClients(t *testing.T) { func TestMultipleClients(t *testing.T) {
mode.Set(mode.TestDev)
defer leaktest.Check(t)() defer leaktest.Check(t)()
userIDs := []uint{1, 1, 1, 2, 2, 3} userIDs := []uint{1, 1, 1, 2, 2, 3}
i := 0 i := 0
@ -237,6 +250,37 @@ func TestMultipleClients(t *testing.T) {
expectNoMessage(userThree...) expectNoMessage(userThree...)
} }
func Test_sameOrigin_returnsTrue(t *testing.T) {
mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http://example.com")
actual := checkSameOrigin(req)
assert.True(t, actual)
}
func Test_emptyOrigin_returnsTrue(t *testing.T) {
mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
actual := checkSameOrigin(req)
assert.True(t, actual)
}
func Test_otherOrigin_returnsFalse(t *testing.T) {
mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http://otherexample.de")
actual := checkSameOrigin(req)
assert.False(t, actual)
}
func Test_invalidOrigin_returnsFalse(t *testing.T) {
mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http\\://otherexample.de")
actual := checkSameOrigin(req)
assert.False(t, actual)
}
func testClient(t *testing.T, url string) *testingClient { func testClient(t *testing.T, url string) *testingClient {
ws, _, err := websocket.DefaultDialer.Dial(url, nil) ws, _, err := websocket.DefaultDialer.Dial(url, nil)
assert.Nil(t, err) assert.Nil(t, err)
@ -297,7 +341,6 @@ func (c *testingClient) expectNoMessage() {
} }
func bootTestServer(handlerFunc gin.HandlerFunc) (*httptest.Server, *API) { func bootTestServer(handlerFunc gin.HandlerFunc) (*httptest.Server, *API) {
mode.Set(mode.TestDev)
r := gin.New() r := gin.New()
r.Use(handlerFunc) r.Use(handlerFunc)