Allow multiple CORS origins

This commit is contained in:
Stewart Thomson 2020-04-26 07:27:24 -04:00 committed by GitHub
parent d22326bba8
commit 3f04d50088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 272 additions and 11 deletions

47
auth/cors.go Normal file
View File

@ -0,0 +1,47 @@
package auth
import (
"regexp"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gotify/server/config"
"github.com/gotify/server/mode"
)
// CorsConfig generates a config to use in gin cors middleware based on server configuration
func CorsConfig(conf *config.Configuration) cors.Config {
corsConf := cors.Config{
MaxAge: 12 * time.Hour,
}
if mode.IsDev() {
corsConf.AllowAllOrigins = true
corsConf.AllowMethods = []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"}
corsConf.AllowHeaders = []string{"X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin",
"Connection", "Accept-Encoding", "Accept-Language", "Host"}
} else {
compiledOrigins := compileAllowedCORSOrigins(conf.Server.Cors.AllowOrigins)
corsConf.AllowMethods = conf.Server.Cors.AllowMethods
corsConf.AllowHeaders = conf.Server.Cors.AllowHeaders
corsConf.AllowOriginFunc = func(origin string) bool {
for _, compiledOrigin := range compiledOrigins {
if compiledOrigin.Match([]byte(strings.ToLower(origin))) {
return true
}
}
return false
}
}
return corsConf
}
func compileAllowedCORSOrigins(allowedOrigins []string) []*regexp.Regexp {
var compiledAllowedOrigins []*regexp.Regexp
for _, origin := range allowedOrigins {
compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin))
}
return compiledAllowedOrigins
}

52
auth/cors_test.go Normal file
View File

@ -0,0 +1,52 @@
package auth
import (
"testing"
"time"
"github.com/gin-contrib/cors"
"github.com/gotify/server/config"
"github.com/gotify/server/mode"
"github.com/stretchr/testify/assert"
)
func TestCorsConfig(t *testing.T) {
mode.Set(mode.Prod)
serverConf := config.Configuration{}
serverConf.Server.Cors.AllowOrigins = []string{"http://test.com"}
serverConf.Server.Cors.AllowHeaders = []string{"content-type"}
serverConf.Server.Cors.AllowMethods = []string{"GET"}
actual := CorsConfig(&serverConf)
allowF := actual.AllowOriginFunc
actual.AllowOriginFunc = nil // func cannot be checked with equal
assert.Equal(t, cors.Config{
AllowAllOrigins: false,
AllowHeaders: []string{"content-type"},
AllowMethods: []string{"GET"},
MaxAge: 12 * time.Hour,
}, actual)
assert.NotNil(t, allowF)
assert.True(t, allowF("http://test.com"))
assert.False(t, allowF("https://test.com"))
assert.False(t, allowF("https://other.com"))
}
func TestDevCorsConfig(t *testing.T) {
mode.Set(mode.Dev)
serverConf := config.Configuration{}
serverConf.Server.Cors.AllowOrigins = []string{"http://test.com"}
serverConf.Server.Cors.AllowHeaders = []string{"content-type"}
serverConf.Server.Cors.AllowMethods = []string{"GET"}
actual := CorsConfig(&serverConf)
assert.Equal(t, cors.Config{
AllowHeaders: []string{"X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin",
"Connection", "Accept-Encoding", "Accept-Language", "Host"},
AllowMethods: []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"},
MaxAge: 12 * time.Hour,
AllowAllOrigins: true,
}, actual)
}

View File

@ -21,9 +21,18 @@ server:
# - myotherdomain.tld
responseheaders: # response headers are added to every response (default: none)
# Access-Control-Allow-Origin: "*"
# Access-Control-Allow-Methods: "GET,POST"
# X-Custom-Header: "custom value"
cors: # Sets cors headers only when needed and provides support for multiple allowed origins. Overrides Access-Control-* Headers in response headers.
alloworigins:
# - ".+.example.com"
# - "otherdomain.com"
allowmethods:
# - "GET"
# - "POST"
allowheaders:
# - "Authorization"
# - "content-type"
stream:
allowedorigins: # allowed origins for websocket connections (same origin is always allowed)
# - ".+.example.com"

View File

@ -31,6 +31,11 @@ type Configuration struct {
Stream struct {
AllowedOrigins []string
}
Cors struct {
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
}
}
Database struct {
Dialect string `default:"sqlite3"`

View File

@ -16,6 +16,9 @@ func TestConfigEnv(t *testing.T) {
os.Setenv("GOTIFY_SERVER_RESPONSEHEADERS",
"Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\"",
)
os.Setenv("GOTIFY_SERVER_CORS_ALLOWORIGINS", "- \".+.example.com\"\n- \"otherdomain.com\"")
os.Setenv("GOTIFY_SERVER_CORS_ALLOWMETHODS", "- \"GET\"\n- \"POST\"")
os.Setenv("GOTIFY_SERVER_CORS_ALLOWHEADERS", "- \"Authorization\"\n- \"content-type\"")
os.Setenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS", "- \".+.example.com\"\n- \"otherdomain.com\"")
conf := Get()
@ -24,11 +27,17 @@ func TestConfigEnv(t *testing.T) {
assert.Equal(t, []string{"push.example.tld", "push.other.tld"}, conf.Server.SSL.LetsEncrypt.Hosts)
assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"])
assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"])
assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Cors.AllowOrigins)
assert.Equal(t, []string{"GET", "POST"}, conf.Server.Cors.AllowMethods)
assert.Equal(t, []string{"Authorization", "content-type"}, conf.Server.Cors.AllowHeaders)
assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins)
os.Unsetenv("GOTIFY_DEFAULTUSER_NAME")
os.Unsetenv("GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS")
os.Unsetenv("GOTIFY_SERVER_RESPONSEHEADERS")
os.Unsetenv("GOTIFY_SERVER_CORS_ALLOWORIGINS")
os.Unsetenv("GOTIFY_SERVER_CORS_ALLOWMETHODS")
os.Unsetenv("GOTIFY_SERVER_CORS_ALLOWHEADERS")
os.Unsetenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS")
}
@ -85,6 +94,16 @@ server:
responseheaders:
Access-Control-Allow-Origin: "*"
Access-Control-Allow-Methods: "GET,POST"
cors:
alloworigins:
- ".*"
- ".+"
allowmethods:
- "GET"
- "POST"
allowheaders:
- "Authorization"
- "content-type"
stream:
allowedorigins:
- ".+.example.com"
@ -109,6 +128,9 @@ pluginsdir: data/plugins
assert.Equal(t, "user name", conf.Database.Connection)
assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"])
assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"])
assert.Equal(t, []string{".*", ".+"}, conf.Server.Cors.AllowOrigins)
assert.Equal(t, []string{"GET", "POST"}, conf.Server.Cors.AllowMethods)
assert.Equal(t, []string{"Authorization", "content-type"}, conf.Server.Cors.AllowHeaders)
assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins)
assert.Equal(t, "data/plugins", conf.PluginsDir)

1
go.mod
View File

@ -3,6 +3,7 @@ module github.com/gotify/server
require (
github.com/Southclaws/configor v1.0.0 // indirect
github.com/fortytw2/leaktest v1.3.0
github.com/gin-contrib/cors v1.3.1
github.com/gin-contrib/gzip v0.0.1
github.com/gin-gonic/gin v1.5.0
github.com/go-playground/universal-translator v0.17.0 // indirect

2
go.sum
View File

@ -32,6 +32,8 @@ github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/gin-contrib/cors v1.3.1 h1:doAsuITavI4IOcd0Y19U4B+O0dNWihRyX//nn4sEmgA=
github.com/gin-contrib/cors v1.3.1/go.mod h1:jjEJ4268OPZUcU7k9Pm653S7lXUGcqMADzFA61xsmDk=
github.com/gin-contrib/gzip v0.0.1 h1:ezvKOL6jH+jlzdHNE4h9h8q8uMpDQjyl0NN0Jd7jozc=
github.com/gin-contrib/gzip v0.0.1/go.mod h1:fGBJBCdt6qCZuCAOwWuFhBB4OOq9EFqlo5dEaFhhu5w=
github.com/gin-contrib/sse v0.0.0-20170109093832-22d885f9ecc7 h1:AzN37oI0cOS+cougNAV9szl6CVoj2RYwzS3DpUQNtlY=

View File

@ -3,8 +3,8 @@ package router
import (
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/gotify/location"
"github.com/gotify/server/api"
"github.com/gotify/server/api/stream"
@ -13,7 +13,6 @@ import (
"github.com/gotify/server/database"
"github.com/gotify/server/docs"
"github.com/gotify/server/error"
"github.com/gotify/server/mode"
"github.com/gotify/server/model"
"github.com/gotify/server/plugin"
"github.com/gotify/server/ui"
@ -65,17 +64,11 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
g.Use(func(ctx *gin.Context) {
ctx.Header("Content-Type", "application/json")
if mode.IsDev() {
ctx.Header("Access-Control-Allow-Origin", "*")
ctx.Header("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS,PUT")
ctx.Header("Access-Control-Allow-Headers", "X-Gotify-Key,Authorization,Content-Type,Upgrade,Origin,Connection,Accept-Encoding,Accept-Language,Host")
}
for header, value := range conf.Server.ResponseHeaders {
ctx.Header(header, value)
}
})
g.Use(cors.New(auth.CorsConfig(conf)))
{
g.GET("/plugin", authentication.RequireClient(), pluginHandler.GetPlugins)

View File

@ -61,6 +61,8 @@ func (s *IntegrationSuite) TestVersionInfo() {
func (s *IntegrationSuite) TestHeaderInDev() {
mode.Set(mode.TestDev)
req := s.newRequest("GET", "version", "")
//Needs an origin to indicate that it is a CORS request
req.Header.Add("Origin", "some-origin")
res, err := client.Do(req)
assert.Nil(s.T(), err)
@ -108,6 +110,134 @@ func TestHeadersFromConfiguration(t *testing.T) {
assert.Equal(t, "Nice", res.Header.Get("New-Cool-Header"))
}
func TestHeadersFromCORSConfig(t *testing.T) {
mode.Set(mode.Prod)
db := testdb.NewDBWithDefaultUser(t)
defer db.Close()
config := config.Configuration{PassStrength: 5}
config.Server.Cors.AllowOrigins = []string{"---", "http://test.com"}
g, closable := Create(db.GormDatabase,
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
&config,
)
server := httptest.NewServer(g)
defer func() {
closable()
server.Close()
}()
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Origin", "http://test.com")
assert.Nil(t, err)
res, err := client.Do(req)
assert.Nil(t, err)
assert.Equal(t, "http://test.com", res.Header.Get("Access-Control-Allow-Origin"))
}
func TestInvalidOrigin(t *testing.T) {
mode.Set(mode.Prod)
db := testdb.NewDBWithDefaultUser(t)
defer db.Close()
config := config.Configuration{PassStrength: 5}
config.Server.Cors.AllowOrigins = []string{"---", "http://test.com"}
g, closable := Create(db.GormDatabase,
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
&config,
)
server := httptest.NewServer(g)
defer func() {
closable()
server.Close()
}()
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil)
req.Header.Add("Origin", "http://test1.com")
assert.Nil(t, err)
res, err := client.Do(req)
assert.Nil(t, err)
assert.Equal(t, "", res.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, http.StatusForbidden, res.StatusCode)
}
func TestCORSHeaderRegex(t *testing.T) {
mode.Set(mode.Prod)
db := testdb.NewDBWithDefaultUser(t)
defer db.Close()
config := config.Configuration{PassStrength: 5}
config.Server.Cors.AllowOrigins = []string{"---", "^http://test\\d{3}.com$"}
g, closable := Create(db.GormDatabase,
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
&config,
)
server := httptest.NewServer(g)
defer func() {
closable()
server.Close()
}()
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Origin", "http://test123.com")
assert.Nil(t, err)
res, err := client.Do(req)
assert.Nil(t, err)
assert.Equal(t, "http://test123.com", res.Header.Get("Access-Control-Allow-Origin"))
}
// We want headers in cors config to override the responseheaders config
func TestCORSConfigOverride(t *testing.T) {
mode.Set(mode.Prod)
db := testdb.NewDBWithDefaultUser(t)
defer db.Close()
config := config.Configuration{PassStrength: 5}
config.Server.ResponseHeaders = map[string]string{
"New-Cool-Header": "Nice",
"Access-Control-Allow-Origin": "something-else",
"Access-Control-Allow-Methods": "321test",
"Access-Control-Allow-Headers": "some-headers",
}
config.Server.Cors.AllowOrigins = []string{"http://test123.com", "aaa"}
config.Server.Cors.AllowMethods = []string{"GET", "OPTIONS"}
config.Server.Cors.AllowHeaders = []string{"Content-Type"}
g, closable := Create(db.GormDatabase,
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
&config,
)
server := httptest.NewServer(g)
defer func() {
closable()
server.Close()
}()
req, err := http.NewRequest("OPTIONS", fmt.Sprintf("%s/%s", server.URL, "version"), nil)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Origin", "http://test123.com")
assert.Nil(t, err)
res, err := client.Do(req)
assert.Nil(t, err)
assert.Equal(t, "Nice", res.Header.Get("New-Cool-Header"))
assert.Equal(t, "http://test123.com", res.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "GET,OPTIONS", res.Header.Get("Access-Control-Allow-Methods"))
assert.Equal(t, "Content-Type", res.Header.Get("Access-Control-Allow-Headers"))
}
func (s *IntegrationSuite) TestOptionsRequest() {
req := s.newRequest("OPTIONS", "version", "")