[#23] Fix check same origin function

This commit is contained in:
Eugene Gavrilov 2018-12-06 01:25:44 +05:00 committed by Jannis Mattheis
parent 193dd67f2c
commit b5b2f19dc2
6 changed files with 117 additions and 44 deletions

View File

@ -82,6 +82,10 @@ server:
responseheaders: # response headers are added to every response (default: none) responseheaders: # response headers are added to every response (default: none)
Access-Control-Allow-Origin: "*" Access-Control-Allow-Origin: "*"
Access-Control-Allow-Methods: "GET,POST" Access-Control-Allow-Methods: "GET,POST"
stream:
allowedorigins: # allowed origins for websocket connections (same origin is always allowed)
- ".+.example.com"
- "otherdomain.com"
database: # for database see (configure database section) database: # for database see (configure database section)
dialect: sqlite3 dialect: sqlite3
connection: data/gotify.db connection: data/gotify.db
@ -111,6 +115,7 @@ GOTIFY_SERVER_SSL_LETSENCRYPT_CACHE=certs
# lists are a little weird but do-able (: # lists are a little weird but do-able (:
GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS=- mydomain.tld\n- myotherdomain.tld GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS=- mydomain.tld\n- myotherdomain.tld
GOTIFY_SERVER_RESPONSEHEADERS="Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\"" GOTIFY_SERVER_RESPONSEHEADERS="Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\""
GOTIFY_SERVER_STREAM_ALLOWEDORIGINS="- \".+.example.com\"\n- \"otherdomain.com\""
GOTIFY_DATABASE_DIALECT=sqlite3 GOTIFY_DATABASE_DIALECT=sqlite3
GOTIFY_DATABASE_CONNECTION=gotify.db GOTIFY_DATABASE_CONNECTION=gotify.db
GOTIFY_DEFAULTUSER_NAME=admin GOTIFY_DEFAULTUSER_NAME=admin

View File

@ -1,6 +1,7 @@
package stream package stream
import ( import (
"regexp"
"sync" "sync"
"time" "time"
@ -14,46 +15,25 @@ import (
"github.com/gotify/server/model" "github.com/gotify/server/model"
) )
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
if mode.IsDev() {
return true
}
return checkSameOrigin(r)
},
}
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return u.Host == r.Host
}
// The API provides a handler for a WebSocket stream API. // The API provides a handler for a WebSocket stream API.
type API struct { type API struct {
clients map[uint][]*client clients map[uint][]*client
lock sync.RWMutex lock sync.RWMutex
pingPeriod time.Duration pingPeriod time.Duration
pongTimeout time.Duration pongTimeout time.Duration
upgrader *websocket.Upgrader
} }
// New creates a new instance of API. // New creates a new instance of API.
// pingPeriod: is the interval, in which is server sends the a ping to the client. // pingPeriod: is the interval, in which is server sends the a ping to the client.
// pongTimeout: is the duration after the connection will be terminated, when the client does not respond with the // pongTimeout: is the duration after the connection will be terminated, when the client does not respond with the
// pong command. // pong command.
func New(pingPeriod, pongTimeout time.Duration) *API { func New(pingPeriod, pongTimeout time.Duration, allowedWebSocketOrigins []string) *API {
return &API{ return &API{
clients: make(map[uint][]*client), clients: make(map[uint][]*client),
pingPeriod: pingPeriod, pingPeriod: pingPeriod,
pongTimeout: pingPeriod + pongTimeout, pongTimeout: pingPeriod + pongTimeout,
upgrader: newUpgrader(allowedWebSocketOrigins),
} }
} }
@ -147,7 +127,7 @@ func (a *API) register(client *client) {
// schema: // schema:
// $ref: "#/definitions/Error" // $ref: "#/definitions/Error"
func (a *API) Handle(ctx *gin.Context) { func (a *API) Handle(ctx *gin.Context) {
conn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil) conn, err := a.upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
if err != nil { if err != nil {
return return
} }
@ -172,3 +152,50 @@ func (a *API) Close() {
delete(a.clients, k) delete(a.clients, k)
} }
} }
func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
if u.Hostname() == r.Host {
return true
}
for _, allowedOrigin := range allowedOrigins {
if allowedOrigin.Match([]byte(u.Hostname())) {
return true
}
}
return false
}
func newUpgrader(allowedWebSocketOrigins []string) *websocket.Upgrader {
compiledAllowedOrigins := compileAllowedWebSocketOrigins(allowedWebSocketOrigins)
return &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
if mode.IsDev() {
return true
}
return isAllowedOrigin(r, compiledAllowedOrigins)
},
}
}
func compileAllowedWebSocketOrigins(allowedOrigins []string) []*regexp.Regexp {
var compiledAllowedOrigins []*regexp.Regexp
for _, origin := range allowedOrigins {
compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin))
}
return compiledAllowedOrigins
}

View File

@ -405,14 +405,37 @@ func Test_sameOrigin_returnsTrue(t *testing.T) {
mode.Set(mode.Prod) mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil) req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http://example.com") req.Header.Set("Origin", "http://example.com")
actual := checkSameOrigin(req) actual := isAllowedOrigin(req, nil)
assert.True(t, actual) assert.True(t, actual)
} }
func Test_isAllowedOrigin_withoutAllowedOrigins_failsWhenNotSameOrigin(t *testing.T) {
mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http://gorify.example.com")
actual := isAllowedOrigin(req, nil)
assert.False(t, actual)
}
func Test_isAllowedOriginMatching(t *testing.T) {
mode.Set(mode.Prod)
compiledAllowedOrigins := compileAllowedWebSocketOrigins([]string{"go.{4}\\.example\\.com", "go\\.example\\.com"})
req := httptest.NewRequest("GET", "http://example.me/stream", nil)
req.Header.Set("Origin", "http://gorify.example.com")
assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins))
req.Header.Set("Origin", "http://go.example.com")
assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins))
req.Header.Set("Origin", "http://hello.example.com")
assert.False(t, isAllowedOrigin(req, compiledAllowedOrigins))
}
func Test_emptyOrigin_returnsTrue(t *testing.T) { func Test_emptyOrigin_returnsTrue(t *testing.T) {
mode.Set(mode.Prod) mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil) req := httptest.NewRequest("GET", "http://example.com/stream", nil)
actual := checkSameOrigin(req) actual := isAllowedOrigin(req, nil)
assert.True(t, actual) assert.True(t, actual)
} }
@ -420,7 +443,7 @@ func Test_otherOrigin_returnsFalse(t *testing.T) {
mode.Set(mode.Prod) mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil) req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http://otherexample.de") req.Header.Set("Origin", "http://otherexample.de")
actual := checkSameOrigin(req) actual := isAllowedOrigin(req, nil)
assert.False(t, actual) assert.False(t, actual)
} }
@ -428,10 +451,15 @@ func Test_invalidOrigin_returnsFalse(t *testing.T) {
mode.Set(mode.Prod) mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil) req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http\\://otherexample.de") req.Header.Set("Origin", "http\\://otherexample.de")
actual := checkSameOrigin(req) actual := isAllowedOrigin(req, nil)
assert.False(t, actual) assert.False(t, actual)
} }
func Test_compileAllowedWebSocketOrigins(t *testing.T) {
assert.Equal(t, 0, len(compileAllowedWebSocketOrigins([]string{})))
assert.Equal(t, 3, len(compileAllowedWebSocketOrigins([]string{"^.*$", "", "abc"})))
}
func clients(api *API, user uint) []*client { func clients(api *API, user uint) []*client {
api.lock.RLock() api.lock.RLock()
defer api.lock.RUnlock() defer api.lock.RUnlock()
@ -439,7 +467,7 @@ func clients(api *API, user uint) []*client {
return api.clients[user] return api.clients[user]
} }
func testClient(t *testing.T, url string) *testingClient { func testClient(t *testing.T, url string) *testingClient {
client := createClient(t, url) client := createClient(t, url)
startReading(client) startReading(client)
return client return client
@ -507,11 +535,11 @@ func (c *testingClient) expectNoMessage() {
} }
func bootTestServer(handlerFunc gin.HandlerFunc) (*httptest.Server, *API) { func bootTestServer(handlerFunc gin.HandlerFunc) (*httptest.Server, *API) {
r := gin.New() r := gin.New()
r.Use(handlerFunc) r.Use(handlerFunc)
// all 4 seconds a ping, and the client has 1 second to respond // all 4 seconds a ping, and the client has 1 second to respond
api := New(4*time.Second, 1*time.Second) api := New(4*time.Second, 1*time.Second, []string{})
r.GET("/", api.Handle) r.GET("/", api.Handle)
server := httptest.NewServer(r) server := httptest.NewServer(r)
return server, api return server, api

View File

@ -25,6 +25,9 @@ type Configuration struct {
} }
} }
ResponseHeaders map[string]string ResponseHeaders map[string]string
Stream struct {
AllowedOrigins []string
}
} }
Database struct { Database struct {
Dialect string `default:"sqlite3"` Dialect string `default:"sqlite3"`

View File

@ -15,15 +15,20 @@ func TestConfigEnv(t *testing.T) {
os.Setenv("GOTIFY_SERVER_RESPONSEHEADERS", os.Setenv("GOTIFY_SERVER_RESPONSEHEADERS",
"Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\"", "Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\"",
) )
os.Setenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS", "- \".+.example.com\"\n- \"otherdomain.com\"")
conf := Get() conf := Get()
assert.Equal(t, 80, conf.Server.Port, "should use defaults") assert.Equal(t, 80, conf.Server.Port, "should use defaults")
assert.Equal(t, "jmattheis", conf.DefaultUser.Name, "should not use default but env var") assert.Equal(t, "jmattheis", conf.DefaultUser.Name, "should not use default but env var")
assert.Equal(t, []string{"push.example.tld", "push.other.tld"}, conf.Server.SSL.LetsEncrypt.Hosts) assert.Equal(t, []string{"push.example.tld", "push.other.tld"}, conf.Server.SSL.LetsEncrypt.Hosts)
assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"]) assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"])
assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"]) assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"])
assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins)
os.Unsetenv("GOTIFY_DEFAULTUSER_NAME") os.Unsetenv("GOTIFY_DEFAULTUSER_NAME")
os.Unsetenv("GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS") os.Unsetenv("GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS")
os.Unsetenv("GOTIFY_SERVER_RESPONSEHEADERS") os.Unsetenv("GOTIFY_SERVER_RESPONSEHEADERS")
os.Unsetenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS")
} }
func TestAddSlash(t *testing.T) { func TestAddSlash(t *testing.T) {
@ -75,6 +80,10 @@ server:
responseheaders: responseheaders:
Access-Control-Allow-Origin: "*" Access-Control-Allow-Origin: "*"
Access-Control-Allow-Methods: "GET,POST" Access-Control-Allow-Methods: "GET,POST"
stream:
allowedorigins:
- ".+.example.com"
- "otherdomain.com"
database: database:
dialect: mysql dialect: mysql
connection: user name connection: user name
@ -94,6 +103,7 @@ defaultuser:
assert.Equal(t, "user name", conf.Database.Connection) assert.Equal(t, "user name", conf.Database.Connection)
assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"]) assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"])
assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"]) assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"])
assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins)
assert.Nil(t, os.Remove("config.yml")) assert.Nil(t, os.Remove("config.yml"))
} }

View File

@ -23,7 +23,7 @@ import (
// Create creates the gin engine with all routes. // Create creates the gin engine with all routes.
func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Configuration) (*gin.Engine, func()) { func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Configuration) (*gin.Engine, func()) {
streamHandler := stream.New(200*time.Second, 15*time.Second) streamHandler := stream.New(200*time.Second, 15*time.Second, conf.Server.Stream.AllowedOrigins)
authentication := auth.Auth{DB: db} authentication := auth.Auth{DB: db}
messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db} messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db}
clientHandler := api.ClientAPI{ clientHandler := api.ClientAPI{