Fix websocket allowed origin (#150)
This commit is contained in:
parent
3e8abdefa7
commit
178c76f410
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -155,22 +156,22 @@ func (a *API) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool {
|
func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool {
|
||||||
origin := r.Header["Origin"]
|
origin := r.Header.Get("origin")
|
||||||
if len(origin) == 0 {
|
if origin == "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := url.Parse(origin[0])
|
u, err := url.Parse(origin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.Hostname() == r.Host {
|
if strings.ToLower(u.Host) == strings.ToLower(r.Host) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, allowedOrigin := range allowedOrigins {
|
for _, allowedOrigin := range allowedOrigins {
|
||||||
if allowedOrigin.Match([]byte(u.Hostname())) {
|
if allowedOrigin.Match([]byte(strings.ToLower(u.Hostname()))) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -408,6 +408,14 @@ func Test_sameOrigin_returnsTrue(t *testing.T) {
|
||||||
assert.True(t, actual)
|
assert.True(t, actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_sameOrigin_returnsTrue_withCustomPort(t *testing.T) {
|
||||||
|
mode.Set(mode.Prod)
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com:8080/stream", nil)
|
||||||
|
req.Header.Set("Origin", "http://example.com:8080")
|
||||||
|
actual := isAllowedOrigin(req, nil)
|
||||||
|
assert.True(t, actual)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_isAllowedOrigin_withoutAllowedOrigins_failsWhenNotSameOrigin(t *testing.T) {
|
func Test_isAllowedOrigin_withoutAllowedOrigins_failsWhenNotSameOrigin(t *testing.T) {
|
||||||
mode.Set(mode.Prod)
|
mode.Set(mode.Prod)
|
||||||
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
|
req := httptest.NewRequest("GET", "http://example.com/stream", nil)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue