diff --git a/stream/stream.go b/stream/stream.go index 7242df4..a4587a9 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -10,16 +10,32 @@ import ( "github.com/gotify/server/model" "net/http" "github.com/gotify/server/mode" + "net/url" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { - return mode.IsDev(); + if mode.IsDev() { + return true + } + return checkSameOrigin(r) }, } +func checkSameOrigin(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + return u.Host == r.Host +} + // The API provides a handler for a WebSocket stream API. type API struct { clients map[uint][]*client diff --git a/stream/stream_test.go b/stream/stream_test.go index d253611..5d23bc7 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -25,6 +25,8 @@ import ( ) func TestFailureOnNormalHttpRequest(t *testing.T) { + mode.Set(mode.TestDev) + defer leaktest.Check(t)() server, api := bootTestServer(staticUserID()) @@ -38,6 +40,8 @@ func TestFailureOnNormalHttpRequest(t *testing.T) { } func TestWriteMessageFails(t *testing.T) { + mode.Set(mode.TestDev) + defer leaktest.Check(t)() server, api := bootTestServer(func(context *gin.Context) { @@ -65,6 +69,8 @@ func TestWriteMessageFails(t *testing.T) { } func TestWritePingFails(t *testing.T) { + mode.Set(mode.TestDev) + defer leaktest.CheckTimeout(t, 10*time.Second)() server, api := bootTestServer(staticUserID()) @@ -92,6 +98,8 @@ func TestWritePingFails(t *testing.T) { } func TestPing(t *testing.T) { + mode.Set(mode.TestDev) + server, api := bootTestServer(staticUserID()) defer server.Close() defer api.Close() @@ -124,6 +132,8 @@ func TestPing(t *testing.T) { } func TestCloseClientOnNotReading(t *testing.T) { + mode.Set(mode.TestDev) + server, api := bootTestServer(staticUserID()) defer server.Close() defer api.Close() @@ -146,6 +156,7 @@ func TestCloseClientOnNotReading(t *testing.T) { } func TestMessageDirectlyAfterConnect(t *testing.T) { + mode.Set(mode.Prod) defer leaktest.Check(t)() server, api := bootTestServer(staticUserID()) defer server.Close() @@ -162,6 +173,8 @@ func TestMessageDirectlyAfterConnect(t *testing.T) { } func TestMultipleClients(t *testing.T) { + mode.Set(mode.TestDev) + defer leaktest.Check(t)() userIDs := []uint{1, 1, 1, 2, 2, 3} i := 0 @@ -237,6 +250,37 @@ func TestMultipleClients(t *testing.T) { expectNoMessage(userThree...) } +func Test_sameOrigin_returnsTrue(t *testing.T) { + mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + req.Header.Set("Origin", "http://example.com") + actual := checkSameOrigin(req) + assert.True(t, actual) +} + +func Test_emptyOrigin_returnsTrue(t *testing.T) { + mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + actual := checkSameOrigin(req) + assert.True(t, actual) +} + +func Test_otherOrigin_returnsFalse(t *testing.T) { + mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + req.Header.Set("Origin", "http://otherexample.de") + actual := checkSameOrigin(req) + assert.False(t, actual) +} + +func Test_invalidOrigin_returnsFalse(t *testing.T) { + mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + req.Header.Set("Origin", "http\\://otherexample.de") + actual := checkSameOrigin(req) + assert.False(t, actual) +} + func testClient(t *testing.T, url string) *testingClient { ws, _, err := websocket.DefaultDialer.Dial(url, nil) assert.Nil(t, err) @@ -297,7 +341,6 @@ func (c *testingClient) expectNoMessage() { } func bootTestServer(handlerFunc gin.HandlerFunc) (*httptest.Server, *API) { - mode.Set(mode.TestDev) r := gin.New() r.Use(handlerFunc)