Use golangci-lint

This commit is contained in:
Jannis Mattheis 2020-11-01 10:36:10 +01:00
parent 44e441a8c5
commit 3454dcd602
47 changed files with 197 additions and 174 deletions

58
.golangci.yml Normal file
View File

@ -0,0 +1,58 @@
run:
skip-dirs:
- plugin/example
- plugin/testing
linters:
enable:
- asciicheck
- deadcode
- depguard
- exportloopref
- gci
- godot
- gofmt
- gofumpt
- goimports
- golint
- gomodguard
- goprintffuncname
- gosimple
- govet
- ineffassign
- interfacer
- misspell
- nakedret
- nolintlint
- sqlclosecheck
- staticcheck
- structcheck
- stylecheck
- typecheck
- unconvert
- unused
- varcheck
- whitespace
disable:
- goerr113
- errcheck
- funlen
- gochecknoglobals
- gocognit
- goconst
- gocyclo
- godox
- gomnd
- lll
- maligned
- nestif
- nlreturn
- noctx
- testpackage
- wsl
linters-settings:
gofumpt:
extra-rules: true
misspell:
locale: US

View File

@ -18,6 +18,7 @@ before_install:
- nvm install 12.10.0 - nvm install 12.10.0
- export GIMME_GO=$(< GO_VERSION) - export GIMME_GO=$(< GO_VERSION)
- eval "$(gimme ${GIMME_GO%.0})"; - eval "$(gimme ${GIMME_GO%.0})";
- curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.31.0
- make download-tools - make download-tools
install: install:

View File

@ -30,20 +30,14 @@ test-js:
rm -rf removeme rm -rf removeme
check-go: check-go:
go vet ./... golangci-lint run
gocyclo -over 10 $(shell find . -iname '*.go' -type f | grep -v /vendor/)
golint -set_exit_status $(shell go list ./... | grep -v mock)
goimports -l $(shell find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./packrd/*")
check-js: check-js:
(cd ui && yarn lint) (cd ui && yarn lint)
(cd ui && yarn testformat) (cd ui && yarn testformat)
download-tools: download-tools:
GO111MODULE=off go get -u golang.org/x/lint/golint
GO111MODULE=off go get -u github.com/fzipp/gocyclo
GO111MODULE=off go get -u github.com/go-swagger/go-swagger/cmd/swagger GO111MODULE=off go get -u github.com/go-swagger/go-swagger/cmd/swagger
GO111MODULE=off go get -u golang.org/x/tools/cmd/goimports
embed-static: embed-static:
go run hack/packr/packr.go go run hack/packr/packr.go

View File

@ -228,7 +228,6 @@ func (a *ApplicationAPI) UpdateApplication(ctx *gin.Context) {
return return
} }
ctx.JSON(200, withResolvedImage(app)) ctx.JSON(200, withResolvedImage(app))
} }
} else { } else {
ctx.AbortWithError(404, fmt.Errorf("app with id %d doesn't exists", id)) ctx.AbortWithError(404, fmt.Errorf("app with id %d doesn't exists", id))

View File

@ -37,8 +37,10 @@ type ApplicationSuite struct {
recorder *httptest.ResponseRecorder recorder *httptest.ResponseRecorder
} }
var originalGenerateApplicationToken func() string var (
var originalGenerateImageName func() string originalGenerateApplicationToken func() string
originalGenerateImageName func() string
)
func (s *ApplicationSuite) BeforeTest(suiteName, testName string) { func (s *ApplicationSuite) BeforeTest(suiteName, testName string) {
originalGenerateApplicationToken = generateApplicationToken originalGenerateApplicationToken = generateApplicationToken
@ -78,6 +80,7 @@ func (s *ApplicationSuite) Test_CreateApplication_mapAllParameters() {
assert.Equal(s.T(), expected, app) assert.Equal(s.T(), expected, app)
} }
} }
func (s *ApplicationSuite) Test_ensureApplicationHasCorrectJsonRepresentation() { func (s *ApplicationSuite) Test_ensureApplicationHasCorrectJsonRepresentation() {
actual := &model.Application{ actual := &model.Application{
ID: 1, ID: 1,
@ -90,6 +93,7 @@ func (s *ApplicationSuite) Test_ensureApplicationHasCorrectJsonRepresentation()
} }
test.JSONEquals(s.T(), actual, `{"id":1,"token":"Aasdasfgeeg","name":"myapp","description":"mydesc", "image": "asd", "internal":true}`) test.JSONEquals(s.T(), actual, `{"id":1,"token":"Aasdasfgeeg","name":"myapp","description":"mydesc", "image": "asd", "internal":true}`)
} }
func (s *ApplicationSuite) Test_CreateApplication_expectBadRequestOnEmptyName() { func (s *ApplicationSuite) Test_CreateApplication_expectBadRequestOnEmptyName() {
s.db.User(5) s.db.User(5)

View File

@ -501,6 +501,7 @@ func (s *MessageSuite) Test_CreateMessage_onQueryData() {
assert.Equal(s.T(), 200, s.recorder.Code) assert.Equal(s.T(), 200, s.recorder.Code)
assert.Equal(s.T(), uint(1), s.notifiedMessage.ID) assert.Equal(s.T(), uint(1), s.notifiedMessage.ID)
} }
func (s *MessageSuite) Test_CreateMessage_onFormData() { func (s *MessageSuite) Test_CreateMessage_onFormData() {
auth.RegisterAuthentication(s.ctx, nil, 4, "app-token") auth.RegisterAuthentication(s.ctx, nil, 4, "app-token")
s.db.User(4).AppWithToken(99, "app-token") s.db.User(4).AppWithToken(99, "app-token")

View File

@ -5,9 +5,8 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"github.com/gotify/location"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gotify/location"
"github.com/gotify/server/v2/auth" "github.com/gotify/server/v2/auth"
"github.com/gotify/server/v2/model" "github.com/gotify/server/v2/model"
"github.com/gotify/server/v2/plugin" "github.com/gotify/server/v2/plugin"

View File

@ -8,8 +8,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"gopkg.in/yaml.v2" "github.com/gin-gonic/gin"
"github.com/gotify/server/v2/mode" "github.com/gotify/server/v2/mode"
"github.com/gotify/server/v2/model" "github.com/gotify/server/v2/model"
"github.com/gotify/server/v2/plugin" "github.com/gotify/server/v2/plugin"
@ -17,11 +16,9 @@ import (
"github.com/gotify/server/v2/plugin/testing/mock" "github.com/gotify/server/v2/plugin/testing/mock"
"github.com/gotify/server/v2/test" "github.com/gotify/server/v2/test"
"github.com/gotify/server/v2/test/testdb" "github.com/gotify/server/v2/test/testdb"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gopkg.in/yaml.v2"
) )
func TestPluginSuite(t *testing.T) { func TestPluginSuite(t *testing.T) {
@ -101,7 +98,6 @@ func (s *PluginSuite) Test_GetPlugins() {
} }
func (s *PluginSuite) Test_EnableDisablePlugin() { func (s *PluginSuite) Test_EnableDisablePlugin() {
{ {
test.WithUser(s.ctx, 1) test.WithUser(s.ctx, 1)
@ -161,7 +157,6 @@ func (s *PluginSuite) Test_EnableDisablePlugin() {
} }
s.resetRecorder() s.resetRecorder()
} }
} }
func (s *PluginSuite) Test_EnableDisablePlugin_EnableReturnsError_expect500() { func (s *PluginSuite) Test_EnableDisablePlugin_EnableReturnsError_expect500() {
@ -239,7 +234,6 @@ func (s *PluginSuite) Test_EnableDisablePlugin_incorrectUser_expectNotFound() {
} }
s.resetRecorder() s.resetRecorder()
} }
} }
func (s *PluginSuite) Test_EnableDisablePlugin_nonExistPlugin_expectNotFound() { func (s *PluginSuite) Test_EnableDisablePlugin_nonExistPlugin_expectNotFound() {
@ -264,7 +258,6 @@ func (s *PluginSuite) Test_EnableDisablePlugin_nonExistPlugin_expectNotFound() {
assert.Equal(s.T(), 404, s.recorder.Code) assert.Equal(s.T(), 404, s.recorder.Code)
s.resetRecorder() s.resetRecorder()
} }
} }
func (s *PluginSuite) Test_EnableDisablePlugin_danglingConf_expectNotFound() { func (s *PluginSuite) Test_EnableDisablePlugin_danglingConf_expectNotFound() {

View File

@ -77,7 +77,7 @@ func (c *client) startReading(pongWait time.Duration) {
// startWriteHandler starts the write loop. The method has the following tasks: // startWriteHandler starts the write loop. The method has the following tasks:
// * ping the client in the interval provided as parameter // * ping the client in the interval provided as parameter
// * write messages send by the channel to the client // * write messages send by the channel to the client
// * on errors exit the loop // * on errors exit the loop.
func (c *client) startWriteHandler(pingPeriod time.Duration) { func (c *client) startWriteHandler(pingPeriod time.Duration) {
pingTicker := time.NewTicker(pingPeriod) pingTicker := time.NewTicker(pingPeriod)
defer func() { defer func() {

View File

@ -166,7 +166,7 @@ func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool {
return false return false
} }
if strings.ToLower(u.Host) == strings.ToLower(r.Host) { if strings.EqualFold(u.Host, r.Host) {
return true return true
} }

View File

@ -484,7 +484,6 @@ func startReading(client *testingClient) {
go func() { go func() {
for { for {
_, payload, err := client.conn.ReadMessage() _, payload, err := client.conn.ReadMessage()
if err != nil { if err != nil {
return return
} }

View File

@ -221,7 +221,7 @@ func (a *UserAPI) CreateUser(ctx *gin.Context) {
// $ref: "#/definitions/Error" // $ref: "#/definitions/Error"
func (a *UserAPI) GetUserByID(ctx *gin.Context) { func (a *UserAPI) GetUserByID(ctx *gin.Context) {
withID(ctx, "id", func(id uint) { withID(ctx, "id", func(id uint) {
user, err := a.DB.GetUserByID(uint(id)) user, err := a.DB.GetUserByID(id)
if success := successOrAbort(ctx, 500, err); !success { if success := successOrAbort(ctx, 500, err); !success {
return return
} }

View File

@ -47,6 +47,7 @@ func (s *UserSuite) BeforeTest(suiteName, testName string) {
}) })
s.a = &UserAPI{DB: s.db, UserChangeNotifier: s.notifier} s.a = &UserAPI{DB: s.db, UserChangeNotifier: s.notifier}
} }
func (s *UserSuite) AfterTest(suiteName, testName string) { func (s *UserSuite) AfterTest(suiteName, testName string) {
s.db.Close() s.db.Close()
} }

2
app.go
View File

@ -21,7 +21,7 @@ var (
Commit = "unknown" Commit = "unknown"
// BuildDate the date on which this binary was build. // BuildDate the date on which this binary was build.
BuildDate = "unknown" BuildDate = "unknown"
// Mode the build mode // Mode the build mode.
Mode = mode.Dev Mode = mode.Dev
) )

View File

@ -21,12 +21,12 @@ type Database interface {
GetUserByID(id uint) (*model.User, error) GetUserByID(id uint) (*model.User, error)
} }
// Auth is the provider for authentication middleware // Auth is the provider for authentication middleware.
type Auth struct { type Auth struct {
DB Database DB Database
} }
type authenticate func(tokenID string, user *model.User) (authenticated bool, success bool, userId uint, err error) type authenticate func(tokenID string, user *model.User) (authenticated, success bool, userId uint, err error)
// RequireAdmin returns a gin middleware which requires a client token or basic authentication header to be supplied // RequireAdmin returns a gin middleware which requires a client token or basic authentication header to be supplied
// with the request. Also the authenticated user must be an administrator. // with the request. Also the authenticated user must be an administrator.
@ -112,14 +112,14 @@ func (a *Auth) requireToken(auth authenticate) gin.HandlerFunc {
token := a.tokenFromQueryOrHeader(ctx) token := a.tokenFromQueryOrHeader(ctx)
user, err := a.userFromBasicAuth(ctx) user, err := a.userFromBasicAuth(ctx)
if err != nil { if err != nil {
ctx.AbortWithError(500, errors.New("an error occured while authenticating user")) ctx.AbortWithError(500, errors.New("an error occurred while authenticating user"))
return return
} }
if user != nil || token != "" { if user != nil || token != "" {
authenticated, ok, userID, err := auth(token, user) authenticated, ok, userID, err := auth(token, user)
if err != nil { if err != nil {
ctx.AbortWithError(500, errors.New("an error occured while authenticating user")) ctx.AbortWithError(500, errors.New("an error occurred while authenticating user"))
return return
} else if ok { } else if ok {
RegisterAuthentication(ctx, user, userID, token) RegisterAuthentication(ctx, user, userID, token)

View File

@ -10,7 +10,7 @@ import (
"github.com/gotify/server/v2/mode" "github.com/gotify/server/v2/mode"
) )
// CorsConfig generates a config to use in gin cors middleware based on server configuration // CorsConfig generates a config to use in gin cors middleware based on server configuration.
func CorsConfig(conf *config.Configuration) cors.Config { func CorsConfig(conf *config.Configuration) cors.Config {
corsConf := cors.Config{ corsConf := cors.Config{
MaxAge: 12 * time.Hour, MaxAge: 12 * time.Hour,
@ -19,8 +19,10 @@ func CorsConfig(conf *config.Configuration) cors.Config {
if mode.IsDev() { if mode.IsDev() {
corsConf.AllowAllOrigins = true corsConf.AllowAllOrigins = true
corsConf.AllowMethods = []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"} corsConf.AllowMethods = []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"}
corsConf.AllowHeaders = []string{"X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin", corsConf.AllowHeaders = []string{
"Connection", "Accept-Encoding", "Accept-Language", "Host"} "X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin",
"Connection", "Accept-Encoding", "Accept-Language", "Host",
}
} else { } else {
compiledOrigins := compileAllowedCORSOrigins(conf.Server.Cors.AllowOrigins) compiledOrigins := compileAllowedCORSOrigins(conf.Server.Cors.AllowOrigins)
corsConf.AllowMethods = conf.Server.Cors.AllowMethods corsConf.AllowMethods = conf.Server.Cors.AllowMethods

View File

@ -33,6 +33,7 @@ func TestCorsConfig(t *testing.T) {
assert.False(t, allowF("https://test.com")) assert.False(t, allowF("https://test.com"))
assert.False(t, allowF("https://other.com")) assert.False(t, allowF("https://other.com"))
} }
func TestEmptyCorsConfigWithResponseHeaders(t *testing.T) { func TestEmptyCorsConfigWithResponseHeaders(t *testing.T) {
mode.Set(mode.Prod) mode.Set(mode.Prod)
serverConf := config.Configuration{} serverConf := config.Configuration{}
@ -60,8 +61,10 @@ func TestDevCorsConfig(t *testing.T) {
actual := CorsConfig(&serverConf) actual := CorsConfig(&serverConf)
assert.Equal(t, cors.Config{ assert.Equal(t, cors.Config{
AllowHeaders: []string{"X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin", AllowHeaders: []string{
"Connection", "Accept-Encoding", "Accept-Language", "Host"}, "X-Gotify-Key", "Authorization", "Content-Type", "Upgrade", "Origin",
"Connection", "Accept-Encoding", "Accept-Language", "Host",
},
AllowMethods: []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"}, AllowMethods: []string{"GET", "POST", "DELETE", "OPTIONS", "PUT"},
MaxAge: 12 * time.Hour, MaxAge: 12 * time.Hour,
AllowAllOrigins: true, AllowAllOrigins: true,

View File

@ -7,7 +7,6 @@ import (
"testing" "testing"
"github.com/gotify/server/v2/test" "github.com/gotify/server/v2/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -26,10 +25,7 @@ func TestGenerateNotExistingToken(t *testing.T) {
return fmt.Sprint(count) return fmt.Sprint(count)
}, func(token string) bool { }, func(token string) bool {
count-- count--
if token == "0" { return token != "0"
return false
}
return true
}) })
assert.Equal(t, "0", token) assert.Equal(t, "0", token)
} }

View File

@ -26,7 +26,7 @@ func GetUserID(ctx *gin.Context) uint {
return user.ID return user.ID
} }
// GetTokenID returns the tokenID // GetTokenID returns the tokenID.
func GetTokenID(ctx *gin.Context) string { func GetTokenID(ctx *gin.Context) string {
return ctx.MustGet("tokenid").(string) return ctx.MustGet("tokenid").(string)
} }

View File

@ -38,7 +38,7 @@ func (s *UtilSuite) Test_getToken() {
assert.Equal(s.T(), "asdasda", actualID) assert.Equal(s.T(), "asdasda", actualID)
} }
func (s *UtilSuite) expectUserIDWith(user *model.User, tokenUserID uint, expectedID uint) { func (s *UtilSuite) expectUserIDWith(user *model.User, tokenUserID, expectedID uint) {
ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
RegisterAuthentication(ctx, user, tokenUserID, "") RegisterAuthentication(ctx, user, tokenUserID, "")
actualID := GetUserID(ctx) actualID := GetUserID(ctx)

View File

@ -73,6 +73,6 @@ func Get() *Configuration {
func addTrailingSlashToPaths(conf *Configuration) { func addTrailingSlashToPaths(conf *Configuration) {
if !strings.HasSuffix(conf.UploadedImagesDir, "/") && !strings.HasSuffix(conf.UploadedImagesDir, "\\") { if !strings.HasSuffix(conf.UploadedImagesDir, "/") && !strings.HasSuffix(conf.UploadedImagesDir, "\\") {
conf.UploadedImagesDir = conf.UploadedImagesDir + string(filepath.Separator) conf.UploadedImagesDir += string(filepath.Separator)
} }
} }

View File

@ -6,7 +6,6 @@ import (
) )
func (s *DatabaseSuite) TestApplication() { func (s *DatabaseSuite) TestApplication() {
if app, err := s.db.GetApplicationByToken("asdasdf"); assert.NoError(s.T(), err) { if app, err := s.db.GetApplicationByToken("asdasdf"); assert.NoError(s.T(), err) {
assert.Nil(s.T(), app, "not existing app") assert.Nil(s.T(), app, "not existing app")
} }

View File

@ -8,9 +8,15 @@ import (
"github.com/gotify/server/v2/auth/password" "github.com/gotify/server/v2/auth/password"
"github.com/gotify/server/v2/model" "github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql" // enable the mysql dialect
_ "github.com/jinzhu/gorm/dialects/postgres" // enable the postgres dialect // enable the mysql dialect.
_ "github.com/jinzhu/gorm/dialects/sqlite" // enable the sqlite3 dialect _ "github.com/jinzhu/gorm/dialects/mysql"
// enable the postgres dialect.
_ "github.com/jinzhu/gorm/dialects/postgres"
// enable the sqlite3 dialect.
_ "github.com/jinzhu/gorm/dialects/sqlite"
) )
var mkdirAll = os.MkdirAll var mkdirAll = os.MkdirAll
@ -86,7 +92,7 @@ func prepareBlobColumn(dialect string, db *gorm.DB) error {
return nil return nil
} }
func createDirectoryIfSqlite(dialect string, connection string) { func createDirectoryIfSqlite(dialect, connection string) {
if dialect == "sqlite3" { if dialect == "sqlite3" {
if _, err := os.Stat(filepath.Dir(connection)); os.IsNotExist(err) { if _, err := os.Stat(filepath.Dir(connection)); os.IsNotExist(err) {
if err := mkdirAll(filepath.Dir(connection), 0777); err != nil { if err := mkdirAll(filepath.Dir(connection), 0777); err != nil {

View File

@ -6,7 +6,6 @@ import (
"testing" "testing"
"github.com/gotify/server/v2/test" "github.com/gotify/server/v2/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )

View File

@ -145,7 +145,6 @@ func (s *DatabaseSuite) TestMessage() {
msgs, err = s.db.GetMessagesByUser(user.ID) msgs, err = s.db.GetMessagesByUser(user.ID)
require.NoError(s.T(), err) require.NoError(s.T(), err)
assert.Empty(s.T(), msgs) assert.Empty(s.T(), msgs)
} }
func (s *DatabaseSuite) TestGetMessagesSince() { func (s *DatabaseSuite) TestGetMessagesSince() {
@ -222,7 +221,6 @@ func (s *DatabaseSuite) TestGetMessagesSince() {
require.NoError(s.T(), err) require.NoError(s.T(), err)
assert.Len(s.T(), actual, 50) assert.Len(s.T(), actual, 50)
hasIDInclusiveBetween(s.T(), actual, 100, 2, 2) hasIDInclusiveBetween(s.T(), actual, 100, 2, 2)
} }
func hasIDInclusiveBetween(t *testing.T, msgs []*model.Message, from, to, decrement int) { func hasIDInclusiveBetween(t *testing.T, msgs []*model.Message, from, to, decrement int) {
@ -236,8 +234,8 @@ func hasIDInclusiveBetween(t *testing.T, msgs []*model.Message, from, to, decrem
assert.Equal(t, index, len(msgs), "not all entries inside msgs were checked") assert.Equal(t, index, len(msgs), "not all entries inside msgs were checked")
} }
// assertEquals compares messages and correctly check dates // assertEquals compares messages and correctly check dates.
func assertEquals(t *testing.T, left *model.Message, right *model.Message) { func assertEquals(t *testing.T, left, right *model.Message) {
assert.Equal(t, left.Date.Unix(), right.Date.Unix()) assert.Equal(t, left.Date.Unix(), right.Date.Unix())
left.Date = right.Date left.Date = right.Date
assert.Equal(t, left, right) assert.Equal(t, left, right)

View File

@ -5,7 +5,7 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
// GetPluginConfByUser gets plugin configurations from a user // GetPluginConfByUser gets plugin configurations from a user.
func (d *GormDatabase) GetPluginConfByUser(userid uint) ([]*model.PluginConf, error) { func (d *GormDatabase) GetPluginConfByUser(userid uint) ([]*model.PluginConf, error) {
var plugins []*model.PluginConf var plugins []*model.PluginConf
err := d.DB.Where("user_id = ?", userid).Find(&plugins).Error err := d.DB.Where("user_id = ?", userid).Find(&plugins).Error
@ -15,7 +15,7 @@ func (d *GormDatabase) GetPluginConfByUser(userid uint) ([]*model.PluginConf, er
return plugins, err return plugins, err
} }
// GetPluginConfByUserAndPath gets plugin configuration by user and file name // GetPluginConfByUserAndPath gets plugin configuration by user and file name.
func (d *GormDatabase) GetPluginConfByUserAndPath(userid uint, path string) (*model.PluginConf, error) { func (d *GormDatabase) GetPluginConfByUserAndPath(userid uint, path string) (*model.PluginConf, error) {
plugin := new(model.PluginConf) plugin := new(model.PluginConf)
err := d.DB.Where("user_id = ? AND module_path = ?", userid, path).First(plugin).Error err := d.DB.Where("user_id = ? AND module_path = ?", userid, path).First(plugin).Error
@ -41,12 +41,12 @@ func (d *GormDatabase) GetPluginConfByApplicationID(appid uint) (*model.PluginCo
return nil, err return nil, err
} }
// CreatePluginConf creates a new plugin configuration // CreatePluginConf creates a new plugin configuration.
func (d *GormDatabase) CreatePluginConf(p *model.PluginConf) error { func (d *GormDatabase) CreatePluginConf(p *model.PluginConf) error {
return d.DB.Create(p).Error return d.DB.Create(p).Error
} }
// GetPluginConfByToken gets plugin configuration by plugin token // GetPluginConfByToken gets plugin configuration by plugin token.
func (d *GormDatabase) GetPluginConfByToken(token string) (*model.PluginConf, error) { func (d *GormDatabase) GetPluginConfByToken(token string) (*model.PluginConf, error) {
plugin := new(model.PluginConf) plugin := new(model.PluginConf)
err := d.DB.Where("token = ?", token).First(plugin).Error err := d.DB.Where("token = ?", token).First(plugin).Error
@ -59,7 +59,7 @@ func (d *GormDatabase) GetPluginConfByToken(token string) (*model.PluginConf, er
return nil, err return nil, err
} }
// GetPluginConfByID gets plugin configuration by plugin ID // GetPluginConfByID gets plugin configuration by plugin ID.
func (d *GormDatabase) GetPluginConfByID(id uint) (*model.PluginConf, error) { func (d *GormDatabase) GetPluginConfByID(id uint) (*model.PluginConf, error) {
plugin := new(model.PluginConf) plugin := new(model.PluginConf)
err := d.DB.Where("id = ?", id).First(plugin).Error err := d.DB.Where("id = ?", id).First(plugin).Error
@ -72,7 +72,7 @@ func (d *GormDatabase) GetPluginConfByID(id uint) (*model.PluginConf, error) {
return nil, err return nil, err
} }
// UpdatePluginConf updates plugin configuration // UpdatePluginConf updates plugin configuration.
func (d *GormDatabase) UpdatePluginConf(p *model.PluginConf) error { func (d *GormDatabase) UpdatePluginConf(p *model.PluginConf) error {
return d.DB.Save(p).Error return d.DB.Save(p).Error
} }

View File

@ -67,5 +67,4 @@ func (s *DatabaseSuite) TestPluginConf() {
require.NoError(s.T(), err) require.NoError(s.T(), err)
assert.Equal(s.T(), false, pluginConf.Enabled) assert.Equal(s.T(), false, pluginConf.Enabled)
assert.Equal(s.T(), testConf, string(pluginConf.Config)) assert.Equal(s.T(), testConf, string(pluginConf.Config))
} }

View File

@ -72,7 +72,6 @@ func (s *DatabaseSuite) TestUser() {
users, err = s.db.GetUsers() users, err = s.db.GetUsers()
require.NoError(s.T(), err) require.NoError(s.T(), err)
assert.Empty(s.T(), users) assert.Empty(s.T(), users)
} }
func (s *DatabaseSuite) TestUserPlugins() { func (s *DatabaseSuite) TestUserPlugins() {
@ -100,7 +99,6 @@ func (s *DatabaseSuite) TestUserPlugins() {
if pluginConf, err := s.db.GetPluginConfByToken("P1234"); assert.NoError(s.T(), err) { if pluginConf, err := s.db.GetPluginConfByToken("P1234"); assert.NoError(s.T(), err) {
assert.Equal(s.T(), "github.com/gotify/example-plugin", pluginConf.ModulePath) assert.Equal(s.T(), "github.com/gotify/example-plugin", pluginConf.ModulePath)
} }
} }
func (s *DatabaseSuite) TestDeleteUserDeletesApplicationsAndClientsAndPluginConfs() { func (s *DatabaseSuite) TestDeleteUserDeletesApplicationsAndClientsAndPluginConfs() {
@ -181,5 +179,4 @@ func (s *DatabaseSuite) TestDeleteUserDeletesApplicationsAndClientsAndPluginConf
msg, err = s.db.GetMessageByID(2000) msg, err = s.db.GetMessageByID(2000)
require.NoError(s.T(), err) require.NoError(s.T(), err)
assert.NotNil(s.T(), msg) assert.NotNil(s.T(), msg)
} }

View File

@ -3,11 +3,11 @@ package mode
import "github.com/gin-gonic/gin" import "github.com/gin-gonic/gin"
const ( const (
// Dev for development mode // Dev for development mode.
Dev = "dev" Dev = "dev"
// Prod for production mode // Prod for production mode.
Prod = "prod" Prod = "prod"
// TestDev used for tests // TestDev used for tests.
TestDev = "testdev" TestDev = "testdev"
) )

View File

@ -4,7 +4,7 @@ import (
"time" "time"
) )
// Message holds information about a message // Message holds information about a message.
type Message struct { type Message struct {
ID uint `gorm:"AUTO_INCREMENT;primary_key;index"` ID uint `gorm:"AUTO_INCREMENT;primary_key;index"`
ApplicationID uint ApplicationID uint

View File

@ -1,6 +1,6 @@
package model package model
// PluginConf holds information about the plugin // PluginConf holds information about the plugin.
type PluginConf struct { type PluginConf struct {
ID uint `gorm:"primary_key;AUTO_INCREMENT;index"` ID uint `gorm:"primary_key;AUTO_INCREMENT;index"`
UserID uint UserID uint

View File

@ -6,23 +6,23 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// Capability is a capability the plugin provides // Capability is a capability the plugin provides.
type Capability string type Capability string
const ( const (
// Messenger sends notifications // Messenger sends notifications.
Messenger = Capability("messenger") Messenger = Capability("messenger")
// Configurer are consigurables // Configurer are consigurables.
Configurer = Capability("configurer") Configurer = Capability("configurer")
// Storager stores data // Storager stores data.
Storager = Capability("storager") Storager = Capability("storager")
// Webhooker registers webhooks // Webhooker registers webhooks.
Webhooker = Capability("webhooker") Webhooker = Capability("webhooker")
// Displayer displays instructions // Displayer displays instructions.
Displayer = Capability("displayer") Displayer = Capability("displayer")
) )
// PluginInstance is an encapsulation layer of plugin instances of different backends // PluginInstance is an encapsulation layer of plugin instances of different backends.
type PluginInstance interface { type PluginInstance interface {
Enable() error Enable() error
Disable() error Disable() error
@ -48,7 +48,7 @@ type PluginInstance interface {
Supports() Capabilities Supports() Capabilities
} }
// HasSupport tests a PluginInstance for a capability // HasSupport tests a PluginInstance for a capability.
func HasSupport(p PluginInstance, toCheck Capability) bool { func HasSupport(p PluginInstance, toCheck Capability) bool {
for _, module := range p.Supports() { for _, module := range p.Supports() {
if module == toCheck { if module == toCheck {
@ -58,10 +58,10 @@ func HasSupport(p PluginInstance, toCheck Capability) bool {
return false return false
} }
// Capabilities is a slice of module // Capabilities is a slice of module.
type Capabilities []Capability type Capabilities []Capability
// Strings converts []Module to []string // Strings converts []Module to []string.
func (m Capabilities) Strings() []string { func (m Capabilities) Strings() []string {
var result []string var result []string
for _, module := range m { for _, module := range m {

View File

@ -1,13 +1,13 @@
package compat package compat
// Plugin is an abstraction of plugin handler // Plugin is an abstraction of plugin handler.
type Plugin interface { type Plugin interface {
PluginInfo() Info PluginInfo() Info
NewPluginInstance(ctx UserContext) PluginInstance NewPluginInstance(ctx UserContext) PluginInstance
APIVersion() string APIVersion() string
} }
// Info is the plugin info // Info is the plugin info.
type Info struct { type Info struct {
Version string Version string
Author string Author string

View File

@ -13,12 +13,12 @@ type PluginV1 struct {
Constructor func(ctx papiv1.UserContext) papiv1.Plugin Constructor func(ctx papiv1.UserContext) papiv1.Plugin
} }
// APIVersion returns the API version // APIVersion returns the API version.
func (c PluginV1) APIVersion() string { func (c PluginV1) APIVersion() string {
return "v1" return "v1"
} }
// PluginInfo implements compat/Plugin // PluginInfo implements compat/Plugin.
func (c PluginV1) PluginInfo() Info { func (c PluginV1) PluginInfo() Info {
return Info{ return Info{
Version: c.Info.Version, Version: c.Info.Version,
@ -31,7 +31,7 @@ func (c PluginV1) PluginInfo() Info {
} }
} }
// NewPluginInstance implements compat/Plugin // NewPluginInstance implements compat/Plugin.
func (c PluginV1) NewPluginInstance(ctx UserContext) PluginInstance { func (c PluginV1) NewPluginInstance(ctx UserContext) PluginInstance {
instance := c.Constructor(papiv1.UserContext{ instance := c.Constructor(papiv1.UserContext{
ID: ctx.ID, ID: ctx.ID,
@ -66,7 +66,7 @@ func (c PluginV1) NewPluginInstance(ctx UserContext) PluginInstance {
return compat return compat
} }
// PluginV1Instance is an adapter for plugin using v1 API // PluginV1Instance is an adapter for plugin using v1 API.
type PluginV1Instance struct { type PluginV1Instance struct {
instance papiv1.Plugin instance papiv1.Plugin
messenger papiv1.Messenger messenger papiv1.Messenger
@ -76,7 +76,7 @@ type PluginV1Instance struct {
displayer papiv1.Displayer displayer papiv1.Displayer
} }
// DefaultConfig see papiv1.Configurer // DefaultConfig see papiv1.Configurer.
func (c *PluginV1Instance) DefaultConfig() interface{} { func (c *PluginV1Instance) DefaultConfig() interface{} {
if c.configurer != nil { if c.configurer != nil {
return c.configurer.DefaultConfig() return c.configurer.DefaultConfig()
@ -84,7 +84,7 @@ func (c *PluginV1Instance) DefaultConfig() interface{} {
return struct{}{} return struct{}{}
} }
// ValidateAndSetConfig see papiv1.Configurer // ValidateAndSetConfig see papiv1.Configurer.
func (c *PluginV1Instance) ValidateAndSetConfig(config interface{}) error { func (c *PluginV1Instance) ValidateAndSetConfig(config interface{}) error {
if c.configurer != nil { if c.configurer != nil {
return c.configurer.ValidateAndSetConfig(config) return c.configurer.ValidateAndSetConfig(config)
@ -92,7 +92,7 @@ func (c *PluginV1Instance) ValidateAndSetConfig(config interface{}) error {
return nil return nil
} }
// GetDisplay see papiv1.Displayer // GetDisplay see papiv1.Displayer.
func (c *PluginV1Instance) GetDisplay(location *url.URL) string { func (c *PluginV1Instance) GetDisplay(location *url.URL) string {
if c.displayer != nil { if c.displayer != nil {
return c.displayer.GetDisplay(location) return c.displayer.GetDisplay(location)
@ -100,28 +100,28 @@ func (c *PluginV1Instance) GetDisplay(location *url.URL) string {
return "" return ""
} }
// SetMessageHandler see papiv1.Messenger // SetMessageHandler see papiv1.Messenger.
func (c *PluginV1Instance) SetMessageHandler(h MessageHandler) { func (c *PluginV1Instance) SetMessageHandler(h MessageHandler) {
if c.messenger != nil { if c.messenger != nil {
c.messenger.SetMessageHandler(&PluginV1MessageHandler{WrapperHandler: h}) c.messenger.SetMessageHandler(&PluginV1MessageHandler{WrapperHandler: h})
} }
} }
// RegisterWebhook see papiv1.Webhooker // RegisterWebhook see papiv1.Webhooker.
func (c *PluginV1Instance) RegisterWebhook(basePath string, mux *gin.RouterGroup) { func (c *PluginV1Instance) RegisterWebhook(basePath string, mux *gin.RouterGroup) {
if c.webhooker != nil { if c.webhooker != nil {
c.webhooker.RegisterWebhook(basePath, mux) c.webhooker.RegisterWebhook(basePath, mux)
} }
} }
// SetStorageHandler see papiv1.Storager // SetStorageHandler see papiv1.Storager.
func (c *PluginV1Instance) SetStorageHandler(handler StorageHandler) { func (c *PluginV1Instance) SetStorageHandler(handler StorageHandler) {
if c.storager != nil { if c.storager != nil {
c.storager.SetStorageHandler(&PluginV1StorageHandler{WrapperHandler: handler}) c.storager.SetStorageHandler(&PluginV1StorageHandler{WrapperHandler: handler})
} }
} }
// Supports returns a slice of capabilities the plugin instance provides // Supports returns a slice of capabilities the plugin instance provides.
func (c *PluginV1Instance) Supports() Capabilities { func (c *PluginV1Instance) Supports() Capabilities {
modules := Capabilities{} modules := Capabilities{}
if c.configurer != nil { if c.configurer != nil {
@ -142,12 +142,12 @@ func (c *PluginV1Instance) Supports() Capabilities {
return modules return modules
} }
// PluginV1MessageHandler is an adapter for messenger plugin handler using v1 API // PluginV1MessageHandler is an adapter for messenger plugin handler using v1 API.
type PluginV1MessageHandler struct { type PluginV1MessageHandler struct {
WrapperHandler MessageHandler WrapperHandler MessageHandler
} }
// SendMessage implements papiv1.MessageHandler // SendMessage implements papiv1.MessageHandler.
func (c *PluginV1MessageHandler) SendMessage(msg papiv1.Message) error { func (c *PluginV1MessageHandler) SendMessage(msg papiv1.Message) error {
return c.WrapperHandler.SendMessage(Message{ return c.WrapperHandler.SendMessage(Message{
Message: msg.Message, Message: msg.Message,
@ -157,27 +157,27 @@ func (c *PluginV1MessageHandler) SendMessage(msg papiv1.Message) error {
}) })
} }
// Enable implements wrapper.Plugin // Enable implements wrapper.Plugin.
func (c *PluginV1Instance) Enable() error { func (c *PluginV1Instance) Enable() error {
return c.instance.Enable() return c.instance.Enable()
} }
// Disable implements wrapper.Plugin // Disable implements wrapper.Plugin.
func (c *PluginV1Instance) Disable() error { func (c *PluginV1Instance) Disable() error {
return c.instance.Disable() return c.instance.Disable()
} }
// PluginV1StorageHandler is a wrapper for v1 storage handler // PluginV1StorageHandler is a wrapper for v1 storage handler.
type PluginV1StorageHandler struct { type PluginV1StorageHandler struct {
WrapperHandler StorageHandler WrapperHandler StorageHandler
} }
// Save implements wrapper.Storager // Save implements wrapper.Storager.
func (c *PluginV1StorageHandler) Save(b []byte) error { func (c *PluginV1StorageHandler) Save(b []byte) error {
return c.WrapperHandler.Save(b) return c.WrapperHandler.Save(b)
} }
// Load implements wrapper.Storager // Load implements wrapper.Storager.
func (c *PluginV1StorageHandler) Load() ([]byte, error) { func (c *PluginV1StorageHandler) Load() ([]byte, error) {
return c.WrapperHandler.Load() return c.WrapperHandler.Load()
} }

View File

@ -3,10 +3,9 @@ package compat
import ( import (
"testing" "testing"
papiv1 "github.com/gotify/plugin-api"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
papiv1 "github.com/gotify/plugin-api"
) )
type v1MockInstance struct { type v1MockInstance struct {
@ -155,6 +154,7 @@ func (s *V1WrapperSuite) TestMessenger_sendMessageWithoutExtras() {
Extras: nil, Extras: nil,
}, handler.msgSent) }, handler.msgSent)
} }
func TestV1Wrapper(t *testing.T) { func TestV1Wrapper(t *testing.T) {
suite.Run(t, new(V1WrapperSuite)) suite.Run(t, new(V1WrapperSuite))
} }

View File

@ -10,10 +10,8 @@ import (
"plugin" "plugin"
"testing" "testing"
"github.com/gotify/server/v2/test"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gotify/server/v2/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
@ -32,9 +30,7 @@ func (s *CompatSuite) SetupSuite() {
exec.Command("go", "get", "-d").Run() exec.Command("go", "get", "-d").Run()
goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + s.tmpDir.Path("echo.so")} goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + s.tmpDir.Path("echo.so")}
for _, extraFlag := range extraGoBuildFlags { goBuildFlags = append(goBuildFlags, extraGoBuildFlags...)
goBuildFlags = append(goBuildFlags, extraFlag)
}
cmd := exec.Command("go", goBuildFlags...) cmd := exec.Command("go", goBuildFlags...)
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@ -109,6 +105,7 @@ func (s *CompatSuite) TestRegisterWebhook() {
inst.RegisterWebhook("/plugin/4/custom/Pabcd/", g) inst.RegisterWebhook("/plugin/4/custom/Pabcd/", g)
}) })
} }
func (s *CompatSuite) TestEnableDisable() { func (s *CompatSuite) TestEnableDisable() {
inst := s.p.NewPluginInstance(UserContext{ inst := s.p.NewPluginInstance(UserContext{
ID: 5, ID: 5,
@ -143,11 +140,7 @@ func TestWrapIncompatiblePlugins(t *testing.T) {
fName := tmpDir.Path(fmt.Sprintf("broken_%d.so", i)) fName := tmpDir.Path(fmt.Sprintf("broken_%d.so", i))
exec.Command("go", "get", "-d").Run() exec.Command("go", "get", "-d").Run()
goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + fName} goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + fName}
goBuildFlags = append(goBuildFlags, extraGoBuildFlags...)
for _, extraFlag := range extraGoBuildFlags {
goBuildFlags = append(goBuildFlags, extraFlag)
}
goBuildFlags = append(goBuildFlags, modulePath) goBuildFlags = append(goBuildFlags, modulePath)
cmd := exec.Command("go", goBuildFlags...) cmd := exec.Command("go", goBuildFlags...)

View File

@ -100,7 +100,7 @@ func NewManager(db Database, directory string, mux *gin.RouterGroup, notifier No
return manager, nil return manager, nil
} }
// ErrAlreadyEnabledOrDisabled is returned on SetPluginEnabled call when a plugin is already enabled or disabled // ErrAlreadyEnabledOrDisabled is returned on SetPluginEnabled call when a plugin is already enabled or disabled.
var ErrAlreadyEnabledOrDisabled = errors.New("config is already enabled/disabled") var ErrAlreadyEnabledOrDisabled = errors.New("config is already enabled/disabled")
func (m *Manager) applicationExists(token string) bool { func (m *Manager) applicationExists(token string) bool {
@ -163,7 +163,7 @@ func (m *Manager) PluginInfo(modulePath string) compat.Info {
} }
} }
// Instance returns an instance with the given ID // Instance returns an instance with the given ID.
func (m *Manager) Instance(pluginID uint) (compat.PluginInstance, error) { func (m *Manager) Instance(pluginID uint) (compat.PluginInstance, error) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
@ -174,13 +174,13 @@ func (m *Manager) Instance(pluginID uint) (compat.PluginInstance, error) {
return nil, errors.New("instance not found") return nil, errors.New("instance not found")
} }
// HasInstance returns whether the given plugin ID has a corresponding instance // HasInstance returns whether the given plugin ID has a corresponding instance.
func (m *Manager) HasInstance(pluginID uint) bool { func (m *Manager) HasInstance(pluginID uint) bool {
instance, err := m.Instance(pluginID) instance, err := m.Instance(pluginID)
return err == nil && instance != nil return err == nil && instance != nil
} }
// RemoveUser disabled all plugins of a user when the user is disabled // RemoveUser disabled all plugins of a user when the user is disabled.
func (m *Manager) RemoveUser(userID uint) error { func (m *Manager) RemoveUser(userID uint) error {
for _, p := range m.plugins { for _, p := range m.plugins {
pluginConf, err := m.db.GetPluginConfByUserAndPath(userID, p.PluginInfo().ModulePath) pluginConf, err := m.db.GetPluginConfByUserAndPath(userID, p.PluginInfo().ModulePath)
@ -242,7 +242,7 @@ func (m *Manager) loadPlugins(directory string) error {
return nil return nil
} }
// LoadPlugin loads a compat plugin, exported to sideload plugins for testing purposes // LoadPlugin loads a compat plugin, exported to sideload plugins for testing purposes.
func (m *Manager) LoadPlugin(compatPlugin compat.Plugin) error { func (m *Manager) LoadPlugin(compatPlugin compat.Plugin) error {
modulePath := compatPlugin.PluginInfo().ModulePath modulePath := compatPlugin.PluginInfo().ModulePath
if _, ok := m.plugins[modulePath]; ok { if _, ok := m.plugins[modulePath]; ok {
@ -252,7 +252,7 @@ func (m *Manager) LoadPlugin(compatPlugin compat.Plugin) error {
return nil return nil
} }
// InitializeForUserID initializes all plugin instances for a given user // InitializeForUserID initializes all plugin instances for a given user.
func (m *Manager) InitializeForUserID(userID uint) error { func (m *Manager) InitializeForUserID(userID uint) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -268,7 +268,6 @@ func (m *Manager) InitializeForUserID(userID uint) error {
} }
func (m *Manager) initializeForUser(user model.User) error { func (m *Manager) initializeForUser(user model.User) error {
userCtx := compat.UserContext{ userCtx := compat.UserContext{
ID: user.ID, ID: user.ID,
Name: user.Name, Name: user.Name,
@ -326,7 +325,8 @@ func (m *Manager) initializeSingleUserPlugin(userCtx compat.UserContext, p compa
instance.SetMessageHandler(redirectToChannel{ instance.SetMessageHandler(redirectToChannel{
ApplicationID: pluginConf.ApplicationID, ApplicationID: pluginConf.ApplicationID,
UserID: pluginConf.UserID, UserID: pluginConf.UserID,
Messages: m.messages}) Messages: m.messages,
})
} }
if compat.HasSupport(instance, compat.Storager) { if compat.HasSupport(instance, compat.Storager) {
instance.SetStorageHandler(dbStorageHandler{pluginConf.ID, m.db}) instance.SetStorageHandler(dbStorageHandler{pluginConf.ID, m.db})
@ -406,7 +406,6 @@ func (m *Manager) createPluginConf(instance compat.PluginInstance, info compat.I
return nil, err return nil, err
} }
pluginConf.ApplicationID = app.ID pluginConf.ApplicationID = app.ID
} }
if err := m.db.CreatePluginConf(pluginConf); err != nil { if err := m.db.CreatePluginConf(pluginConf); err != nil {
return nil, err return nil, err

View File

@ -13,31 +13,29 @@ import (
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/gotify/server/v2/auth" "github.com/gotify/server/v2/auth"
"github.com/gotify/server/v2/model" "github.com/gotify/server/v2/model"
"github.com/gotify/server/v2/plugin/compat" "github.com/gotify/server/v2/plugin/compat"
"github.com/gotify/server/v2/plugin/testing/mock" "github.com/gotify/server/v2/plugin/testing/mock"
"github.com/gotify/server/v2/test" "github.com/gotify/server/v2/test"
"github.com/gotify/server/v2/test/testdb" "github.com/gotify/server/v2/test/testdb"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/gin-gonic/gin"
) )
const examplePluginPath = "github.com/gotify/server/v2/plugin/example/echo" const (
const mockPluginPath = mock.ModulePath examplePluginPath = "github.com/gotify/server/v2/plugin/example/echo"
const danglingPluginPath = "github.com/gotify/server/v2/plugin/testing/removed" mockPluginPath = mock.ModulePath
danglingPluginPath = "github.com/gotify/server/v2/plugin/testing/removed"
)
type ManagerSuite struct { type ManagerSuite struct {
suite.Suite suite.Suite
db *testdb.Database db *testdb.Database
manager *Manager manager *Manager
e *gin.Engine e *gin.Engine
g *gin.RouterGroup
msgReceiver chan MessageWithUserID msgReceiver chan MessageWithUserID
tmpDir test.TmpDir tmpDir test.TmpDir
@ -57,9 +55,7 @@ func (s *ManagerSuite) SetupSuite() {
exec.Command("go", "get", "-d").Run() exec.Command("go", "get", "-d").Run()
goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + s.tmpDir.Path("echo.so")} goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + s.tmpDir.Path("echo.so")}
for _, extraFlag := range extraGoBuildFlags { goBuildFlags = append(goBuildFlags, extraGoBuildFlags...)
goBuildFlags = append(goBuildFlags, extraFlag)
}
cmd := exec.Command("go", goBuildFlags...) cmd := exec.Command("go", goBuildFlags...)
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@ -98,7 +94,6 @@ func (s *ManagerSuite) getConfForExamplePlugin(uid uint) *model.PluginConf {
pluginConf, err := s.db.GetPluginConfByUserAndPath(uid, examplePluginPath) pluginConf, err := s.db.GetPluginConfByUserAndPath(uid, examplePluginPath)
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
return pluginConf return pluginConf
} }
func (s *ManagerSuite) getConfForMockPlugin(uid uint) *model.PluginConf { func (s *ManagerSuite) getConfForMockPlugin(uid uint) *model.PluginConf {
@ -151,6 +146,7 @@ func (s *ManagerSuite) TestWebhook_successIfEnabled() {
func (s *ManagerSuite) TestInitializePlugin_noOpIfEmpty() { func (s *ManagerSuite) TestInitializePlugin_noOpIfEmpty() {
assert.Nil(s.T(), s.manager.loadPlugins("")) assert.Nil(s.T(), s.manager.loadPlugins(""))
} }
func (s *ManagerSuite) TestInitializePlugin_directoryInvalid_expectError() { func (s *ManagerSuite) TestInitializePlugin_directoryInvalid_expectError() {
assert.Error(s.T(), s.manager.loadPlugins("<<")) assert.Error(s.T(), s.manager.loadPlugins("<<"))
} }
@ -166,9 +162,7 @@ func (s *ManagerSuite) TestInitializePlugin_brokenPlugin_expectError() {
exec.Command("go", "get", "-d").Run() exec.Command("go", "get", "-d").Run()
goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + tmpDir.Path("empty.so")} goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + tmpDir.Path("empty.so")}
for _, extraFlag := range extraGoBuildFlags { goBuildFlags = append(goBuildFlags, extraGoBuildFlags...)
goBuildFlags = append(goBuildFlags, extraFlag)
}
cmd := exec.Command("go", goBuildFlags...) cmd := exec.Command("go", goBuildFlags...)
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@ -193,7 +187,6 @@ func (s *ManagerSuite) TestInitializePlugin_alreadyEnabledInConf_expectAutoEnabl
assert.Nil(s.T(), s.manager.InitializeForUserID(2)) assert.Nil(s.T(), s.manager.InitializeForUserID(2))
inst := s.getMockPluginInstance(2) inst := s.getMockPluginInstance(2)
assert.True(s.T(), inst.Enabled) assert.True(s.T(), inst.Enabled)
} }
func (s *ManagerSuite) TestInitializePlugin_alreadyEnabledInConf_failedToLoadConfig_disableAutomatically() { func (s *ManagerSuite) TestInitializePlugin_alreadyEnabledInConf_failedToLoadConfig_disableAutomatically() {
@ -209,7 +202,6 @@ func (s *ManagerSuite) TestInitializePlugin_alreadyEnabledInConf_failedToLoadCon
assert.Nil(s.T(), s.manager.InitializeForUserID(3)) assert.Nil(s.T(), s.manager.InitializeForUserID(3))
inst := s.getMockPluginInstance(3) inst := s.getMockPluginInstance(3)
assert.False(s.T(), inst.Enabled) assert.False(s.T(), inst.Enabled)
} }
func (s *ManagerSuite) TestInitializePlugin_alreadyEnabled_cannotEnable_disabledAutomatically() { func (s *ManagerSuite) TestInitializePlugin_alreadyEnabled_cannotEnable_disabledAutomatically() {
@ -360,18 +352,16 @@ func TestNewManager_NonPluginFile_expectError(t *testing.T) {
func TestNewManager_FaultyDB_expectError(t *testing.T) { func TestNewManager_FaultyDB_expectError(t *testing.T) {
tmpDir := test.NewTmpDir("gotify_testnewmanager_faultydb") tmpDir := test.NewTmpDir("gotify_testnewmanager_faultydb")
defer tmpDir.Clean() defer tmpDir.Clean()
for _, suite := range []struct { for _, data := range []struct {
pkg string pkg string
faultyTable string faultyTable string
name string name string
}{{"plugin/example/minimal/", "plugin_confs", "minimal"}, {"plugin/example/clock/", "applications", "clock"}} { }{{"plugin/example/minimal/", "plugin_confs", "minimal"}, {"plugin/example/clock/", "applications", "clock"}} {
test.WithWd(path.Join(test.GetProjectDir(), suite.pkg), func(origWd string) { test.WithWd(path.Join(test.GetProjectDir(), data.pkg), func(origWd string) {
exec.Command("go", "get", "-d").Run() exec.Command("go", "get", "-d").Run()
goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + tmpDir.Path(fmt.Sprintf("%s.so", suite.name))} goBuildFlags := []string{"build", "-buildmode=plugin", "-o=" + tmpDir.Path(fmt.Sprintf("%s.so", data.name))}
for _, extraFlag := range extraGoBuildFlags { goBuildFlags = append(goBuildFlags, extraGoBuildFlags...)
goBuildFlags = append(goBuildFlags, extraFlag)
}
cmd := exec.Command("go", goBuildFlags...) cmd := exec.Command("go", goBuildFlags...)
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@ -379,13 +369,13 @@ func TestNewManager_FaultyDB_expectError(t *testing.T) {
}) })
db := testdb.NewDBWithDefaultUser(t) db := testdb.NewDBWithDefaultUser(t)
db.GormDatabase.DB.Callback().Create().Register("no_create", func(s *gorm.Scope) { db.GormDatabase.DB.Callback().Create().Register("no_create", func(s *gorm.Scope) {
if s.TableName() == suite.faultyTable { if s.TableName() == data.faultyTable {
s.Err(errors.New("database failed")) s.Err(errors.New("database failed"))
} }
}) })
_, err := NewManager(db, tmpDir.Path(), nil, nil) _, err := NewManager(db, tmpDir.Path(), nil, nil)
assert.Error(t, err) assert.Error(t, err)
os.Remove(tmpDir.Path(fmt.Sprintf("%s.so", suite.name))) os.Remove(tmpDir.Path(fmt.Sprintf("%s.so", data.name)))
} }
} }

View File

@ -13,13 +13,13 @@ type redirectToChannel struct {
Messages chan MessageWithUserID Messages chan MessageWithUserID
} }
// MessageWithUserID encapsulates a message with a given user ID // MessageWithUserID encapsulates a message with a given user ID.
type MessageWithUserID struct { type MessageWithUserID struct {
Message model.MessageExternal Message model.MessageExternal
UserID uint UserID uint
} }
// SendMessage sends a message to the underlying message channel // SendMessage sends a message to the underlying message channel.
func (c redirectToChannel) SendMessage(msg compat.Message) error { func (c redirectToChannel) SendMessage(msg compat.Message) error {
c.Messages <- MessageWithUserID{ c.Messages <- MessageWithUserID{
Message: model.MessageExternal{ Message: model.MessageExternal{

View File

@ -4,16 +4,13 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/gin-gonic/gin"
"github.com/gotify/server/v2/model" "github.com/gotify/server/v2/model"
"github.com/gotify/server/v2/test/testdb" "github.com/gotify/server/v2/test/testdb"
"github.com/stretchr/testify/assert"
"github.com/gin-gonic/gin"
) )
func TestRequirePluginEnabled(t *testing.T) { func TestRequirePluginEnabled(t *testing.T) {
db := testdb.NewDBWithDefaultUser(t) db := testdb.NewDBWithDefaultUser(t)
conf := &model.PluginConf{ conf := &model.PluginConf{
ID: 1, ID: 1,

View File

@ -106,7 +106,6 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
clientAuth.Use(authentication.RequireClient()) clientAuth.Use(authentication.RequireClient())
app := clientAuth.Group("/application") app := clientAuth.Group("/application")
{ {
app.GET("", applicationHandler.GetApplications) app.GET("", applicationHandler.GetApplications)
app.POST("", applicationHandler.CreateApplication) app.POST("", applicationHandler.CreateApplication)
@ -119,7 +118,6 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
tokenMessage := app.Group("/:id/message") tokenMessage := app.Group("/:id/message")
{ {
tokenMessage.GET("", messageHandler.GetMessagesWithApplication) tokenMessage.GET("", messageHandler.GetMessagesWithApplication)
tokenMessage.DELETE("", messageHandler.DeleteMessageWithApplication) tokenMessage.DELETE("", messageHandler.DeleteMessageWithApplication)
@ -128,7 +126,6 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
client := clientAuth.Group("/client") client := clientAuth.Group("/client")
{ {
client.GET("", clientHandler.GetClients) client.GET("", clientHandler.GetClients)
client.POST("", clientHandler.CreateClient) client.POST("", clientHandler.CreateClient)
@ -140,7 +137,6 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
message := clientAuth.Group("/message") message := clientAuth.Group("/message")
{ {
message.GET("", messageHandler.GetMessages) message.GET("", messageHandler.GetMessages)
message.DELETE("", messageHandler.DeleteMessages) message.DELETE("", messageHandler.DeleteMessages)

View File

@ -61,7 +61,7 @@ func (s *IntegrationSuite) TestVersionInfo() {
func (s *IntegrationSuite) TestHeaderInDev() { func (s *IntegrationSuite) TestHeaderInDev() {
mode.Set(mode.TestDev) mode.Set(mode.TestDev)
req := s.newRequest("GET", "version", "") req := s.newRequest("GET", "version", "")
//Needs an origin to indicate that it is a CORS request // Needs an origin to indicate that it is a CORS request
req.Header.Add("Origin", "some-origin") req.Header.Add("Origin", "some-origin")
res, err := client.Do(req) res, err := client.Do(req)
@ -176,7 +176,8 @@ func TestAllowedOriginFromResponseHeaders(t *testing.T) {
config := config.Configuration{PassStrength: 5} config := config.Configuration{PassStrength: 5}
config.Server.ResponseHeaders = map[string]string{ config.Server.ResponseHeaders = map[string]string{
"Access-Control-Allow-Origin": "http://test1.com", "Access-Control-Allow-Origin": "http://test1.com",
"Access-Control-Allow-Methods": "GET,POST"} "Access-Control-Allow-Methods": "GET,POST",
}
g, closable := Create(db.GormDatabase, g, closable := Create(db.GormDatabase,
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"}, &model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
@ -213,7 +214,8 @@ func TestAllowedWildcardOriginInHeader(t *testing.T) {
config := config.Configuration{PassStrength: 5} config := config.Configuration{PassStrength: 5}
config.Server.ResponseHeaders = map[string]string{ config.Server.ResponseHeaders = map[string]string{
"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET,POST"} "Access-Control-Allow-Methods": "GET,POST",
}
g, closable := Create(db.GormDatabase, g, closable := Create(db.GormDatabase,
&model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"}, &model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"},
@ -265,7 +267,7 @@ func TestCORSHeaderRegex(t *testing.T) {
assert.Equal(t, "http://test123.com", res.Header.Get("Access-Control-Allow-Origin")) assert.Equal(t, "http://test123.com", res.Header.Get("Access-Control-Allow-Origin"))
} }
// We want headers in cors config to override the responseheaders config // We want headers in cors config to override the responseheaders config.
func TestCORSConfigOverride(t *testing.T) { func TestCORSConfigOverride(t *testing.T) {
mode.Set(mode.Prod) mode.Set(mode.Prod)
db := testdb.NewDBWithDefaultUser(t) db := testdb.NewDBWithDefaultUser(t)
@ -398,7 +400,7 @@ func (s *IntegrationSuite) TestAuthentication() {
assert.Equal(s.T(), "android-client", token.Name) assert.Equal(s.T(), "android-client", token.Name)
} }
func (s *IntegrationSuite) newRequest(method, url string, body string) *http.Request { func (s *IntegrationSuite) newRequest(method, url, body string) *http.Request {
req, err := http.NewRequest(method, fmt.Sprintf("%s/%s", s.server.URL, url), strings.NewReader(body)) req, err := http.NewRequest(method, fmt.Sprintf("%s/%s", s.server.URL, url), strings.NewReader(body))
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
assert.Nil(s.T(), err) assert.Nil(s.T(), err)

View File

@ -10,14 +10,13 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/gotify/server/v2/config" "github.com/gotify/server/v2/config"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
) )
// Run starts the http server and if configured a https server. // Run starts the http server and if configured a https server.
func Run(engine *gin.Engine, conf *config.Configuration) { func Run(router http.Handler, conf *config.Configuration) {
var httpHandler http.Handler = engine httpHandler := router
if *conf.Server.SSL.Enabled { if *conf.Server.SSL.Enabled {
if *conf.Server.SSL.RedirectToHTTPS { if *conf.Server.SSL.RedirectToHTTPS {
@ -27,7 +26,7 @@ func Run(engine *gin.Engine, conf *config.Configuration) {
addr := fmt.Sprintf("%s:%d", conf.Server.SSL.ListenAddr, conf.Server.SSL.Port) addr := fmt.Sprintf("%s:%d", conf.Server.SSL.ListenAddr, conf.Server.SSL.Port)
s := &http.Server{ s := &http.Server{
Addr: addr, Addr: addr,
Handler: engine, Handler: router,
} }
if *conf.Server.SSL.LetsEncrypt.Enabled { if *conf.Server.SSL.LetsEncrypt.Enabled {
@ -73,7 +72,7 @@ func redirectToHTTPS(port string) http.HandlerFunc {
} }
} }
func changePort(hostPort string, port string) string { func changePort(hostPort, port string) string {
host, _, err := net.SplitHostPort(hostPort) host, _, err := net.SplitHostPort(hostPort)
if err != nil { if err != nil {
return hostPort return hostPort

View File

@ -7,14 +7,14 @@ import (
"runtime" "runtime"
) )
// GetProjectDir returns the correct absolute path of this project // GetProjectDir returns the correct absolute path of this project.
func GetProjectDir() string { func GetProjectDir() string {
_, f, _, _ := runtime.Caller(0) _, f, _, _ := runtime.Caller(0)
projectDir, _ := filepath.Abs(path.Join(filepath.Dir(f), "../")) projectDir, _ := filepath.Abs(path.Join(filepath.Dir(f), "../"))
return projectDir return projectDir
} }
// WithWd executes a function with the specified working directory // WithWd executes a function with the specified working directory.
func WithWd(chDir string, f func(origWd string)) { func WithWd(chDir string, f func(origWd string)) {
wd, err := os.Getwd() wd, err := os.Getwd()
if err != nil { if err != nil {

View File

@ -44,5 +44,4 @@ func TestWithWd(t *testing.T) {
WithWd(".", func(string) {}) WithWd(".", func(string) {})
}) })
}) })
} }

View File

@ -156,7 +156,7 @@ func (ab *AppClientBuilder) NewClientWithToken(id uint, token string) *model.Cli
return client return client
} }
// Message creates a message and returns itself // Message creates a message and returns itself.
func (mb *MessageBuilder) Message(id uint) *MessageBuilder { func (mb *MessageBuilder) Message(id uint) *MessageBuilder {
mb.NewMessage(id) mb.NewMessage(id)
return mb return mb

View File

@ -6,22 +6,22 @@ import (
"path" "path"
) )
// TmpDir is a handler to temporary directory // TmpDir is a handler to temporary directory.
type TmpDir struct { type TmpDir struct {
path string path string
} }
// Path returns the path to the temporary directory joined by the elements provided // Path returns the path to the temporary directory joined by the elements provided.
func (c TmpDir) Path(elem ...string) string { func (c TmpDir) Path(elem ...string) string {
return path.Join(append([]string{c.path}, elem...)...) return path.Join(append([]string{c.path}, elem...)...)
} }
// Clean removes the TmpDir // Clean removes the TmpDir.
func (c TmpDir) Clean() error { func (c TmpDir) Clean() error {
return os.RemoveAll(c.path) return os.RemoveAll(c.path)
} }
// NewTmpDir returns a new handle to a tmp dir // NewTmpDir returns a new handle to a tmp dir.
func NewTmpDir(prefix string) TmpDir { func NewTmpDir(prefix string) TmpDir {
dir, _ := ioutil.TempDir("", prefix) dir, _ := ioutil.TempDir("", prefix)
return TmpDir{dir} return TmpDir{dir}