From 3f04d50088c8eec70a46547ebaae2269a96a5436 Mon Sep 17 00:00:00 2001 From: Stewart Thomson Date: Sun, 26 Apr 2020 07:27:24 -0400 Subject: [PATCH] Allow multiple CORS origins --- auth/cors.go | 47 +++++++++++++++ auth/cors_test.go | 52 +++++++++++++++++ config.example.yml | 13 ++++- config/config.go | 5 ++ config/config_test.go | 22 +++++++ go.mod | 1 + go.sum | 2 + router/router.go | 11 +--- router/router_test.go | 130 ++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 272 insertions(+), 11 deletions(-) create mode 100644 auth/cors.go create mode 100644 auth/cors_test.go diff --git a/auth/cors.go b/auth/cors.go new file mode 100644 index 0000000..60dfcae --- /dev/null +++ b/auth/cors.go @@ -0,0 +1,47 @@ +package auth + +import ( + "regexp" + "strings" + "time" + + "github.com/gin-contrib/cors" + "github.com/gotify/server/config" + "github.com/gotify/server/mode" +) + +// CorsConfig generates a config to use in gin cors middleware based on server configuration +func CorsConfig(conf *config.Configuration) cors.Config { + corsConf := cors.Config{ + MaxAge: 12 * time.Hour, + } + if mode.IsDev() { + corsConf.AllowAllOrigins = true + corsConf.AllowMethods = []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"} + corsConf.AllowHeaders = []string{"X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin", + "Connection", "Accept-Encoding", "Accept-Language", "Host"} + } else { + compiledOrigins := compileAllowedCORSOrigins(conf.Server.Cors.AllowOrigins) + corsConf.AllowMethods = conf.Server.Cors.AllowMethods + corsConf.AllowHeaders = conf.Server.Cors.AllowHeaders + corsConf.AllowOriginFunc = func(origin string) bool { + for _, compiledOrigin := range compiledOrigins { + if compiledOrigin.Match([]byte(strings.ToLower(origin))) { + return true + } + } + return false + } + } + + return corsConf +} + +func compileAllowedCORSOrigins(allowedOrigins []string) []*regexp.Regexp { + var compiledAllowedOrigins []*regexp.Regexp + for _, origin := range allowedOrigins { + compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin)) + } + + return compiledAllowedOrigins +} diff --git a/auth/cors_test.go b/auth/cors_test.go new file mode 100644 index 0000000..966cac3 --- /dev/null +++ b/auth/cors_test.go @@ -0,0 +1,52 @@ +package auth + +import ( + "testing" + "time" + + "github.com/gin-contrib/cors" + "github.com/gotify/server/config" + "github.com/gotify/server/mode" + "github.com/stretchr/testify/assert" +) + +func TestCorsConfig(t *testing.T) { + mode.Set(mode.Prod) + serverConf := config.Configuration{} + serverConf.Server.Cors.AllowOrigins = []string{"http://test.com"} + serverConf.Server.Cors.AllowHeaders = []string{"content-type"} + serverConf.Server.Cors.AllowMethods = []string{"GET"} + + actual := CorsConfig(&serverConf) + allowF := actual.AllowOriginFunc + actual.AllowOriginFunc = nil // func cannot be checked with equal + + assert.Equal(t, cors.Config{ + AllowAllOrigins: false, + AllowHeaders: []string{"content-type"}, + AllowMethods: []string{"GET"}, + MaxAge: 12 * time.Hour, + }, actual) + assert.NotNil(t, allowF) + assert.True(t, allowF("http://test.com")) + assert.False(t, allowF("https://test.com")) + assert.False(t, allowF("https://other.com")) +} + +func TestDevCorsConfig(t *testing.T) { + mode.Set(mode.Dev) + serverConf := config.Configuration{} + serverConf.Server.Cors.AllowOrigins = []string{"http://test.com"} + serverConf.Server.Cors.AllowHeaders = []string{"content-type"} + serverConf.Server.Cors.AllowMethods = []string{"GET"} + + actual := CorsConfig(&serverConf) + + assert.Equal(t, cors.Config{ + AllowHeaders: []string{"X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin", + "Connection", "Accept-Encoding", "Accept-Language", "Host"}, + AllowMethods: []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"}, + MaxAge: 12 * time.Hour, + AllowAllOrigins: true, + }, actual) +} diff --git a/config.example.yml b/config.example.yml index 6ca3f6a..7192f9e 100644 --- a/config.example.yml +++ b/config.example.yml @@ -21,9 +21,18 @@ server: # - myotherdomain.tld responseheaders: # response headers are added to every response (default: none) -# Access-Control-Allow-Origin: "*" -# Access-Control-Allow-Methods: "GET,POST" +# X-Custom-Header: "custom value" + cors: # Sets cors headers only when needed and provides support for multiple allowed origins. Overrides Access-Control-* Headers in response headers. + alloworigins: +# - ".+.example.com" +# - "otherdomain.com" + allowmethods: +# - "GET" +# - "POST" + allowheaders: +# - "Authorization" +# - "content-type" stream: allowedorigins: # allowed origins for websocket connections (same origin is always allowed) # - ".+.example.com" diff --git a/config/config.go b/config/config.go index 8318c66..778239e 100644 --- a/config/config.go +++ b/config/config.go @@ -31,6 +31,11 @@ type Configuration struct { Stream struct { AllowedOrigins []string } + Cors struct { + AllowOrigins []string + AllowMethods []string + AllowHeaders []string + } } Database struct { Dialect string `default:"sqlite3"` diff --git a/config/config_test.go b/config/config_test.go index faff06e..03105bc 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -16,6 +16,9 @@ func TestConfigEnv(t *testing.T) { os.Setenv("GOTIFY_SERVER_RESPONSEHEADERS", "Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\"", ) + os.Setenv("GOTIFY_SERVER_CORS_ALLOWORIGINS", "- \".+.example.com\"\n- \"otherdomain.com\"") + os.Setenv("GOTIFY_SERVER_CORS_ALLOWMETHODS", "- \"GET\"\n- \"POST\"") + os.Setenv("GOTIFY_SERVER_CORS_ALLOWHEADERS", "- \"Authorization\"\n- \"content-type\"") os.Setenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS", "- \".+.example.com\"\n- \"otherdomain.com\"") conf := Get() @@ -24,11 +27,17 @@ func TestConfigEnv(t *testing.T) { assert.Equal(t, []string{"push.example.tld", "push.other.tld"}, conf.Server.SSL.LetsEncrypt.Hosts) assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"]) assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"]) + assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Cors.AllowOrigins) + assert.Equal(t, []string{"GET", "POST"}, conf.Server.Cors.AllowMethods) + assert.Equal(t, []string{"Authorization", "content-type"}, conf.Server.Cors.AllowHeaders) assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins) os.Unsetenv("GOTIFY_DEFAULTUSER_NAME") os.Unsetenv("GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS") os.Unsetenv("GOTIFY_SERVER_RESPONSEHEADERS") + os.Unsetenv("GOTIFY_SERVER_CORS_ALLOWORIGINS") + os.Unsetenv("GOTIFY_SERVER_CORS_ALLOWMETHODS") + os.Unsetenv("GOTIFY_SERVER_CORS_ALLOWHEADERS") os.Unsetenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS") } @@ -85,6 +94,16 @@ server: responseheaders: Access-Control-Allow-Origin: "*" Access-Control-Allow-Methods: "GET,POST" + cors: + alloworigins: + - ".*" + - ".+" + allowmethods: + - "GET" + - "POST" + allowheaders: + - "Authorization" + - "content-type" stream: allowedorigins: - ".+.example.com" @@ -109,6 +128,9 @@ pluginsdir: data/plugins assert.Equal(t, "user name", conf.Database.Connection) assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"]) assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"]) + assert.Equal(t, []string{".*", ".+"}, conf.Server.Cors.AllowOrigins) + assert.Equal(t, []string{"GET", "POST"}, conf.Server.Cors.AllowMethods) + assert.Equal(t, []string{"Authorization", "content-type"}, conf.Server.Cors.AllowHeaders) assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins) assert.Equal(t, "data/plugins", conf.PluginsDir) diff --git a/go.mod b/go.mod index 874502f..d146987 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/gotify/server require ( github.com/Southclaws/configor v1.0.0 // indirect github.com/fortytw2/leaktest v1.3.0 + github.com/gin-contrib/cors v1.3.1 github.com/gin-contrib/gzip v0.0.1 github.com/gin-gonic/gin v1.5.0 github.com/go-playground/universal-translator v0.17.0 // indirect diff --git a/go.sum b/go.sum index 5be1604..3aae970 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/gin-contrib/cors v1.3.1 h1:doAsuITavI4IOcd0Y19U4B+O0dNWihRyX//nn4sEmgA= +github.com/gin-contrib/cors v1.3.1/go.mod h1:jjEJ4268OPZUcU7k9Pm653S7lXUGcqMADzFA61xsmDk= github.com/gin-contrib/gzip v0.0.1 h1:ezvKOL6jH+jlzdHNE4h9h8q8uMpDQjyl0NN0Jd7jozc= github.com/gin-contrib/gzip v0.0.1/go.mod h1:fGBJBCdt6qCZuCAOwWuFhBB4OOq9EFqlo5dEaFhhu5w= github.com/gin-contrib/sse v0.0.0-20170109093832-22d885f9ecc7 h1:AzN37oI0cOS+cougNAV9szl6CVoj2RYwzS3DpUQNtlY= diff --git a/router/router.go b/router/router.go index f3cb87f..470096b 100644 --- a/router/router.go +++ b/router/router.go @@ -3,8 +3,8 @@ package router import ( "time" + "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" - "github.com/gotify/location" "github.com/gotify/server/api" "github.com/gotify/server/api/stream" @@ -13,7 +13,6 @@ import ( "github.com/gotify/server/database" "github.com/gotify/server/docs" "github.com/gotify/server/error" - "github.com/gotify/server/mode" "github.com/gotify/server/model" "github.com/gotify/server/plugin" "github.com/gotify/server/ui" @@ -65,17 +64,11 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co g.Use(func(ctx *gin.Context) { ctx.Header("Content-Type", "application/json") - - if mode.IsDev() { - ctx.Header("Access-Control-Allow-Origin", "*") - ctx.Header("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS,PUT") - ctx.Header("Access-Control-Allow-Headers", "X-Gotify-Key,Authorization,Content-Type,Upgrade,Origin,Connection,Accept-Encoding,Accept-Language,Host") - } - for header, value := range conf.Server.ResponseHeaders { ctx.Header(header, value) } }) + g.Use(cors.New(auth.CorsConfig(conf))) { g.GET("/plugin", authentication.RequireClient(), pluginHandler.GetPlugins) diff --git a/router/router_test.go b/router/router_test.go index 9d889c9..46a6a66 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -61,6 +61,8 @@ func (s *IntegrationSuite) TestVersionInfo() { func (s *IntegrationSuite) TestHeaderInDev() { mode.Set(mode.TestDev) req := s.newRequest("GET", "version", "") + //Needs an origin to indicate that it is a CORS request + req.Header.Add("Origin", "some-origin") res, err := client.Do(req) assert.Nil(s.T(), err) @@ -108,6 +110,134 @@ func TestHeadersFromConfiguration(t *testing.T) { assert.Equal(t, "Nice", res.Header.Get("New-Cool-Header")) } +func TestHeadersFromCORSConfig(t *testing.T) { + mode.Set(mode.Prod) + db := testdb.NewDBWithDefaultUser(t) + defer db.Close() + + config := config.Configuration{PassStrength: 5} + config.Server.Cors.AllowOrigins = []string{"---", "http://test.com"} + + g, closable := Create(db.GormDatabase, + &model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"}, + &config, + ) + server := httptest.NewServer(g) + + defer func() { + closable() + server.Close() + }() + + req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Origin", "http://test.com") + assert.Nil(t, err) + + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, "http://test.com", res.Header.Get("Access-Control-Allow-Origin")) +} + +func TestInvalidOrigin(t *testing.T) { + mode.Set(mode.Prod) + db := testdb.NewDBWithDefaultUser(t) + defer db.Close() + + config := config.Configuration{PassStrength: 5} + config.Server.Cors.AllowOrigins = []string{"---", "http://test.com"} + + g, closable := Create(db.GormDatabase, + &model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"}, + &config, + ) + server := httptest.NewServer(g) + + defer func() { + closable() + server.Close() + }() + + req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil) + req.Header.Add("Origin", "http://test1.com") + assert.Nil(t, err) + + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, "", res.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, http.StatusForbidden, res.StatusCode) +} + +func TestCORSHeaderRegex(t *testing.T) { + mode.Set(mode.Prod) + db := testdb.NewDBWithDefaultUser(t) + defer db.Close() + + config := config.Configuration{PassStrength: 5} + config.Server.Cors.AllowOrigins = []string{"---", "^http://test\\d{3}.com$"} + + g, closable := Create(db.GormDatabase, + &model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"}, + &config, + ) + server := httptest.NewServer(g) + + defer func() { + closable() + server.Close() + }() + + req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Origin", "http://test123.com") + assert.Nil(t, err) + + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, "http://test123.com", res.Header.Get("Access-Control-Allow-Origin")) +} + +// We want headers in cors config to override the responseheaders config +func TestCORSConfigOverride(t *testing.T) { + mode.Set(mode.Prod) + db := testdb.NewDBWithDefaultUser(t) + defer db.Close() + + config := config.Configuration{PassStrength: 5} + config.Server.ResponseHeaders = map[string]string{ + "New-Cool-Header": "Nice", + "Access-Control-Allow-Origin": "something-else", + "Access-Control-Allow-Methods": "321test", + "Access-Control-Allow-Headers": "some-headers", + } + config.Server.Cors.AllowOrigins = []string{"http://test123.com", "aaa"} + config.Server.Cors.AllowMethods = []string{"GET", "OPTIONS"} + config.Server.Cors.AllowHeaders = []string{"Content-Type"} + + g, closable := Create(db.GormDatabase, + &model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"}, + &config, + ) + server := httptest.NewServer(g) + + defer func() { + closable() + server.Close() + }() + + req, err := http.NewRequest("OPTIONS", fmt.Sprintf("%s/%s", server.URL, "version"), nil) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Origin", "http://test123.com") + assert.Nil(t, err) + + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, "Nice", res.Header.Get("New-Cool-Header")) + assert.Equal(t, "http://test123.com", res.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, "GET,OPTIONS", res.Header.Get("Access-Control-Allow-Methods")) + assert.Equal(t, "Content-Type", res.Header.Get("Access-Control-Allow-Headers")) +} + func (s *IntegrationSuite) TestOptionsRequest() { req := s.newRequest("OPTIONS", "version", "")