diff --git a/README.md b/README.md
index 1595ce4..644b729 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
# Gotify Server
-[![Build Status][badge-travis]][travis]
-[![codecov][badge-codecov]][codecov]
-[![Go Report Card][badge-go-report]][go-report]
+[![Build Status][badge-travis]][travis]
+[![codecov][badge-codecov]][codecov]
+[![Go Report Card][badge-go-report]][go-report]
[![Swagger Valid][badge-swagger]][swagger]
[![FOSSA Status][fossa-badge]][fossa]
-[![Api Docs][badge-api-docs]][api-docs]
+[![Api Docs][badge-api-docs]][api-docs]
[![latest release version][badge-release]][release]
@@ -46,7 +46,7 @@ Google Play and the Google Play logo are trademarks of Google LLC.
The docker image is available on docker hub at [gotify/server][docker-normal].
``` bash
-$ docker run -p 80:80 -v /etc/gotify/data:/app/data gotify/server
+$ docker run -p 80:80 -v /etc/gotify/data:/app/data gotify/server
```
Also there is a specific docker image for arm-7 processors (raspberry pi), named [gotify/server-arm7][docker-arm7].
``` bash
@@ -60,7 +60,7 @@ Visit the [releases page](https://github.com/gotify/server/releases) and downloa
## Configuration
### File
-When strings contain reserved characters then they need to be escaped.
+When strings contain reserved characters then they need to be escaped.
[List of reserved characters and how to escape them](https://stackoverflow.com/a/22235064/4244993).
``` yml
@@ -82,6 +82,10 @@ server:
responseheaders: # response headers are added to every response (default: none)
Access-Control-Allow-Origin: "*"
Access-Control-Allow-Methods: "GET,POST"
+ stream:
+ allowedorigins: # allowed origins for websocket connections (same origin is always allowed)
+ - ".+.example.com"
+ - "otherdomain.com"
database: # for database see (configure database section)
dialect: sqlite3
connection: data/gotify.db
@@ -94,8 +98,8 @@ uploadedimagesdir: data/images # the directory for storing uploaded images
### Environment
-Escaped characters in list or map environment settings (`GOTIFY_SERVER_RESPONSEHEADERS` and
-`GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS`) need to be escaped as well.
+Escaped characters in list or map environment settings (`GOTIFY_SERVER_RESPONSEHEADERS` and
+`GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS`) need to be escaped as well.
[List of reserved characters and how to escape them](https://stackoverflow.com/a/22235064/4244993).
``` bash
@@ -111,6 +115,7 @@ GOTIFY_SERVER_SSL_LETSENCRYPT_CACHE=certs
# lists are a little weird but do-able (:
GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS=- mydomain.tld\n- myotherdomain.tld
GOTIFY_SERVER_RESPONSEHEADERS="Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\""
+GOTIFY_SERVER_STREAM_ALLOWEDORIGINS="- \".+.example.com\"\n- \"otherdomain.com\""
GOTIFY_DATABASE_DIALECT=sqlite3
GOTIFY_DATABASE_CONNECTION=gotify.db
GOTIFY_DEFAULTUSER_NAME=admin
@@ -126,7 +131,7 @@ GOTIFY_UPLOADEDIMAGESDIR=images
| mysql | `gotify:secret@/gotifydb?charset=utf8&parseTime=True&loc=Local ` |
| postgres | `host=localhost port=3306 user=gotify dbname=gotify password=secret` |
-When using postgres without SSL then `sslmode=disable` must be added to the connection string.
+When using postgres without SSL then `sslmode=disable` must be added to the connection string.
See [#90](https://github.com/gotify/server/issues/90).
## Push Message Examples
@@ -141,7 +146,7 @@ $ http -f POST "https://push.example.de/message?token=" title="my titl
```
[More examples can be found here](ADD_MESSAGE_EXAMPLES.md)
-Also you can use [gotify/cli](https://github.com/gotify/cli) to push messages.
+Also you can use [gotify/cli](https://github.com/gotify/cli) to push messages.
The CLI stores url and token in a config file.
```bash
@@ -200,7 +205,7 @@ $ go test ./...
```
## Versioning
-We use [SemVer](http://semver.org/) for versioning. For the versions available, see the
+We use [SemVer](http://semver.org/) for versioning. For the versions available, see the
[tags on this repository](https://github.com/gotify/server/tags).
## License
diff --git a/api/stream/stream.go b/api/stream/stream.go
index ed742c9..07a0784 100644
--- a/api/stream/stream.go
+++ b/api/stream/stream.go
@@ -1,6 +1,7 @@
package stream
import (
+ "regexp"
"sync"
"time"
@@ -14,46 +15,25 @@ import (
"github.com/gotify/server/model"
)
-var upgrader = websocket.Upgrader{
- ReadBufferSize: 1024,
- WriteBufferSize: 1024,
- CheckOrigin: func(r *http.Request) bool {
- if mode.IsDev() {
- return true
- }
- return checkSameOrigin(r)
- },
-}
-
-func checkSameOrigin(r *http.Request) bool {
- origin := r.Header["Origin"]
- if len(origin) == 0 {
- return true
- }
- u, err := url.Parse(origin[0])
- if err != nil {
- return false
- }
- return u.Host == r.Host
-}
-
// The API provides a handler for a WebSocket stream API.
type API struct {
clients map[uint][]*client
lock sync.RWMutex
pingPeriod time.Duration
pongTimeout time.Duration
+ upgrader *websocket.Upgrader
}
// New creates a new instance of API.
// pingPeriod: is the interval, in which is server sends the a ping to the client.
// pongTimeout: is the duration after the connection will be terminated, when the client does not respond with the
// pong command.
-func New(pingPeriod, pongTimeout time.Duration) *API {
+func New(pingPeriod, pongTimeout time.Duration, allowedWebSocketOrigins []string) *API {
return &API{
clients: make(map[uint][]*client),
pingPeriod: pingPeriod,
pongTimeout: pingPeriod + pongTimeout,
+ upgrader: newUpgrader(allowedWebSocketOrigins),
}
}
@@ -147,7 +127,7 @@ func (a *API) register(client *client) {
// schema:
// $ref: "#/definitions/Error"
func (a *API) Handle(ctx *gin.Context) {
- conn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
+ conn, err := a.upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
if err != nil {
return
}
@@ -172,3 +152,50 @@ func (a *API) Close() {
delete(a.clients, k)
}
}
+
+func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool {
+ origin := r.Header["Origin"]
+ if len(origin) == 0 {
+ return true
+ }
+
+ u, err := url.Parse(origin[0])
+ if err != nil {
+ return false
+ }
+
+ if u.Hostname() == r.Host {
+ return true
+ }
+
+ for _, allowedOrigin := range allowedOrigins {
+ if allowedOrigin.Match([]byte(u.Hostname())) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func newUpgrader(allowedWebSocketOrigins []string) *websocket.Upgrader {
+ compiledAllowedOrigins := compileAllowedWebSocketOrigins(allowedWebSocketOrigins)
+ return &websocket.Upgrader{
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ CheckOrigin: func(r *http.Request) bool {
+ if mode.IsDev() {
+ return true
+ }
+ return isAllowedOrigin(r, compiledAllowedOrigins)
+ },
+ }
+}
+
+func compileAllowedWebSocketOrigins(allowedOrigins []string) []*regexp.Regexp {
+ var compiledAllowedOrigins []*regexp.Regexp
+ for _, origin := range allowedOrigins {
+ compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin))
+ }
+
+ return compiledAllowedOrigins
+}
diff --git a/api/stream/stream_test.go b/api/stream/stream_test.go
index cf88b08..a6248d2 100644
--- a/api/stream/stream_test.go
+++ b/api/stream/stream_test.go
@@ -405,14 +405,37 @@ func Test_sameOrigin_returnsTrue(t *testing.T) {
mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http://example.com")
- actual := checkSameOrigin(req)
+ actual := isAllowedOrigin(req, nil)
assert.True(t, actual)
}
+func Test_isAllowedOrigin_withoutAllowedOrigins_failsWhenNotSameOrigin(t *testing.T) {
+ mode.Set(mode.Prod)
+ req := httptest.NewRequest("GET", "http://example.com/stream", nil)
+ req.Header.Set("Origin", "http://gorify.example.com")
+ actual := isAllowedOrigin(req, nil)
+ assert.False(t, actual)
+}
+
+func Test_isAllowedOriginMatching(t *testing.T) {
+ mode.Set(mode.Prod)
+ compiledAllowedOrigins := compileAllowedWebSocketOrigins([]string{"go.{4}\\.example\\.com", "go\\.example\\.com"})
+
+ req := httptest.NewRequest("GET", "http://example.me/stream", nil)
+ req.Header.Set("Origin", "http://gorify.example.com")
+ assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins))
+
+ req.Header.Set("Origin", "http://go.example.com")
+ assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins))
+
+ req.Header.Set("Origin", "http://hello.example.com")
+ assert.False(t, isAllowedOrigin(req, compiledAllowedOrigins))
+}
+
func Test_emptyOrigin_returnsTrue(t *testing.T) {
mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
- actual := checkSameOrigin(req)
+ actual := isAllowedOrigin(req, nil)
assert.True(t, actual)
}
@@ -420,7 +443,7 @@ func Test_otherOrigin_returnsFalse(t *testing.T) {
mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http://otherexample.de")
- actual := checkSameOrigin(req)
+ actual := isAllowedOrigin(req, nil)
assert.False(t, actual)
}
@@ -428,10 +451,15 @@ func Test_invalidOrigin_returnsFalse(t *testing.T) {
mode.Set(mode.Prod)
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
req.Header.Set("Origin", "http\\://otherexample.de")
- actual := checkSameOrigin(req)
+ actual := isAllowedOrigin(req, nil)
assert.False(t, actual)
}
+func Test_compileAllowedWebSocketOrigins(t *testing.T) {
+ assert.Equal(t, 0, len(compileAllowedWebSocketOrigins([]string{})))
+ assert.Equal(t, 3, len(compileAllowedWebSocketOrigins([]string{"^.*$", "", "abc"})))
+}
+
func clients(api *API, user uint) []*client {
api.lock.RLock()
defer api.lock.RUnlock()
@@ -439,7 +467,7 @@ func clients(api *API, user uint) []*client {
return api.clients[user]
}
-func testClient(t *testing.T, url string) *testingClient {
+func testClient(t *testing.T, url string) *testingClient {
client := createClient(t, url)
startReading(client)
return client
@@ -507,11 +535,11 @@ func (c *testingClient) expectNoMessage() {
}
func bootTestServer(handlerFunc gin.HandlerFunc) (*httptest.Server, *API) {
-
r := gin.New()
r.Use(handlerFunc)
// all 4 seconds a ping, and the client has 1 second to respond
- api := New(4*time.Second, 1*time.Second)
+ api := New(4*time.Second, 1*time.Second, []string{})
+
r.GET("/", api.Handle)
server := httptest.NewServer(r)
return server, api
diff --git a/config/config.go b/config/config.go
index c8f3e09..c80288c 100644
--- a/config/config.go
+++ b/config/config.go
@@ -25,6 +25,9 @@ type Configuration struct {
}
}
ResponseHeaders map[string]string
+ Stream struct {
+ AllowedOrigins []string
+ }
}
Database struct {
Dialect string `default:"sqlite3"`
diff --git a/config/config_test.go b/config/config_test.go
index 666cb1c..c4c64dd 100644
--- a/config/config_test.go
+++ b/config/config_test.go
@@ -15,15 +15,20 @@ func TestConfigEnv(t *testing.T) {
os.Setenv("GOTIFY_SERVER_RESPONSEHEADERS",
"Access-Control-Allow-Origin: \"*\"\nAccess-Control-Allow-Methods: \"GET,POST\"",
)
+ os.Setenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS", "- \".+.example.com\"\n- \"otherdomain.com\"")
+
conf := Get()
assert.Equal(t, 80, conf.Server.Port, "should use defaults")
assert.Equal(t, "jmattheis", conf.DefaultUser.Name, "should not use default but env var")
assert.Equal(t, []string{"push.example.tld", "push.other.tld"}, conf.Server.SSL.LetsEncrypt.Hosts)
assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"])
assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"])
+ assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins)
+
os.Unsetenv("GOTIFY_DEFAULTUSER_NAME")
os.Unsetenv("GOTIFY_SERVER_SSL_LETSENCRYPT_HOSTS")
os.Unsetenv("GOTIFY_SERVER_RESPONSEHEADERS")
+ os.Unsetenv("GOTIFY_SERVER_STREAM_ALLOWEDORIGINS")
}
func TestAddSlash(t *testing.T) {
@@ -75,6 +80,10 @@ server:
responseheaders:
Access-Control-Allow-Origin: "*"
Access-Control-Allow-Methods: "GET,POST"
+ stream:
+ allowedorigins:
+ - ".+.example.com"
+ - "otherdomain.com"
database:
dialect: mysql
connection: user name
@@ -94,6 +103,7 @@ defaultuser:
assert.Equal(t, "user name", conf.Database.Connection)
assert.Equal(t, "*", conf.Server.ResponseHeaders["Access-Control-Allow-Origin"])
assert.Equal(t, "GET,POST", conf.Server.ResponseHeaders["Access-Control-Allow-Methods"])
+ assert.Equal(t, []string{".+.example.com", "otherdomain.com"}, conf.Server.Stream.AllowedOrigins)
assert.Nil(t, os.Remove("config.yml"))
}
diff --git a/router/router.go b/router/router.go
index ccf49fe..dc2adf5 100644
--- a/router/router.go
+++ b/router/router.go
@@ -23,7 +23,7 @@ import (
// Create creates the gin engine with all routes.
func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Configuration) (*gin.Engine, func()) {
- streamHandler := stream.New(200*time.Second, 15*time.Second)
+ streamHandler := stream.New(200*time.Second, 15*time.Second, conf.Server.Stream.AllowedOrigins)
authentication := auth.Auth{DB: db}
messageHandler := api.MessageAPI{Notifier: streamHandler, DB: db}
clientHandler := api.ClientAPI{