Check origin on web socket requests in prod mode
This commit is contained in:
parent
f1aa490035
commit
090142c281
|
|
@ -10,16 +10,32 @@ import (
|
|||
"github.com/gotify/server/model"
|
||||
"net/http"
|
||||
"github.com/gotify/server/mode"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return mode.IsDev();
|
||||
if mode.IsDev() {
|
||||
return true
|
||||
}
|
||||
return checkSameOrigin(r)
|
||||
},
|
||||
}
|
||||
|
||||
func checkSameOrigin(r *http.Request) bool {
|
||||
origin := r.Header["Origin"]
|
||||
if len(origin) == 0 {
|
||||
return true
|
||||
}
|
||||
u, err := url.Parse(origin[0])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Host == r.Host
|
||||
}
|
||||
|
||||
// The API provides a handler for a WebSocket stream API.
|
||||
type API struct {
|
||||
clients map[uint][]*client
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ import (
|
|||
)
|
||||
|
||||
func TestFailureOnNormalHttpRequest(t *testing.T) {
|
||||
mode.Set(mode.TestDev)
|
||||
|
||||
defer leaktest.Check(t)()
|
||||
|
||||
server, api := bootTestServer(staticUserID())
|
||||
|
|
@ -38,6 +40,8 @@ func TestFailureOnNormalHttpRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestWriteMessageFails(t *testing.T) {
|
||||
mode.Set(mode.TestDev)
|
||||
|
||||
defer leaktest.Check(t)()
|
||||
|
||||
server, api := bootTestServer(func(context *gin.Context) {
|
||||
|
|
@ -65,6 +69,8 @@ func TestWriteMessageFails(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestWritePingFails(t *testing.T) {
|
||||
mode.Set(mode.TestDev)
|
||||
|
||||
defer leaktest.CheckTimeout(t, 10*time.Second)()
|
||||
|
||||
server, api := bootTestServer(staticUserID())
|
||||
|
|
@ -92,6 +98,8 @@ func TestWritePingFails(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
mode.Set(mode.TestDev)
|
||||
|
||||
server, api := bootTestServer(staticUserID())
|
||||
defer server.Close()
|
||||
defer api.Close()
|
||||
|
|
@ -124,6 +132,8 @@ func TestPing(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCloseClientOnNotReading(t *testing.T) {
|
||||
mode.Set(mode.TestDev)
|
||||
|
||||
server, api := bootTestServer(staticUserID())
|
||||
defer server.Close()
|
||||
defer api.Close()
|
||||
|
|
@ -146,6 +156,7 @@ func TestCloseClientOnNotReading(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMessageDirectlyAfterConnect(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
defer leaktest.Check(t)()
|
||||
server, api := bootTestServer(staticUserID())
|
||||
defer server.Close()
|
||||
|
|
@ -162,6 +173,8 @@ func TestMessageDirectlyAfterConnect(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMultipleClients(t *testing.T) {
|
||||
mode.Set(mode.TestDev)
|
||||
|
||||
defer leaktest.Check(t)()
|
||||
userIDs := []uint{1, 1, 1, 2, 2, 3}
|
||||
i := 0
|
||||
|
|
@ -237,6 +250,37 @@ func TestMultipleClients(t *testing.T) {
|
|||
expectNoMessage(userThree...)
|
||||
}
|
||||
|
||||
func Test_sameOrigin_returnsTrue(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
actual := checkSameOrigin(req)
|
||||
assert.True(t, actual)
|
||||
}
|
||||
|
||||
func Test_emptyOrigin_returnsTrue(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
|
||||
actual := checkSameOrigin(req)
|
||||
assert.True(t, actual)
|
||||
}
|
||||
|
||||
func Test_otherOrigin_returnsFalse(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
|
||||
req.Header.Set("Origin", "http://otherexample.de")
|
||||
actual := checkSameOrigin(req)
|
||||
assert.False(t, actual)
|
||||
}
|
||||
|
||||
func Test_invalidOrigin_returnsFalse(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
|
||||
req.Header.Set("Origin", "http\\://otherexample.de")
|
||||
actual := checkSameOrigin(req)
|
||||
assert.False(t, actual)
|
||||
}
|
||||
|
||||
func testClient(t *testing.T, url string) *testingClient {
|
||||
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -297,7 +341,6 @@ func (c *testingClient) expectNoMessage() {
|
|||
}
|
||||
|
||||
func bootTestServer(handlerFunc gin.HandlerFunc) (*httptest.Server, *API) {
|
||||
mode.Set(mode.TestDev)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(handlerFunc)
|
||||
|
|
|
|||
Loading…
Reference in New Issue