diff --git a/runner/runner.go b/runner/runner.go index 1bf2b7f..8b9642f 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strconv" + "strings" "time" "github.com/gotify/server/v2/config" @@ -75,7 +76,11 @@ func redirectToHTTPS(port string) http.HandlerFunc { func changePort(hostPort, port string) string { host, _, err := net.SplitHostPort(hostPort) if err != nil { - return hostPort + // There is no exported error. + if !strings.Contains(err.Error(), "missing port") { + return hostPort + } + host = hostPort } return net.JoinHostPort(host, port) } diff --git a/runner/runner_test.go b/runner/runner_test.go new file mode 100644 index 0000000..fa22255 --- /dev/null +++ b/runner/runner_test.go @@ -0,0 +1,35 @@ +package runner + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRedirect(t *testing.T) { + cases := []struct { + Request string + TLS int + Expect string + }{ + {Request: "http://gotify.net/meow", TLS: 443, Expect: "https://gotify.net:443/meow"}, + {Request: "http://gotify.net:8080/meow", TLS: 443, Expect: "https://gotify.net:443/meow"}, + {Request: "http://gotify.net:8080/meow", TLS: 8443, Expect: "https://gotify.net:8443/meow"}, + } + + for _, testCase := range cases { + name := fmt.Sprintf("%s -- %d -> %s", testCase.Request, testCase.TLS, testCase.Expect) + t.Run(name, func(t *testing.T) { + req := httptest.NewRequest("GET", testCase.Request, nil) + rec := httptest.NewRecorder() + + redirectToHTTPS(fmt.Sprint(testCase.TLS)).ServeHTTP(rec, req) + + assert.Equal(t, http.StatusFound, rec.Result().StatusCode) + assert.Equal(t, testCase.Expect, rec.Header().Get("location")) + }) + } +}