Allow multiple CORS origins
This commit is contained in:
parent
d22326bba8
commit
3f04d50088
|
|
@ -0,0 +1,47 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gotify/server/config"
|
||||
"github.com/gotify/server/mode"
|
||||
)
|
||||
|
||||
// CorsConfig generates a config to use in gin cors middleware based on server configuration
|
||||
func CorsConfig(conf *config.Configuration) cors.Config {
|
||||
corsConf := cors.Config{
|
||||
MaxAge: 12 * time.Hour,
|
||||
}
|
||||
if mode.IsDev() {
|
||||
corsConf.AllowAllOrigins = true
|
||||
corsConf.AllowMethods = []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"}
|
||||
corsConf.AllowHeaders = []string{"X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin",
|
||||
"Connection", "Accept-Encoding", "Accept-Language", "Host"}
|
||||
} else {
|
||||
compiledOrigins := compileAllowedCORSOrigins(conf.Server.Cors.AllowOrigins)
|
||||
corsConf.AllowMethods = conf.Server.Cors.AllowMethods
|
||||
corsConf.AllowHeaders = conf.Server.Cors.AllowHeaders
|
||||
corsConf.AllowOriginFunc = func(origin string) bool {
|
||||
for _, compiledOrigin := range compiledOrigins {
|
||||
if compiledOrigin.Match([]byte(strings.ToLower(origin))) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return corsConf
|
||||
}
|
||||
|
||||
func compileAllowedCORSOrigins(allowedOrigins []string) []*regexp.Regexp {
|
||||
var compiledAllowedOrigins []*regexp.Regexp
|
||||
for _, origin := range allowedOrigins {
|
||||
compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin))
|
||||
}
|
||||
|
||||
return compiledAllowedOrigins
|
||||
}
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gotify/server/config"
|
||||
"github.com/gotify/server/mode"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCorsConfig(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
serverConf := config.Configuration{}
|
||||
serverConf.Server.Cors.AllowOrigins = []string{"http://test.com"}
|
||||
serverConf.Server.Cors.AllowHeaders = []string{"content-type"}
|
||||
serverConf.Server.Cors.AllowMethods = []string{"GET"}
|
||||
|
||||
actual := CorsConfig(&serverConf)
|
||||
allowF := actual.AllowOriginFunc
|
||||
actual.AllowOriginFunc = nil // func cannot be checked with equal
|
||||
|
||||
assert.Equal(t, cors.Config{
|
||||
AllowAllOrigins: false,
|
||||
AllowHeaders: []string{"content-type"},
|
||||
AllowMethods: []string{"GET"},
|
||||
MaxAge: 12 * time.Hour,
|
||||
}, actual)
|
||||
assert.NotNil(t, allowF)
|
||||
assert.True(t, allowF("http://test.com"))
|
||||
assert.False(t, allowF("https://test.com"))
|
||||
assert.False(t, allowF("https://other.com"))
|
||||
}
|
||||
|
||||
func TestDevCorsConfig(t *testing.T) {
|
||||
mode.Set(mode.Dev)
|
||||
serverConf := config.Configuration{}
|
||||
serverConf.Server.Cors.AllowOrigins = []string{"http://test.com"}
|
||||
serverConf.Server.Cors.AllowHeaders = []string{"content-type"}
|
||||
serverConf.Server.Cors.AllowMethods = []string{"GET"}
|
||||
|
||||
actual := CorsConfig(&serverConf)
|
||||
|
||||
assert.Equal(t, cors.Config{
|
||||
AllowHeaders: []string{"X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin",
|
||||
"Connection", "Accept-Encoding", "Accept-Language", "Host"},
|
||||
AllowMethods: []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"},
|
||||
MaxAge: 12 * time.Hour,
|
||||
AllowAllOrigins: true,
|
||||
}, actual)
|
||||
}
|
||||
|
|
@ -21,9 +21,18 @@ server:
|
|||
# - myotherdomain.tld
|
||||
|
||||
responseheaders: # response headers are added to every response (default: none)
|
||||
# Access-Control-Allow-Origin: "*"
|
||||
# Access-Control-Allow-Methods: "GET,POST"
|
||||
# X-Custom-Header: "custom value"
|
||||
|
||||
cors: # Sets cors headers only when needed and provides support for multiple allowed origins. Overrides Access-Control-* Headers in response headers.
|
||||
alloworigins:
|
||||
# - ".+.example.com"
|
||||
# - "otherdomain.com"
|
||||
allowmethods:
|
||||
# - "GET"
|
||||
# - "POST"
|
||||
allowheaders:
|
||||
# - "Authorization"
|
||||
# - "content-type"
|
||||
stream:
|
||||
allowedorigins: # allowed origins for websocket connections (same origin is always allowed)
|
||||
# - ".+.example.com"
|
||||
|
|
|
|||
|
|
@ -31,6 +31,11 @@ type Configuration struct {
|
|||
Stream struct {
|
||||
AllowedOrigins []string
|
||||
}
|
||||
Cors struct {
|
||||
AllowOrigins []string
|
||||
AllowMethods []string
|
||||
AllowHeaders []string
|
||||
}
|
||||
}
|
||||
Database struct {
|
||||
Dialect string `default:"sqlite3"`
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ func TestConfigEnv(t *testing.T) {
|
|||
os.Setenv("GOTIFY_SERVER_RESPONSEHEADERS",
|
||||
"Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\"",
|
||||
)
|
||||
os.Setenv("GOTIFY_SERVER_CORS_ALLOWORIGINS", "- \".+.example.com\"\n- \"otherdomain.com\"")
|
||||
os.Setenv("GOTIFY_SERVER_CORS_ALLOWMETHODS", "- \"GET\"\n- \"POST\"")
|
||||
os.Setenv("GOTIFY_SERVER_CORS_ALLOWHEADERS", "- \"Authorization\"\n- \"content-type\"")
|
||||
os.Setenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS", "- \".+.example.com\"\n- \"otherdomain.com\"")
|
||||
|
||||
conf := Get()
|
||||
|
|
@ -24,11 +27,17 @@ func TestConfigEnv(t *testing.T) {
|
|||
assert.Equal(t, []string{"push.example.tld", "push.other.tld"}, conf.Server.SSL.LetsEncrypt.Hosts)
|
||||
assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"])
|
||||
assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"])
|
||||
assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Cors.AllowOrigins)
|
||||
assert.Equal(t, []string{"GET", "POST"}, conf.Server.Cors.AllowMethods)
|
||||
assert.Equal(t, []string{"Authorization", "content-type"}, conf.Server.Cors.AllowHeaders)
|
||||
assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins)
|
||||
|
||||
os.Unsetenv("GOTIFY_DEFAULTUSER_NAME")
|
||||
os.Unsetenv("GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS")
|
||||
os.Unsetenv("GOTIFY_SERVER_RESPONSEHEADERS")
|
||||
os.Unsetenv("GOTIFY_SERVER_CORS_ALLOWORIGINS")
|
||||
os.Unsetenv("GOTIFY_SERVER_CORS_ALLOWMETHODS")
|
||||
os.Unsetenv("GOTIFY_SERVER_CORS_ALLOWHEADERS")
|
||||
os.Unsetenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS")
|
||||
}
|
||||
|
||||
|
|
@ -85,6 +94,16 @@ server:
|
|||
responseheaders:
|
||||
Access-Control-Allow-Origin: "*"
|
||||
Access-Control-Allow-Methods: "GET,POST"
|
||||
cors:
|
||||
alloworigins:
|
||||
- ".*"
|
||||
- ".+"
|
||||
allowmethods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
allowheaders:
|
||||
- "Authorization"
|
||||
- "content-type"
|
||||
stream:
|
||||
allowedorigins:
|
||||
- ".+.example.com"
|
||||
|
|
@ -109,6 +128,9 @@ pluginsdir: data/plugins
|
|||
assert.Equal(t, "user name", conf.Database.Connection)
|
||||
assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"])
|
||||
assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"])
|
||||
assert.Equal(t, []string{".*", ".+"}, conf.Server.Cors.AllowOrigins)
|
||||
assert.Equal(t, []string{"GET", "POST"}, conf.Server.Cors.AllowMethods)
|
||||
assert.Equal(t, []string{"Authorization", "content-type"}, conf.Server.Cors.AllowHeaders)
|
||||
assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins)
|
||||
assert.Equal(t, "data/plugins", conf.PluginsDir)
|
||||
|
||||
|
|
|
|||
1
go.mod
1
go.mod
|
|
@ -3,6 +3,7 @@ module github.com/gotify/server
|
|||
require (
|
||||
github.com/Southclaws/configor v1.0.0 // indirect
|
||||
github.com/fortytw2/leaktest v1.3.0
|
||||
github.com/gin-contrib/cors v1.3.1
|
||||
github.com/gin-contrib/gzip v0.0.1
|
||||
github.com/gin-gonic/gin v1.5.0
|
||||
github.com/go-playground/universal-translator v0.17.0 // indirect
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -32,6 +32,8 @@ github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a
|
|||
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
|
||||
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/gin-contrib/cors v1.3.1 h1:doAsuITavI4IOcd0Y19U4B+O0dNWihRyX//nn4sEmgA=
|
||||
github.com/gin-contrib/cors v1.3.1/go.mod h1:jjEJ4268OPZUcU7k9Pm653S7lXUGcqMADzFA61xsmDk=
|
||||
github.com/gin-contrib/gzip v0.0.1 h1:ezvKOL6jH+jlzdHNE4h9h8q8uMpDQjyl0NN0Jd7jozc=
|
||||
github.com/gin-contrib/gzip v0.0.1/go.mod h1:fGBJBCdt6qCZuCAOwWuFhBB4OOq9EFqlo5dEaFhhu5w=
|
||||
github.com/gin-contrib/sse v0.0.0-20170109093832-22d885f9ecc7 h1:AzN37oI0cOS+cougNAV9szl6CVoj2RYwzS3DpUQNtlY=
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ package router
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/gotify/location"
|
||||
"github.com/gotify/server/api"
|
||||
"github.com/gotify/server/api/stream"
|
||||
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/gotify/server/database"
|
||||
"github.com/gotify/server/docs"
|
||||
"github.com/gotify/server/error"
|
||||
"github.com/gotify/server/mode"
|
||||
"github.com/gotify/server/model"
|
||||
"github.com/gotify/server/plugin"
|
||||
"github.com/gotify/server/ui"
|
||||
|
|
@ -65,17 +64,11 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
|
|||
|
||||
g.Use(func(ctx *gin.Context) {
|
||||
ctx.Header("Content-Type", "application/json")
|
||||
|
||||
if mode.IsDev() {
|
||||
ctx.Header("Access-Control-Allow-Origin", "*")
|
||||
ctx.Header("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS,PUT")
|
||||
ctx.Header("Access-Control-Allow-Headers", "X-Gotify-Key,Authorization,Content-Type,Upgrade,Origin,Connection,Accept-Encoding,Accept-Language,Host")
|
||||
}
|
||||
|
||||
for header, value := range conf.Server.ResponseHeaders {
|
||||
ctx.Header(header, value)
|
||||
}
|
||||
})
|
||||
g.Use(cors.New(auth.CorsConfig(conf)))
|
||||
|
||||
{
|
||||
g.GET("/plugin", authentication.RequireClient(), pluginHandler.GetPlugins)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,8 @@ func (s *IntegrationSuite) TestVersionInfo() {
|
|||
func (s *IntegrationSuite) TestHeaderInDev() {
|
||||
mode.Set(mode.TestDev)
|
||||
req := s.newRequest("GET", "version", "")
|
||||
//Needs an origin to indicate that it is a CORS request
|
||||
req.Header.Add("Origin", "some-origin")
|
||||
|
||||
res, err := client.Do(req)
|
||||
assert.Nil(s.T(), err)
|
||||
|
|
@ -108,6 +110,134 @@ func TestHeadersFromConfiguration(t *testing.T) {
|
|||
assert.Equal(t, "Nice", res.Header.Get("New-Cool-Header"))
|
||||
}
|
||||
|
||||
func TestHeadersFromCORSConfig(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
db := testdb.NewDBWithDefaultUser(t)
|
||||
defer db.Close()
|
||||
|
||||
config := config.Configuration{PassStrength: 5}
|
||||
config.Server.Cors.AllowOrigins = []string{"---", "http://test.com"}
|
||||
|
||||
g, closable := Create(db.GormDatabase,
|
||||
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
|
||||
&config,
|
||||
)
|
||||
server := httptest.NewServer(g)
|
||||
|
||||
defer func() {
|
||||
closable()
|
||||
server.Close()
|
||||
}()
|
||||
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil)
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Origin", "http://test.com")
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := client.Do(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://test.com", res.Header.Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
func TestInvalidOrigin(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
db := testdb.NewDBWithDefaultUser(t)
|
||||
defer db.Close()
|
||||
|
||||
config := config.Configuration{PassStrength: 5}
|
||||
config.Server.Cors.AllowOrigins = []string{"---", "http://test.com"}
|
||||
|
||||
g, closable := Create(db.GormDatabase,
|
||||
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
|
||||
&config,
|
||||
)
|
||||
server := httptest.NewServer(g)
|
||||
|
||||
defer func() {
|
||||
closable()
|
||||
server.Close()
|
||||
}()
|
||||
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil)
|
||||
req.Header.Add("Origin", "http://test1.com")
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := client.Do(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "", res.Header.Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestCORSHeaderRegex(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
db := testdb.NewDBWithDefaultUser(t)
|
||||
defer db.Close()
|
||||
|
||||
config := config.Configuration{PassStrength: 5}
|
||||
config.Server.Cors.AllowOrigins = []string{"---", "^http://test\\d{3}.com$"}
|
||||
|
||||
g, closable := Create(db.GormDatabase,
|
||||
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
|
||||
&config,
|
||||
)
|
||||
server := httptest.NewServer(g)
|
||||
|
||||
defer func() {
|
||||
closable()
|
||||
server.Close()
|
||||
}()
|
||||
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil)
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Origin", "http://test123.com")
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := client.Do(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://test123.com", res.Header.Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
// We want headers in cors config to override the responseheaders config
|
||||
func TestCORSConfigOverride(t *testing.T) {
|
||||
mode.Set(mode.Prod)
|
||||
db := testdb.NewDBWithDefaultUser(t)
|
||||
defer db.Close()
|
||||
|
||||
config := config.Configuration{PassStrength: 5}
|
||||
config.Server.ResponseHeaders = map[string]string{
|
||||
"New-Cool-Header": "Nice",
|
||||
"Access-Control-Allow-Origin": "something-else",
|
||||
"Access-Control-Allow-Methods": "321test",
|
||||
"Access-Control-Allow-Headers": "some-headers",
|
||||
}
|
||||
config.Server.Cors.AllowOrigins = []string{"http://test123.com", "aaa"}
|
||||
config.Server.Cors.AllowMethods = []string{"GET", "OPTIONS"}
|
||||
config.Server.Cors.AllowHeaders = []string{"Content-Type"}
|
||||
|
||||
g, closable := Create(db.GormDatabase,
|
||||
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
|
||||
&config,
|
||||
)
|
||||
server := httptest.NewServer(g)
|
||||
|
||||
defer func() {
|
||||
closable()
|
||||
server.Close()
|
||||
}()
|
||||
|
||||
req, err := http.NewRequest("OPTIONS", fmt.Sprintf("%s/%s", server.URL, "version"), nil)
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Origin", "http://test123.com")
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := client.Do(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "Nice", res.Header.Get("New-Cool-Header"))
|
||||
assert.Equal(t, "http://test123.com", res.Header.Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, "GET,OPTIONS", res.Header.Get("Access-Control-Allow-Methods"))
|
||||
assert.Equal(t, "Content-Type", res.Header.Get("Access-Control-Allow-Headers"))
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestOptionsRequest() {
|
||||
req := s.newRequest("OPTIONS", "version", "")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue