diff --git a/api/application.go b/api/application.go index a236fc0..3087a08 100644 --- a/api/application.go +++ b/api/application.go @@ -16,9 +16,9 @@ import ( // The ApplicationDatabase interface for encapsulating database access. type ApplicationDatabase interface { CreateApplication(application *model.Application) error - GetApplicationByToken(token string) *model.Application - GetApplicationByID(id uint) *model.Application - GetApplicationsByUser(userID uint) []*model.Application + GetApplicationByToken(token string) (*model.Application, error) + GetApplicationByID(id uint) (*model.Application, error) + GetApplicationsByUser(userID uint) ([]*model.Application, error) DeleteApplicationByID(id uint) error UpdateApplication(application *model.Application) error } @@ -68,7 +68,9 @@ func (a *ApplicationAPI) CreateApplication(ctx *gin.Context) { app.Token = auth.GenerateNotExistingToken(generateApplicationToken, a.applicationExists) app.UserID = auth.GetUserID(ctx) app.Internal = false - a.DB.CreateApplication(&app) + if success := successOrAbort(ctx, 500, a.DB.CreateApplication(&app)); !success { + return + } ctx.JSON(200, withResolvedImage(&app)) } } @@ -99,7 +101,10 @@ func (a *ApplicationAPI) CreateApplication(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *ApplicationAPI) GetApplications(ctx *gin.Context) { userID := auth.GetUserID(ctx) - apps := a.DB.GetApplicationsByUser(userID) + apps, err := a.DB.GetApplicationsByUser(userID) + if success := successOrAbort(ctx, 500, err); !success { + return + } for _, app := range apps { withResolvedImage(app) } @@ -142,12 +147,18 @@ func (a *ApplicationAPI) GetApplications(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *ApplicationAPI) DeleteApplication(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - if app := a.DB.GetApplicationByID(id); app != nil && app.UserID == auth.GetUserID(ctx) { + app, err := a.DB.GetApplicationByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if app != nil && app.UserID == auth.GetUserID(ctx) { if app.Internal { ctx.AbortWithError(400, errors.New("cannot delete internal application")) return } - a.DB.DeleteApplicationByID(id) + if success := successOrAbort(ctx, 500, a.DB.DeleteApplicationByID(id)); !success { + return + } if app.Image != "" { os.Remove(a.ImageDir + app.Image) } @@ -201,15 +212,21 @@ func (a *ApplicationAPI) DeleteApplication(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *ApplicationAPI) UpdateApplication(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - if app := a.DB.GetApplicationByID(id); app != nil && app.UserID == auth.GetUserID(ctx) { + app, err := a.DB.GetApplicationByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if app != nil && app.UserID == auth.GetUserID(ctx) { newValues := &model.Application{} if err := ctx.Bind(newValues); err == nil { app.Description = newValues.Description app.Name = newValues.Name - a.DB.UpdateApplication(app) - + if success := successOrAbort(ctx, 500, a.DB.UpdateApplication(app)); !success { + return + } ctx.JSON(200, withResolvedImage(app)) + } } else { ctx.AbortWithError(404, fmt.Errorf("app with id %d doesn't exists", id)) @@ -265,7 +282,11 @@ func (a *ApplicationAPI) UpdateApplication(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *ApplicationAPI) UploadApplicationImage(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - if app := a.DB.GetApplicationByID(id); app != nil && app.UserID == auth.GetUserID(ctx) { + app, err := a.DB.GetApplicationByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if app != nil && app.UserID == auth.GetUserID(ctx) { file, err := ctx.FormFile("file") if err == http.ErrMissingFile { ctx.AbortWithError(400, errors.New("file with key 'file' must be present")) @@ -284,12 +305,11 @@ func (a *ApplicationAPI) UploadApplicationImage(ctx *gin.Context) { ext := filepath.Ext(file.Filename) - name := generateImageName() - for exist(a.ImageDir + name + ext) { - name = generateImageName() - } + name := generateNonExistingImageName(a.ImageDir, func() string { + return generateImageName() + ext + }) - err = ctx.SaveUploadedFile(file, a.ImageDir+name+ext) + err = ctx.SaveUploadedFile(file, a.ImageDir+name) if err != nil { ctx.AbortWithError(500, err) return @@ -299,11 +319,13 @@ func (a *ApplicationAPI) UploadApplicationImage(ctx *gin.Context) { os.Remove(a.ImageDir + app.Image) } - app.Image = name + ext - a.DB.UpdateApplication(app) + app.Image = name + if success := successOrAbort(ctx, 500, a.DB.UpdateApplication(app)); !success { + return + } ctx.JSON(200, withResolvedImage(app)) } else { - ctx.AbortWithError(404, fmt.Errorf("client with id %d doesn't exists", id)) + ctx.AbortWithError(404, fmt.Errorf("app with id %d doesn't exists", id)) } }) } @@ -318,7 +340,8 @@ func withResolvedImage(app *model.Application) *model.Application { } func (a *ApplicationAPI) applicationExists(token string) bool { - return a.DB.GetApplicationByToken(token) != nil + app, _ := a.DB.GetApplicationByToken(token) + return app != nil } func exist(path string) bool { @@ -327,3 +350,12 @@ func exist(path string) bool { } return true } + +func generateNonExistingImageName(imgDir string, gen func() string) string { + for { + name := gen() + if !exist(imgDir + name) { + return name + } + } +} diff --git a/api/application_test.go b/api/application_test.go index e41c033..f3d52a5 100644 --- a/api/application_test.go +++ b/api/application_test.go @@ -74,7 +74,9 @@ func (s *ApplicationSuite) Test_CreateApplication_mapAllParameters() { Description: "description_text", } assert.Equal(s.T(), 200, s.recorder.Code) - assert.Equal(s.T(), expected, s.db.GetApplicationByID(1)) + if app, err := s.db.GetApplicationByID(1); assert.NoError(s.T(), err) { + assert.Equal(s.T(), expected, app) + } } func (s *ApplicationSuite) Test_ensureApplicationHasCorrectJsonRepresentation() { actual := &model.Application{ @@ -96,7 +98,9 @@ func (s *ApplicationSuite) Test_CreateApplication_expectBadRequestOnEmptyName() s.a.CreateApplication(s.ctx) assert.Equal(s.T(), 400, s.recorder.Code) - assert.Empty(s.T(), s.db.GetApplicationsByUser(5)) + if app, err := s.db.GetApplicationsByUser(5); assert.NoError(s.T(), err) { + assert.Empty(s.T(), app) + } } func (s *ApplicationSuite) Test_DeleteApplication_expectNotFoundOnCurrentUserIsNotOwner() { @@ -122,7 +126,9 @@ func (s *ApplicationSuite) Test_CreateApplication_onlyRequiredParameters() { expected := &model.Application{ID: 1, Token: firstApplicationToken, Name: "custom_name", UserID: 5} assert.Equal(s.T(), 200, s.recorder.Code) - assert.Contains(s.T(), s.db.GetApplicationsByUser(5), expected) + if app, err := s.db.GetApplicationsByUser(5); assert.NoError(s.T(), err) { + assert.Contains(s.T(), app, expected) + } } func (s *ApplicationSuite) Test_CreateApplication_returnsApplicationWithID() { @@ -155,7 +161,9 @@ func (s *ApplicationSuite) Test_CreateApplication_withExistingToken() { expected := &model.Application{ID: 2, Token: secondApplicationToken, Name: "custom_name", UserID: 5} assert.Equal(s.T(), 200, s.recorder.Code) - assert.Contains(s.T(), s.db.GetApplicationsByUser(5), expected) + if app, err := s.db.GetApplicationsByUser(5); assert.NoError(s.T(), err) { + assert.Contains(s.T(), app, expected) + } } func (s *ApplicationSuite) Test_GetApplications() { @@ -275,16 +283,18 @@ func (s *ApplicationSuite) Test_UploadAppImage_WithImageFile_expectSuccess() { s.a.UploadApplicationImage(s.ctx) - imgName := s.db.GetApplicationByID(1).Image + if app, err := s.db.GetApplicationByID(1); assert.NoError(s.T(), err) { + imgName := app.Image - assert.Equal(s.T(), 200, s.recorder.Code) - _, err = os.Stat(imgName) - assert.Nil(s.T(), err) + assert.Equal(s.T(), 200, s.recorder.Code) + _, err = os.Stat(imgName) + assert.Nil(s.T(), err) - s.a.DeleteApplication(s.ctx) + s.a.DeleteApplication(s.ctx) - _, err = os.Stat(imgName) - assert.True(s.T(), os.IsNotExist(err)) + _, err = os.Stat(imgName) + assert.True(s.T(), os.IsNotExist(err)) + } } func (s *ApplicationSuite) Test_UploadAppImage_WithImageFile_DeleteExstingImageAndGenerateNewName() { @@ -399,7 +409,9 @@ func (s *ApplicationSuite) Test_UpdateApplicationNameAndDescription_expectSucces } assert.Equal(s.T(), 200, s.recorder.Code) - assert.Equal(s.T(), expected, s.db.GetApplicationByID(2)) + if app, err := s.db.GetApplicationByID(2); assert.NoError(s.T(), err) { + assert.Equal(s.T(), expected, app) + } } func (s *ApplicationSuite) Test_UpdateApplicationName_expectSuccess() { @@ -419,7 +431,9 @@ func (s *ApplicationSuite) Test_UpdateApplicationName_expectSuccess() { } assert.Equal(s.T(), 200, s.recorder.Code) - assert.Equal(s.T(), expected, s.db.GetApplicationByID(2)) + if app, err := s.db.GetApplicationByID(2); assert.NoError(s.T(), err) { + assert.Equal(s.T(), expected, app) + } } func (s *ApplicationSuite) Test_UpdateApplication_preservesImage() { @@ -434,7 +448,9 @@ func (s *ApplicationSuite) Test_UpdateApplication_preservesImage() { s.a.UpdateApplication(s.ctx) assert.Equal(s.T(), 200, s.recorder.Code) - assert.Equal(s.T(), "existing.png", s.db.GetApplicationByID(2).Image) + if app, err := s.db.GetApplicationByID(2); assert.NoError(s.T(), err) { + assert.Equal(s.T(), "existing.png", app.Image) + } } func (s *ApplicationSuite) Test_UpdateApplication_setEmptyDescription() { @@ -449,7 +465,9 @@ func (s *ApplicationSuite) Test_UpdateApplication_setEmptyDescription() { s.a.UpdateApplication(s.ctx) assert.Equal(s.T(), 200, s.recorder.Code) - assert.Equal(s.T(), "", s.db.GetApplicationByID(2).Description) + if app, err := s.db.GetApplicationByID(2); assert.NoError(s.T(), err) { + assert.Equal(s.T(), "", app.Description) + } } func (s *ApplicationSuite) Test_UpdateApplication_expectNotFound() { diff --git a/api/client.go b/api/client.go index b897ae5..b892853 100644 --- a/api/client.go +++ b/api/client.go @@ -11,9 +11,9 @@ import ( // The ClientDatabase interface for encapsulating database access. type ClientDatabase interface { CreateClient(client *model.Client) error - GetClientByToken(token string) *model.Client - GetClientByID(id uint) *model.Client - GetClientsByUser(userID uint) []*model.Client + GetClientByToken(token string) (*model.Client, error) + GetClientByID(id uint) (*model.Client, error) + GetClientsByUser(userID uint) ([]*model.Client, error) DeleteClientByID(id uint) error UpdateClient(client *model.Client) error } @@ -69,13 +69,18 @@ type ClientAPI struct { // $ref: "#/definitions/Error" func (a *ClientAPI) UpdateClient(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - if client := a.DB.GetClientByID(id); client != nil && client.UserID == auth.GetUserID(ctx) { + client, err := a.DB.GetClientByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if client != nil && client.UserID == auth.GetUserID(ctx) { newValues := &model.Client{} if err := ctx.Bind(newValues); err == nil { client.Name = newValues.Name - a.DB.UpdateClient(client) - + if success := successOrAbort(ctx, 500, a.DB.UpdateClient(client)); !success { + return + } ctx.JSON(200, client) } } else { @@ -122,7 +127,9 @@ func (a *ClientAPI) CreateClient(ctx *gin.Context) { if err := ctx.Bind(&client); err == nil { client.Token = auth.GenerateNotExistingToken(generateClientToken, a.clientExists) client.UserID = auth.GetUserID(ctx) - a.DB.CreateClient(&client) + if success := successOrAbort(ctx, 500, a.DB.CreateClient(&client)); !success { + return + } ctx.JSON(200, client) } } @@ -153,7 +160,10 @@ func (a *ClientAPI) CreateClient(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *ClientAPI) GetClients(ctx *gin.Context) { userID := auth.GetUserID(ctx) - clients := a.DB.GetClientsByUser(userID) + clients, err := a.DB.GetClientsByUser(userID) + if success := successOrAbort(ctx, 500, err); !success { + return + } ctx.JSON(200, clients) } @@ -193,9 +203,13 @@ func (a *ClientAPI) GetClients(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *ClientAPI) DeleteClient(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - if client := a.DB.GetClientByID(id); client != nil && client.UserID == auth.GetUserID(ctx) { + client, err := a.DB.GetClientByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if client != nil && client.UserID == auth.GetUserID(ctx) { a.NotifyDeleted(client.UserID, client.Token) - a.DB.DeleteClientByID(id) + successOrAbort(ctx, 500, a.DB.DeleteClientByID(id)) } else { ctx.AbortWithError(404, fmt.Errorf("client with id %d doesn't exists", id)) } @@ -203,5 +217,6 @@ func (a *ClientAPI) DeleteClient(ctx *gin.Context) { } func (a *ClientAPI) clientExists(token string) bool { - return a.DB.GetClientByToken(token) != nil + client, _ := a.DB.GetClientByToken(token) + return client != nil } diff --git a/api/client_test.go b/api/client_test.go index 6e2f414..084fa4f 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -71,7 +71,9 @@ func (s *ClientSuite) Test_CreateClient_mapAllParameters() { expected := &model.Client{ID: 1, Token: firstClientToken, UserID: 5, Name: "custom_name"} assert.Equal(s.T(), 200, s.recorder.Code) - assert.Contains(s.T(), s.db.GetClientsByUser(5), expected) + if clients, err := s.db.GetClientsByUser(5); assert.NoError(s.T(), err) { + assert.Contains(s.T(), clients, expected) + } } func (s *ClientSuite) Test_CreateClient_expectBadRequestOnEmptyName() { @@ -83,7 +85,9 @@ func (s *ClientSuite) Test_CreateClient_expectBadRequestOnEmptyName() { s.a.CreateClient(s.ctx) assert.Equal(s.T(), 400, s.recorder.Code) - assert.Empty(s.T(), s.db.GetClientsByUser(5)) + if clients, err := s.db.GetClientsByUser(5); assert.NoError(s.T(), err) { + assert.Empty(s.T(), clients) + } } func (s *ClientSuite) Test_DeleteClient_expectNotFoundOnCurrentUserIsNotOwner() { @@ -184,7 +188,9 @@ func (s *ClientSuite) Test_UpdateClient_expectSuccess() { } assert.Equal(s.T(), 200, s.recorder.Code) - assert.Equal(s.T(), expected, s.db.GetClientByID(1)) + if client, err := s.db.GetClientByID(1); assert.NoError(s.T(), err) { + assert.Equal(s.T(), expected, client) + } } func (s *ClientSuite) Test_UpdateClient_expectNotFound() { diff --git a/api/errorHandling.go b/api/errorHandling.go new file mode 100644 index 0000000..f5d2b11 --- /dev/null +++ b/api/errorHandling.go @@ -0,0 +1,10 @@ +package api + +import "github.com/gin-gonic/gin" + +func successOrAbort(ctx *gin.Context, code int, err error) (success bool) { + if err != nil { + ctx.AbortWithError(code, err) + } + return err == nil +} diff --git a/api/errorHandling_test.go b/api/errorHandling_test.go new file mode 100644 index 0000000..0b166b7 --- /dev/null +++ b/api/errorHandling_test.go @@ -0,0 +1,20 @@ +package api + +import ( + "errors" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestErrorHandling(t *testing.T) { + rec := httptest.NewRecorder() + + ctx, _ := gin.CreateTestContext(rec) + successOrAbort(ctx, 500, errors.New("err")) + + if rec.Code != 500 { + t.Fail() + } +} diff --git a/api/message.go b/api/message.go index 88cd6d3..77f407d 100644 --- a/api/message.go +++ b/api/message.go @@ -16,15 +16,15 @@ import ( // The MessageDatabase interface for encapsulating database access. type MessageDatabase interface { - GetMessagesByApplicationSince(appID uint, limit int, since uint) []*model.Message - GetApplicationByID(id uint) *model.Application - GetMessagesByUserSince(userID uint, limit int, since uint) []*model.Message + GetMessagesByApplicationSince(appID uint, limit int, since uint) ([]*model.Message, error) + GetApplicationByID(id uint) (*model.Application, error) + GetMessagesByUserSince(userID uint, limit int, since uint) ([]*model.Message, error) DeleteMessageByID(id uint) error - GetMessageByID(id uint) *model.Message + GetMessageByID(id uint) (*model.Message, error) DeleteMessagesByUser(userID uint) error DeleteMessagesByApplication(applicationID uint) error CreateMessage(message *model.Message) error - GetApplicationByToken(token string) *model.Application + GetApplicationByToken(token string) (*model.Application, error) } var timeNow = time.Now @@ -89,7 +89,10 @@ func (a *MessageAPI) GetMessages(ctx *gin.Context) { userID := auth.GetUserID(ctx) withPaging(ctx, func(params *pagingParams) { // the +1 is used to check if there are more messages and will be removed on buildWithPaging - messages := a.DB.GetMessagesByUserSince(userID, params.Limit+1, params.Since) + messages, err := a.DB.GetMessagesByUserSince(userID, params.Limit+1, params.Since) + if success := successOrAbort(ctx, 500, err); !success { + return + } ctx.JSON(200, buildWithPaging(ctx, params, messages)) }) } @@ -174,9 +177,16 @@ func withPaging(ctx *gin.Context, f func(pagingParams *pagingParams)) { func (a *MessageAPI) GetMessagesWithApplication(ctx *gin.Context) { withID(ctx, "id", func(id uint) { withPaging(ctx, func(params *pagingParams) { - if app := a.DB.GetApplicationByID(id); app != nil && app.UserID == auth.GetUserID(ctx) { + app, err := a.DB.GetApplicationByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if app != nil && app.UserID == auth.GetUserID(ctx) { // the +1 is used to check if there are more messages and will be removed on buildWithPaging - messages := a.DB.GetMessagesByApplicationSince(id, params.Limit+1, params.Since) + messages, err := a.DB.GetMessagesByApplicationSince(id, params.Limit+1, params.Since) + if success := successOrAbort(ctx, 500, err); !success { + return + } ctx.JSON(200, buildWithPaging(ctx, params, messages)) } else { ctx.AbortWithError(404, errors.New("application does not exist")) @@ -206,7 +216,7 @@ func (a *MessageAPI) GetMessagesWithApplication(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *MessageAPI) DeleteMessages(ctx *gin.Context) { userID := auth.GetUserID(ctx) - a.DB.DeleteMessagesByUser(userID) + successOrAbort(ctx, 500, a.DB.DeleteMessagesByUser(userID)) } // DeleteMessageWithApplication deletes all messages from a specific application. @@ -244,8 +254,12 @@ func (a *MessageAPI) DeleteMessages(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *MessageAPI) DeleteMessageWithApplication(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - if application := a.DB.GetApplicationByID(id); application != nil && application.UserID == auth.GetUserID(ctx) { - a.DB.DeleteMessagesByApplication(id) + application, err := a.DB.GetApplicationByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if application != nil && application.UserID == auth.GetUserID(ctx) { + successOrAbort(ctx, 500, a.DB.DeleteMessagesByApplication(id)) } else { ctx.AbortWithError(404, errors.New("application does not exists")) } @@ -287,10 +301,22 @@ func (a *MessageAPI) DeleteMessageWithApplication(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *MessageAPI) DeleteMessage(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - if msg := a.DB.GetMessageByID(id); msg != nil && a.DB.GetApplicationByID(msg.ApplicationID).UserID == auth.GetUserID(ctx) { - a.DB.DeleteMessageByID(id) + msg, err := a.DB.GetMessageByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if msg == nil { + ctx.AbortWithError(404, errors.New("message does not exist")) + return + } + app, err := a.DB.GetApplicationByID(msg.ApplicationID) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if app != nil && app.UserID == auth.GetUserID(ctx) { + successOrAbort(ctx, 500, a.DB.DeleteMessageByID(id)) } else { - ctx.AbortWithError(404, errors.New("message does not exists")) + ctx.AbortWithError(404, errors.New("message does not exist")) } }) } @@ -332,14 +358,19 @@ func (a *MessageAPI) DeleteMessage(ctx *gin.Context) { func (a *MessageAPI) CreateMessage(ctx *gin.Context) { message := model.MessageExternal{} if err := ctx.Bind(&message); err == nil { - application := a.DB.GetApplicationByToken(auth.GetTokenID(ctx)) + application, err := a.DB.GetApplicationByToken(auth.GetTokenID(ctx)) + if success := successOrAbort(ctx, 500, err); !success { + return + } message.ApplicationID = application.ID if strings.TrimSpace(message.Title) == "" { message.Title = application.Name } message.Date = timeNow() msgInternal := toInternalMessage(&message) - a.DB.CreateMessage(msgInternal) + if success := successOrAbort(ctx, 500, a.DB.CreateMessage(msgInternal)); !success { + return + } a.Notifier.Notify(auth.GetUserID(ctx), toExternalMessage(msgInternal)) ctx.JSON(200, toExternalMessage(msgInternal)) } diff --git a/api/message_test.go b/api/message_test.go index e1de091..e1430ea 100644 --- a/api/message_test.go +++ b/api/message_test.go @@ -329,7 +329,8 @@ func (s *MessageSuite) Test_CreateMessage_onJson_allParams() { s.a.CreateMessage(s.ctx) - msgs := s.db.GetMessagesByApplication(7) + msgs, err := s.db.GetMessagesByApplication(7) + assert.NoError(s.T(), err) expected := &model.MessageExternal{ID: 1, ApplicationID: 7, Title: "mytitle", Message: "mymessage", Priority: 1, Date: t} assert.Len(s.T(), msgs, 1) assert.Equal(s.T(), expected, toExternalMessage(msgs[0])) @@ -349,7 +350,8 @@ func (s *MessageSuite) Test_CreateMessage_WithTitle() { s.a.CreateMessage(s.ctx) - msgs := s.db.GetMessagesByApplication(5) + msgs, err := s.db.GetMessagesByApplication(5) + assert.NoError(s.T(), err) expected := &model.MessageExternal{ID: 1, ApplicationID: 5, Title: "mytitle", Message: "mymessage", Date: t} assert.Len(s.T(), msgs, 1) assert.Equal(s.T(), expected, toExternalMessage(msgs[0])) @@ -366,7 +368,9 @@ func (s *MessageSuite) Test_CreateMessage_failWhenNoMessage() { s.a.CreateMessage(s.ctx) - assert.Empty(s.T(), s.db.GetMessagesByApplication(1)) + if msgs, err := s.db.GetMessagesByApplication(1); assert.NoError(s.T(), err) { + assert.Empty(s.T(), msgs) + } assert.Equal(s.T(), 400, s.recorder.Code) assert.Nil(s.T(), s.notifiedMessage) } @@ -380,7 +384,8 @@ func (s *MessageSuite) Test_CreateMessage_WithoutTitle() { s.a.CreateMessage(s.ctx) - msgs := s.db.GetMessagesByApplication(8) + msgs, err := s.db.GetMessagesByApplication(8) + assert.NoError(s.T(), err) assert.Len(s.T(), msgs, 1) assert.Equal(s.T(), "Application name", msgs[0].Title) assert.Equal(s.T(), 200, s.recorder.Code) @@ -396,7 +401,8 @@ func (s *MessageSuite) Test_CreateMessage_WithBlankTitle() { s.a.CreateMessage(s.ctx) - msgs := s.db.GetMessagesByApplication(8) + msgs, err := s.db.GetMessagesByApplication(8) + assert.NoError(s.T(), err) assert.Len(s.T(), msgs, 1) assert.Equal(s.T(), "Application name", msgs[0].Title) assert.Equal(s.T(), 200, s.recorder.Code) @@ -415,7 +421,8 @@ func (s *MessageSuite) Test_CreateMessage_WithExtras() { s.a.CreateMessage(s.ctx) - msgs := s.db.GetMessagesByApplication(8) + msgs, err := s.db.GetMessagesByApplication(8) + assert.NoError(s.T(), err) expected := &model.MessageExternal{ ID: 1, ApplicationID: 8, @@ -450,7 +457,9 @@ func (s *MessageSuite) Test_CreateMessage_failWhenPriorityNotNumber() { assert.Equal(s.T(), 400, s.recorder.Code) assert.Nil(s.T(), s.notifiedMessage) - assert.Empty(s.T(), s.db.GetMessagesByApplication(1)) + if msgs, err := s.db.GetMessagesByApplication(1); assert.NoError(s.T(), err) { + assert.Empty(s.T(), msgs) + } } func (s *MessageSuite) Test_CreateMessage_onQueryData() { @@ -468,7 +477,8 @@ func (s *MessageSuite) Test_CreateMessage_onQueryData() { expected := &model.MessageExternal{ID: 1, ApplicationID: 2, Title: "mytitle", Message: "mymessage", Priority: 1, Date: t} - msgs := s.db.GetMessagesByApplication(2) + msgs, err := s.db.GetMessagesByApplication(2) + assert.NoError(s.T(), err) assert.Len(s.T(), msgs, 1) assert.Equal(s.T(), expected, toExternalMessage(msgs[0])) assert.Equal(s.T(), 200, s.recorder.Code) @@ -488,7 +498,8 @@ func (s *MessageSuite) Test_CreateMessage_onFormData() { s.a.CreateMessage(s.ctx) expected := &model.MessageExternal{ID: 1, ApplicationID: 99, Title: "mytitle", Message: "mymessage", Priority: 1, Date: t} - msgs := s.db.GetMessagesByApplication(99) + msgs, err := s.db.GetMessagesByApplication(99) + assert.NoError(s.T(), err) assert.Len(s.T(), msgs, 1) assert.Equal(s.T(), expected, toExternalMessage(msgs[0])) assert.Equal(s.T(), 200, s.recorder.Code) diff --git a/api/plugin.go b/api/plugin.go index b162433..dedda15 100644 --- a/api/plugin.go +++ b/api/plugin.go @@ -17,9 +17,9 @@ import ( // The PluginDatabase interface for encapsulating database access. type PluginDatabase interface { - GetPluginConfByUser(userid uint) []*model.PluginConf + GetPluginConfByUser(userid uint) ([]*model.PluginConf, error) UpdatePluginConf(p *model.PluginConf) error - GetPluginConfByID(id uint) *model.PluginConf + GetPluginConfByID(id uint) (*model.PluginConf, error) } // The PluginAPI provides handlers for managing plugins. @@ -63,7 +63,10 @@ type PluginAPI struct { // $ref: "#/definitions/Error" func (c *PluginAPI) GetPlugins(ctx *gin.Context) { userID := auth.GetUserID(ctx) - plugins := c.DB.GetPluginConfByUser(userID) + plugins, err := c.DB.GetPluginConfByUser(userID) + if success := successOrAbort(ctx, 500, err); !success { + return + } result := make([]model.PluginConfExternal, 0) for _, conf := range plugins { if inst, err := c.Manager.Instance(conf.ID); err == nil { @@ -120,12 +123,15 @@ func (c *PluginAPI) GetPlugins(ctx *gin.Context) { // $ref: "#/definitions/Error" func (c *PluginAPI) EnablePlugin(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - conf := c.DB.GetPluginConfByID(id) + conf, err := c.DB.GetPluginConfByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } if conf == nil || !isPluginOwner(ctx, conf) { ctx.AbortWithError(404, errors.New("unknown plugin")) return } - _, err := c.Manager.Instance(id) + _, err = c.Manager.Instance(id) if err != nil { ctx.AbortWithError(404, errors.New("plugin instance not found")) return @@ -174,12 +180,15 @@ func (c *PluginAPI) EnablePlugin(ctx *gin.Context) { // $ref: "#/definitions/Error" func (c *PluginAPI) DisablePlugin(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - conf := c.DB.GetPluginConfByID(id) + conf, err := c.DB.GetPluginConfByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } if conf == nil || !isPluginOwner(ctx, conf) { ctx.AbortWithError(404, errors.New("unknown plugin")) return } - _, err := c.Manager.Instance(id) + _, err = c.Manager.Instance(id) if err != nil { ctx.AbortWithError(404, errors.New("plugin instance not found")) return @@ -230,7 +239,10 @@ func (c *PluginAPI) DisablePlugin(ctx *gin.Context) { // $ref: "#/definitions/Error" func (c *PluginAPI) GetDisplay(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - conf := c.DB.GetPluginConfByID(id) + conf, err := c.DB.GetPluginConfByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } if conf == nil || !isPluginOwner(ctx, conf) { ctx.AbortWithError(404, errors.New("unknown plugin")) return @@ -287,7 +299,10 @@ func (c *PluginAPI) GetDisplay(ctx *gin.Context) { // $ref: "#/definitions/Error" func (c *PluginAPI) GetConfig(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - conf := c.DB.GetPluginConfByID(id) + conf, err := c.DB.GetPluginConfByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } if conf == nil || !isPluginOwner(ctx, conf) { ctx.AbortWithError(404, errors.New("unknown plugin")) return @@ -305,7 +320,6 @@ func (c *PluginAPI) GetConfig(ctx *gin.Context) { ctx.Header("content-type", "application/x-yaml") ctx.Writer.Write(conf.Config) }) - } // UpdateConfig updates Configurer plugin configuration in YAML format. @@ -348,7 +362,10 @@ func (c *PluginAPI) GetConfig(ctx *gin.Context) { // $ref: "#/definitions/Error" func (c *PluginAPI) UpdateConfig(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - conf := c.DB.GetPluginConfByID(id) + conf, err := c.DB.GetPluginConfByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } if conf == nil || !isPluginOwner(ctx, conf) { ctx.AbortWithError(404, errors.New("unknown plugin")) return @@ -378,7 +395,7 @@ func (c *PluginAPI) UpdateConfig(ctx *gin.Context) { return } conf.Config = newconfBytes - c.DB.UpdatePluginConf(conf) + successOrAbort(ctx, 500, c.DB.UpdatePluginConf(conf)) }) } diff --git a/api/plugin_test.go b/api/plugin_test.go index 79f7114..56a2a86 100644 --- a/api/plugin_test.go +++ b/api/plugin_test.go @@ -65,7 +65,9 @@ func (s *PluginSuite) BeforeTest(suiteName, testName string) { } func (s *PluginSuite) getDanglingConf(uid uint) *model.PluginConf { - return s.db.GetPluginConfByUserAndPath(uid, "github.com/gotify/server/plugin/example/removed") + conf, err := s.db.GetPluginConfByUserAndPath(uid, "github.com/gotify/server/plugin/example/removed") + assert.NoError(s.T(), err) + return conf } func (s *PluginSuite) resetRecorder() { @@ -109,7 +111,9 @@ func (s *PluginSuite) Test_EnableDisablePlugin() { assert.Equal(s.T(), 200, s.recorder.Code) - assert.True(s.T(), s.db.GetPluginConfByUserAndPath(1, mock.ModulePath).Enabled) + if pluginConf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath); assert.NoError(s.T(), err) { + assert.True(s.T(), pluginConf.Enabled) + } s.resetRecorder() } @@ -122,7 +126,9 @@ func (s *PluginSuite) Test_EnableDisablePlugin() { assert.Equal(s.T(), 400, s.recorder.Code) - assert.True(s.T(), s.db.GetPluginConfByUserAndPath(1, mock.ModulePath).Enabled) + if pluginConf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath); assert.NoError(s.T(), err) { + assert.True(s.T(), pluginConf.Enabled) + } s.resetRecorder() } @@ -135,7 +141,9 @@ func (s *PluginSuite) Test_EnableDisablePlugin() { assert.Equal(s.T(), 200, s.recorder.Code) - assert.False(s.T(), s.db.GetPluginConfByUserAndPath(1, mock.ModulePath).Enabled) + if pluginConf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath); assert.NoError(s.T(), err) { + assert.False(s.T(), pluginConf.Enabled) + } s.resetRecorder() } @@ -148,7 +156,9 @@ func (s *PluginSuite) Test_EnableDisablePlugin() { assert.Equal(s.T(), 400, s.recorder.Code) - assert.False(s.T(), s.db.GetPluginConfByUserAndPath(1, mock.ModulePath).Enabled) + if pluginConf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath); assert.NoError(s.T(), err) { + assert.False(s.T(), pluginConf.Enabled) + } s.resetRecorder() } @@ -158,7 +168,8 @@ func (s *PluginSuite) Test_EnableDisablePlugin_EnableReturnsError_expect500() { s.db.User(16) assert.Nil(s.T(), s.manager.InitializeForUserID(16)) mock.ReturnErrorOnEnableForUser(16, errors.New("test error")) - conf := s.db.GetPluginConfByUserAndPath(16, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(16, mock.ModulePath) + assert.NoError(s.T(), err) { test.WithUser(s.ctx, 16) @@ -168,7 +179,9 @@ func (s *PluginSuite) Test_EnableDisablePlugin_EnableReturnsError_expect500() { assert.Equal(s.T(), 500, s.recorder.Code) - assert.False(s.T(), s.db.GetPluginConfByUserAndPath(1, mock.ModulePath).Enabled) + if pluginConf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath); assert.NoError(s.T(), err) { + assert.False(s.T(), pluginConf.Enabled) + } s.resetRecorder() } } @@ -177,7 +190,8 @@ func (s *PluginSuite) Test_EnableDisablePlugin_DisableReturnsError_expect500() { s.db.User(17) assert.Nil(s.T(), s.manager.InitializeForUserID(17)) mock.ReturnErrorOnDisableForUser(17, errors.New("test error")) - conf := s.db.GetPluginConfByUserAndPath(17, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(17, mock.ModulePath) + assert.NoError(s.T(), err) s.manager.SetPluginEnabled(conf.ID, true) { @@ -188,7 +202,9 @@ func (s *PluginSuite) Test_EnableDisablePlugin_DisableReturnsError_expect500() { assert.Equal(s.T(), 500, s.recorder.Code) - assert.False(s.T(), s.db.GetPluginConfByUserAndPath(1, mock.ModulePath).Enabled) + if pluginConf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath); assert.NoError(s.T(), err) { + assert.False(s.T(), pluginConf.Enabled) + } s.resetRecorder() } } @@ -203,7 +219,9 @@ func (s *PluginSuite) Test_EnableDisablePlugin_incorrectUser_expectNotFound() { assert.Equal(s.T(), 404, s.recorder.Code) - assert.False(s.T(), s.db.GetPluginConfByUserAndPath(1, mock.ModulePath).Enabled) + if pluginConf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath); assert.NoError(s.T(), err) { + assert.False(s.T(), pluginConf.Enabled) + } s.resetRecorder() } @@ -216,7 +234,9 @@ func (s *PluginSuite) Test_EnableDisablePlugin_incorrectUser_expectNotFound() { assert.Equal(s.T(), 404, s.recorder.Code) - assert.False(s.T(), s.db.GetPluginConfByUserAndPath(1, mock.ModulePath).Enabled) + if pluginConf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath); assert.NoError(s.T(), err) { + assert.False(s.T(), pluginConf.Enabled) + } s.resetRecorder() } @@ -274,7 +294,8 @@ func (s *PluginSuite) Test_EnableDisablePlugin_danglingConf_expectNotFound() { } func (s *PluginSuite) Test_GetDisplay() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -294,7 +315,8 @@ func (s *PluginSuite) Test_GetDisplay() { } func (s *PluginSuite) Test_GetDisplay_NotImplemented_expectEmptyString() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -315,7 +337,8 @@ func (s *PluginSuite) Test_GetDisplay_NotImplemented_expectEmptyString() { } func (s *PluginSuite) Test_GetDisplay_incorrectUser_expectNotFound() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -360,7 +383,8 @@ func (s *PluginSuite) Test_GetDisplay_nonExistPlugin_expectNotFound() { } func (s *PluginSuite) Test_GetConfig() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -381,7 +405,8 @@ func (s *PluginSuite) Test_GetConfig() { } func (s *PluginSuite) Test_GetConfg_notImplemeted_expect400() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -401,7 +426,8 @@ func (s *PluginSuite) Test_GetConfg_notImplemeted_expect400() { } func (s *PluginSuite) Test_GetConfig_incorrectUser_expectNotFound() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) { test.WithUser(s.ctx, 2) @@ -441,7 +467,8 @@ func (s *PluginSuite) Test_GetConfig_nonExistPlugin_expectNotFound() { } func (s *PluginSuite) Test_UpdateConfig() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -463,7 +490,10 @@ func (s *PluginSuite) Test_UpdateConfig() { assert.Equal(s.T(), 200, s.recorder.Code) assert.Equal(s.T(), newConfig, mockInst.Config, "config should be received by plugin") - pluginFromDBBytes := s.db.GetPluginConfByID(conf.ID).Config + var pluginFromDBBytes []byte + if pluginConf, err := s.db.GetPluginConfByID(conf.ID); assert.NoError(s.T(), err) { + pluginFromDBBytes = pluginConf.Config + } pluginFromDB := new(mock.PluginConfig) err := yaml.Unmarshal(pluginFromDBBytes, pluginFromDB) assert.Nil(s.T(), err) @@ -472,7 +502,8 @@ func (s *PluginSuite) Test_UpdateConfig() { } func (s *PluginSuite) Test_UpdateConfig_invalidConfig_expect400() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -496,7 +527,10 @@ func (s *PluginSuite) Test_UpdateConfig_invalidConfig_expect400() { assert.Equal(s.T(), 400, s.recorder.Code) assert.Equal(s.T(), origConfig, mockInst.Config, "config should not be received by plugin") - pluginFromDBBytes := s.db.GetPluginConfByID(conf.ID).Config + var pluginFromDBBytes []byte + if pluginConf, err := s.db.GetPluginConfByID(conf.ID); assert.NoError(s.T(), err) { + pluginFromDBBytes = pluginConf.Config + } pluginFromDB := new(mock.PluginConfig) err := yaml.Unmarshal(pluginFromDBBytes, pluginFromDB) assert.Nil(s.T(), err) @@ -505,7 +539,8 @@ func (s *PluginSuite) Test_UpdateConfig_invalidConfig_expect400() { } func (s *PluginSuite) Test_UpdateConfig_malformedYAML_expect400() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -524,7 +559,10 @@ func (s *PluginSuite) Test_UpdateConfig_malformedYAML_expect400() { assert.Equal(s.T(), 400, s.recorder.Code) assert.Equal(s.T(), origConfig, mockInst.Config, "config should not be received by plugin") - pluginFromDBBytes := s.db.GetPluginConfByID(conf.ID).Config + var pluginFromDBBytes []byte + if pluginConf, err := s.db.GetPluginConfByID(conf.ID); assert.NoError(s.T(), err) { + pluginFromDBBytes = pluginConf.Config + } pluginFromDB := new(mock.PluginConfig) err := yaml.Unmarshal(pluginFromDBBytes, pluginFromDB) assert.Nil(s.T(), err) @@ -533,7 +571,8 @@ func (s *PluginSuite) Test_UpdateConfig_malformedYAML_expect400() { } func (s *PluginSuite) Test_UpdateConfig_ioError_expect500() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -550,7 +589,10 @@ func (s *PluginSuite) Test_UpdateConfig_ioError_expect500() { assert.Equal(s.T(), 500, s.recorder.Code) assert.Equal(s.T(), origConfig, mockInst.Config, "config should not be received by plugin") - pluginFromDBBytes := s.db.GetPluginConfByID(conf.ID).Config + var pluginFromDBBytes []byte + if pluginConf, err := s.db.GetPluginConfByID(conf.ID); assert.NoError(s.T(), err) { + pluginFromDBBytes = pluginConf.Config + } pluginFromDB := new(mock.PluginConfig) err := yaml.Unmarshal(pluginFromDBBytes, pluginFromDB) assert.Nil(s.T(), err) @@ -559,7 +601,8 @@ func (s *PluginSuite) Test_UpdateConfig_ioError_expect500() { } func (s *PluginSuite) Test_UpdateConfig_notImplemented_expect400() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -586,7 +629,8 @@ func (s *PluginSuite) Test_UpdateConfig_notImplemented_expect400() { } func (s *PluginSuite) Test_UpdateConfig_incorrectUser_expectNotFound() { - conf := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) inst, err := s.manager.Instance(conf.ID) assert.Nil(s.T(), err) mockInst := inst.(*mock.PluginInstance) @@ -609,7 +653,10 @@ func (s *PluginSuite) Test_UpdateConfig_incorrectUser_expectNotFound() { assert.Equal(s.T(), 404, s.recorder.Code) assert.Equal(s.T(), origConfig, mockInst.Config, "config should not be received by plugin") - pluginFromDBBytes := s.db.GetPluginConfByID(conf.ID).Config + var pluginFromDBBytes []byte + if pluginConf, err := s.db.GetPluginConfByID(conf.ID); assert.NoError(s.T(), err) { + pluginFromDBBytes = pluginConf.Config + } pluginFromDB := new(mock.PluginConfig) err := yaml.Unmarshal(pluginFromDBBytes, pluginFromDB) assert.Nil(s.T(), err) diff --git a/api/user.go b/api/user.go index e0e8305..109de8f 100644 --- a/api/user.go +++ b/api/user.go @@ -11,13 +11,13 @@ import ( // The UserDatabase interface for encapsulating database access. type UserDatabase interface { - GetUsers() []*model.User - GetUserByID(id uint) *model.User - GetUserByName(name string) *model.User + GetUsers() ([]*model.User, error) + GetUserByID(id uint) (*model.User, error) + GetUserByName(name string) (*model.User, error) DeleteUserByID(id uint) error - UpdateUser(user *model.User) + UpdateUser(user *model.User) error CreateUser(user *model.User) error - CountUser(condition ...interface{}) int + CountUser(condition ...interface{}) (int, error) } // UserChangeNotifier notifies listeners for user changes. @@ -85,8 +85,10 @@ type UserAPI struct { // schema: // $ref: "#/definitions/Error" func (a *UserAPI) GetUsers(ctx *gin.Context) { - users := a.DB.GetUsers() - + users, err := a.DB.GetUsers() + if success := successOrAbort(ctx, 500, err); !success { + return + } var resp []*model.UserExternal for _, user := range users { resp = append(resp, toExternalUser(user)) @@ -117,7 +119,10 @@ func (a *UserAPI) GetUsers(ctx *gin.Context) { // schema: // $ref: "#/definitions/Error" func (a *UserAPI) GetCurrentUser(ctx *gin.Context) { - user := a.DB.GetUserByID(auth.GetUserID(ctx)) + user, err := a.DB.GetUserByID(auth.GetUserID(ctx)) + if success := successOrAbort(ctx, 500, err); !success { + return + } ctx.JSON(200, toExternalUser(user)) } @@ -158,8 +163,14 @@ func (a *UserAPI) CreateUser(ctx *gin.Context) { user := model.UserExternalWithPass{} if err := ctx.Bind(&user); err == nil { internal := a.toInternalUser(&user, []byte{}) - if a.DB.GetUserByName(internal.Name) == nil { - a.DB.CreateUser(internal) + existingUser, err := a.DB.GetUserByName(internal.Name) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if existingUser == nil { + if success := successOrAbort(ctx, 500, a.DB.CreateUser(internal)); !success { + return + } if err := a.UserChangeNotifier.fireUserAdded(internal.ID); err != nil { ctx.AbortWithError(500, err) return @@ -209,7 +220,11 @@ func (a *UserAPI) CreateUser(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *UserAPI) GetUserByID(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - if user := a.DB.GetUserByID(uint(id)); user != nil { + user, err := a.DB.GetUserByID(uint(id)) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if user != nil { ctx.JSON(200, toExternalUser(user)) } else { ctx.AbortWithError(404, errors.New("user does not exist")) @@ -252,8 +267,16 @@ func (a *UserAPI) GetUserByID(ctx *gin.Context) { // $ref: "#/definitions/Error" func (a *UserAPI) DeleteUserByID(ctx *gin.Context) { withID(ctx, "id", func(id uint) { - if user := a.DB.GetUserByID(id); user != nil { - if user.Admin && a.DB.CountUser(&model.User{Admin: true}) == 1 { + user, err := a.DB.GetUserByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if user != nil { + adminCount, err := a.DB.CountUser(&model.User{Admin: true}) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if user.Admin && adminCount == 1 { ctx.AbortWithError(400, errors.New("cannot delete last admin")) return } @@ -261,7 +284,7 @@ func (a *UserAPI) DeleteUserByID(ctx *gin.Context) { ctx.AbortWithError(500, err) return } - a.DB.DeleteUserByID(id) + successOrAbort(ctx, 500, a.DB.DeleteUserByID(id)) } else { ctx.AbortWithError(404, errors.New("user does not exist")) } @@ -302,9 +325,12 @@ func (a *UserAPI) DeleteUserByID(ctx *gin.Context) { func (a *UserAPI) ChangePassword(ctx *gin.Context) { pw := model.UserExternalPass{} if err := ctx.Bind(&pw); err == nil { - user := a.DB.GetUserByID(auth.GetUserID(ctx)) + user, err := a.DB.GetUserByID(auth.GetUserID(ctx)) + if success := successOrAbort(ctx, 500, err); !success { + return + } user.Pass = password.CreatePassword(pw.Pass, a.PasswordStrength) - a.DB.UpdateUser(user) + successOrAbort(ctx, 500, a.DB.UpdateUser(user)) } } @@ -354,14 +380,24 @@ func (a *UserAPI) UpdateUserByID(ctx *gin.Context) { withID(ctx, "id", func(id uint) { var user *model.UserExternalWithPass if err := ctx.Bind(&user); err == nil { - if oldUser := a.DB.GetUserByID(id); oldUser != nil { - if !user.Admin && oldUser.Admin && a.DB.CountUser(&model.User{Admin: true}) == 1 { + oldUser, err := a.DB.GetUserByID(id) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if oldUser != nil { + adminCount, err := a.DB.CountUser(&model.User{Admin: true}) + if success := successOrAbort(ctx, 500, err); !success { + return + } + if !user.Admin && oldUser.Admin && adminCount == 1 { ctx.AbortWithError(400, errors.New("cannot delete last admin")) return } internal := a.toInternalUser(user, oldUser.Pass) internal.ID = id - a.DB.UpdateUser(internal) + if success := successOrAbort(ctx, 500, a.DB.UpdateUser(internal)); !success { + return + } ctx.JSON(200, toExternalUser(internal)) } else { ctx.AbortWithError(404, errors.New("user does not exist")) diff --git a/api/user_test.go b/api/user_test.go index 43a8ba6..4b90b14 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -173,15 +173,20 @@ func (s *UserSuite) Test_CreateUser() { test.BodyEquals(s.T(), user, s.recorder) assert.Equal(s.T(), 200, s.recorder.Code) - created := s.db.GetUserByName("tom") - assert.NotNil(s.T(), created) - assert.True(s.T(), password.ComparePassword(created.Pass, []byte("mylittlepony"))) + if created, err := s.db.GetUserByName("tom"); assert.NoError(s.T(), err) { + assert.NotNil(s.T(), created) + assert.True(s.T(), password.ComparePassword(created.Pass, []byte("mylittlepony"))) + } assert.True(s.T(), s.notifiedAdd) } func (s *UserSuite) Test_CreateUser_NotifyFail() { s.notifier.OnUserAdded(func(id uint) error { - if s.db.GetUserByID(id).Name == "eva" { + user, err := s.db.GetUserByID(id) + if err != nil { + return err + } + if user.Name == "eva" { return errors.New("some error") } return nil @@ -272,7 +277,8 @@ func (s *UserSuite) Test_UpdateUserByID_UpdateNotPassword() { s.a.UpdateUserByID(s.ctx) assert.Equal(s.T(), 200, s.recorder.Code) - user := s.db.GetUserByID(2) + user, err := s.db.GetUserByID(2) + assert.NoError(s.T(), err) assert.NotNil(s.T(), user) assert.True(s.T(), password.ComparePassword(user.Pass, []byte("old"))) } @@ -288,7 +294,8 @@ func (s *UserSuite) Test_UpdateUserByID_UpdatePassword() { s.a.UpdateUserByID(s.ctx) assert.Equal(s.T(), 200, s.recorder.Code) - user := s.db.GetUserByID(2) + user, err := s.db.GetUserByID(2) + assert.NoError(s.T(), err) assert.NotNil(s.T(), user) assert.True(s.T(), password.ComparePassword(user.Pass, []byte("new"))) } @@ -303,7 +310,8 @@ func (s *UserSuite) Test_UpdatePassword() { s.a.ChangePassword(s.ctx) assert.Equal(s.T(), 200, s.recorder.Code) - user := s.db.GetUserByID(1) + user, err := s.db.GetUserByID(1) + assert.NoError(s.T(), err) assert.NotNil(s.T(), user) assert.True(s.T(), password.ComparePassword(user.Pass, []byte("new"))) } @@ -318,7 +326,8 @@ func (s *UserSuite) Test_UpdatePassword_EmptyPassword() { s.a.ChangePassword(s.ctx) assert.Equal(s.T(), 400, s.recorder.Code) - user := s.db.GetUserByID(1) + user, err := s.db.GetUserByID(1) + assert.NoError(s.T(), err) assert.NotNil(s.T(), user) assert.True(s.T(), password.ComparePassword(user.Pass, []byte("old"))) } diff --git a/auth/authentication.go b/auth/authentication.go index b6f48a4..f7acba2 100644 --- a/auth/authentication.go +++ b/auth/authentication.go @@ -14,11 +14,11 @@ const ( // The Database interface for encapsulating database access. type Database interface { - GetApplicationByToken(token string) *model.Application - GetClientByToken(token string) *model.Client - GetPluginConfByToken(token string) *model.PluginConf - GetUserByName(name string) *model.User - GetUserByID(id uint) *model.User + GetApplicationByToken(token string) (*model.Application, error) + GetClientByToken(token string) (*model.Client, error) + GetPluginConfByToken(token string) (*model.PluginConf, error) + GetUserByName(name string) (*model.User, error) + GetUserByID(id uint) (*model.User, error) } // Auth is the provider for authentication middleware @@ -26,46 +26,56 @@ type Auth struct { DB Database } -type authenticate func(tokenID string, user *model.User) (authenticated bool, success bool, userId uint) +type authenticate func(tokenID string, user *model.User) (authenticated bool, success bool, userId uint, err error) // RequireAdmin returns a gin middleware which requires a client token or basic authentication header to be supplied // with the request. Also the authenticated user must be an administrator. func (a *Auth) RequireAdmin() gin.HandlerFunc { - return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint) { + return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) { if user != nil { - return true, user.Admin, user.ID + return true, user.Admin, user.ID, nil } - if token := a.DB.GetClientByToken(tokenID); token != nil { - return true, a.DB.GetUserByID(token.UserID).Admin, token.UserID + if token, err := a.DB.GetClientByToken(tokenID); err != nil { + return false, false, 0, err + } else if token != nil { + user, err := a.DB.GetUserByID(token.UserID) + if err != nil { + return false, false, token.UserID, err + } + return true, user.Admin, token.UserID, nil } - return false, false, 0 + return false, false, 0, nil }) } // RequireClient returns a gin middleware which requires a client token or basic authentication header to be supplied // with the request. func (a *Auth) RequireClient() gin.HandlerFunc { - return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint) { + return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) { if user != nil { - return true, true, user.ID + return true, true, user.ID, nil } - if token := a.DB.GetClientByToken(tokenID); token != nil { - return true, true, token.UserID + if token, err := a.DB.GetClientByToken(tokenID); err != nil { + return false, false, 0, err + } else if token != nil { + return true, true, token.UserID, nil } - return false, false, 0 + return false, false, 0, nil }) } // RequireApplicationToken returns a gin middleware which requires an application token to be supplied with the request. func (a *Auth) RequireApplicationToken() gin.HandlerFunc { - return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint) { + return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) { if user != nil { - return true, false, 0 + return true, false, 0, nil } - if token := a.DB.GetApplicationByToken(tokenID); token != nil { - return true, true, token.UserID + if token, err := a.DB.GetApplicationByToken(tokenID); err != nil { + return false, false, 0, err + } else if token != nil { + return true, true, token.UserID, nil } - return false, false, 0 + return false, false, 0, nil }) } @@ -86,22 +96,32 @@ func (a *Auth) tokenFromHeader(ctx *gin.Context) string { return ctx.Request.Header.Get(headerName) } -func (a *Auth) userFromBasicAuth(ctx *gin.Context) *model.User { +func (a *Auth) userFromBasicAuth(ctx *gin.Context) (*model.User, error) { if name, pass, ok := ctx.Request.BasicAuth(); ok { - if user := a.DB.GetUserByName(name); user != nil && password.ComparePassword(user.Pass, []byte(pass)) { - return user + if user, err := a.DB.GetUserByName(name); err != nil { + return nil, err + } else if user != nil && password.ComparePassword(user.Pass, []byte(pass)) { + return user, nil } } - return nil + return nil, nil } func (a *Auth) requireToken(auth authenticate) gin.HandlerFunc { return func(ctx *gin.Context) { token := a.tokenFromQueryOrHeader(ctx) - user := a.userFromBasicAuth(ctx) + user, err := a.userFromBasicAuth(ctx) + if err != nil { + ctx.AbortWithError(500, errors.New("an error occured while authenticating user")) + return + } if user != nil || token != "" { - if authenticated, ok, userID := auth(token, user); ok { + authenticated, ok, userID, err := auth(token, user) + if err != nil { + ctx.AbortWithError(500, errors.New("an error occured while authenticating user")) + return + } else if ok { RegisterAuthentication(ctx, user, userID, token) ctx.Next() return diff --git a/database/application.go b/database/application.go index 4640774..7fa6528 100644 --- a/database/application.go +++ b/database/application.go @@ -2,26 +2,33 @@ package database import ( "github.com/gotify/server/model" + "github.com/jinzhu/gorm" ) // GetApplicationByToken returns the application for the given token or nil. -func (d *GormDatabase) GetApplicationByToken(token string) *model.Application { +func (d *GormDatabase) GetApplicationByToken(token string) (*model.Application, error) { app := new(model.Application) - d.DB.Where("token = ?", token).Find(app) - if app.Token == token { - return app + err := d.DB.Where("token = ?", token).Find(app).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if app.Token == token { + return app, err + } + return nil, err } // GetApplicationByID returns the application for the given id or nil. -func (d *GormDatabase) GetApplicationByID(id uint) *model.Application { +func (d *GormDatabase) GetApplicationByID(id uint) (*model.Application, error) { app := new(model.Application) - d.DB.Where("id = ?", id).Find(app) - if app.ID == id { - return app + err := d.DB.Where("id = ?", id).Find(app).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if app.ID == id { + return app, err + } + return nil, err } // CreateApplication creates an application. @@ -36,10 +43,13 @@ func (d *GormDatabase) DeleteApplicationByID(id uint) error { } // GetApplicationsByUser returns all applications from a user. -func (d *GormDatabase) GetApplicationsByUser(userID uint) []*model.Application { +func (d *GormDatabase) GetApplicationsByUser(userID uint) ([]*model.Application, error) { var apps []*model.Application - d.DB.Where("user_id = ?", userID).Find(&apps) - return apps + err := d.DB.Where("user_id = ?", userID).Find(&apps).Error + if err == gorm.ErrRecordNotFound { + err = nil + } + return apps, err } // UpdateApplication updates an application. diff --git a/database/application_test.go b/database/application_test.go index 7a72d7b..29e0be4 100644 --- a/database/application_test.go +++ b/database/application_test.go @@ -6,57 +6,87 @@ import ( ) func (s *DatabaseSuite) TestApplication() { - assert.Nil(s.T(), s.db.GetApplicationByToken("asdasdf"), "not existing app") - assert.Nil(s.T(), s.db.GetApplicationByID(uint(1)), "not existing app") + + if app, err := s.db.GetApplicationByToken("asdasdf"); assert.NoError(s.T(), err) { + assert.Nil(s.T(), app, "not existing app") + } + + if app, err := s.db.GetApplicationByID(uint(1)); assert.NoError(s.T(), err) { + assert.Nil(s.T(), app, "not existing app") + } user := &model.User{Name: "test", Pass: []byte{1}} s.db.CreateUser(user) assert.NotEqual(s.T(), 0, user.ID) - apps := s.db.GetApplicationsByUser(user.ID) - assert.Empty(s.T(), apps) + if apps, err := s.db.GetApplicationsByUser(user.ID); assert.NoError(s.T(), err) { + assert.Empty(s.T(), apps) + } app := &model.Application{UserID: user.ID, Token: "C0000000000", Name: "backupserver"} s.db.CreateApplication(app) - apps = s.db.GetApplicationsByUser(user.ID) - assert.Len(s.T(), apps, 1) - assert.Contains(s.T(), apps, app) + if apps, err := s.db.GetApplicationsByUser(user.ID); assert.NoError(s.T(), err) { + assert.Len(s.T(), apps, 1) + assert.Contains(s.T(), apps, app) + } - newApp := s.db.GetApplicationByToken(app.Token) - assert.Equal(s.T(), app, newApp) + newApp, err := s.db.GetApplicationByToken(app.Token) + if assert.NoError(s.T(), err) { + assert.Equal(s.T(), app, newApp) + } - newApp = s.db.GetApplicationByID(app.ID) - assert.Equal(s.T(), app, newApp) + newApp, err = s.db.GetApplicationByID(app.ID) + if assert.NoError(s.T(), err) { + assert.Equal(s.T(), app, newApp) + } newApp.Image = "asdasd" - s.db.UpdateApplication(newApp) + assert.NoError(s.T(), s.db.UpdateApplication(newApp)) - newApp = s.db.GetApplicationByID(app.ID) - assert.Equal(s.T(), "asdasd", newApp.Image) + newApp, err = s.db.GetApplicationByID(app.ID) + if assert.NoError(s.T(), err) { + assert.Equal(s.T(), "asdasd", newApp.Image) + } - s.db.DeleteApplicationByID(app.ID) + assert.NoError(s.T(), s.db.DeleteApplicationByID(app.ID)) - apps = s.db.GetApplicationsByUser(user.ID) - assert.Empty(s.T(), apps) + if apps, err := s.db.GetApplicationsByUser(user.ID); assert.NoError(s.T(), err) { + assert.Empty(s.T(), apps) + } - assert.Nil(s.T(), s.db.GetApplicationByID(app.ID)) + if app, err := s.db.GetApplicationByID(app.ID); assert.NoError(s.T(), err) { + assert.Nil(s.T(), app) + } } func (s *DatabaseSuite) TestDeleteAppDeletesMessages() { - s.db.CreateApplication(&model.Application{ID: 55, Token: "token"}) - s.db.CreateApplication(&model.Application{ID: 66, Token: "token2"}) - s.db.CreateMessage(&model.Message{ID: 12, ApplicationID: 55}) - s.db.CreateMessage(&model.Message{ID: 13, ApplicationID: 66}) - s.db.CreateMessage(&model.Message{ID: 14, ApplicationID: 55}) - s.db.CreateMessage(&model.Message{ID: 15, ApplicationID: 55}) + assert.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 55, Token: "token"})) + assert.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 66, Token: "token2"})) + assert.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 12, ApplicationID: 55})) + assert.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 13, ApplicationID: 66})) + assert.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 14, ApplicationID: 55})) + assert.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 15, ApplicationID: 55})) - s.db.DeleteApplicationByID(55) + assert.NoError(s.T(), s.db.DeleteApplicationByID(55)) - assert.Nil(s.T(), s.db.GetMessageByID(12)) - assert.NotNil(s.T(), s.db.GetMessageByID(13)) - assert.Nil(s.T(), s.db.GetMessageByID(14)) - assert.Nil(s.T(), s.db.GetMessageByID(15)) - assert.Empty(s.T(), s.db.GetMessagesByApplication(55)) - assert.NotEmpty(s.T(), s.db.GetMessagesByApplication(66)) + if msg, err := s.db.GetMessageByID(12); assert.NoError(s.T(), err) { + assert.Nil(s.T(), msg) + } + if msg, err := s.db.GetMessageByID(13); assert.NoError(s.T(), err) { + assert.NotNil(s.T(), msg) + } + if msg, err := s.db.GetMessageByID(14); assert.NoError(s.T(), err) { + assert.Nil(s.T(), msg) + } + if msg, err := s.db.GetMessageByID(15); assert.NoError(s.T(), err) { + assert.Nil(s.T(), msg) + } + + if msgs, err := s.db.GetMessagesByApplication(55); assert.NoError(s.T(), err) { + assert.Empty(s.T(), msgs) + } + if msgs, err := s.db.GetMessagesByApplication(66); assert.NoError(s.T(), err) { + assert.NotEmpty(s.T(), msgs) + } } diff --git a/database/client.go b/database/client.go index cdff19d..a9b186e 100644 --- a/database/client.go +++ b/database/client.go @@ -1,25 +1,34 @@ package database -import "github.com/gotify/server/model" +import ( + "github.com/gotify/server/model" + "github.com/jinzhu/gorm" +) // GetClientByID returns the client for the given id or nil. -func (d *GormDatabase) GetClientByID(id uint) *model.Client { +func (d *GormDatabase) GetClientByID(id uint) (*model.Client, error) { client := new(model.Client) - d.DB.Where("id = ?", id).Find(client) - if client.ID == id { - return client + err := d.DB.Where("id = ?", id).Find(client).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if client.ID == id { + return client, err + } + return nil, err } // GetClientByToken returns the client for the given token or nil. -func (d *GormDatabase) GetClientByToken(token string) *model.Client { +func (d *GormDatabase) GetClientByToken(token string) (*model.Client, error) { client := new(model.Client) - d.DB.Where("token = ?", token).Find(client) - if client.Token == token { - return client + err := d.DB.Where("token = ?", token).Find(client).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if client.Token == token { + return client, err + } + return nil, err } // CreateClient creates a client. @@ -28,10 +37,13 @@ func (d *GormDatabase) CreateClient(client *model.Client) error { } // GetClientsByUser returns all clients from a user. -func (d *GormDatabase) GetClientsByUser(userID uint) []*model.Client { +func (d *GormDatabase) GetClientsByUser(userID uint) ([]*model.Client, error) { var clients []*model.Client - d.DB.Where("user_id = ?", userID).Find(&clients) - return clients + err := d.DB.Where("user_id = ?", userID).Find(&clients).Error + if err == gorm.ErrRecordNotFound { + err = nil + } + return clients, err } // DeleteClientByID deletes a client by its id. diff --git a/database/client_test.go b/database/client_test.go index 6c0aa3e..f343ffb 100644 --- a/database/client_test.go +++ b/database/client_test.go @@ -6,38 +6,51 @@ import ( ) func (s *DatabaseSuite) TestClient() { - assert.Nil(s.T(), s.db.GetClientByID(1), "not existing client") - assert.Nil(s.T(), s.db.GetClientByToken("asdasd"), "not existing client") + if client, err := s.db.GetClientByID(1); assert.NoError(s.T(), err) { + assert.Nil(s.T(), client, "not existing client") + } + if client, err := s.db.GetClientByToken("asdasd"); assert.NoError(s.T(), err) { + assert.Nil(s.T(), client, "not existing client") + } user := &model.User{Name: "test", Pass: []byte{1}} s.db.CreateUser(user) assert.NotEqual(s.T(), 0, user.ID) - clients := s.db.GetClientsByUser(user.ID) - assert.Empty(s.T(), clients) + if clients, err := s.db.GetClientsByUser(user.ID); assert.NoError(s.T(), err) { + assert.Empty(s.T(), clients) + } client := &model.Client{UserID: user.ID, Token: "C0000000000", Name: "android"} - s.db.CreateClient(client) + assert.NoError(s.T(), s.db.CreateClient(client)) - clients = s.db.GetClientsByUser(user.ID) - assert.Len(s.T(), clients, 1) - assert.Contains(s.T(), clients, client) + if clients, err := s.db.GetClientsByUser(user.ID); assert.NoError(s.T(), err) { + assert.Len(s.T(), clients, 1) + assert.Contains(s.T(), clients, client) + } - newClient := s.db.GetClientByID(client.ID) - assert.Equal(s.T(), client, newClient) + newClient, err := s.db.GetClientByID(client.ID) + if assert.NoError(s.T(), err) { + assert.Equal(s.T(), client, newClient) + } - newClient = s.db.GetClientByToken(client.Token) - assert.Equal(s.T(), client, newClient) + if newClient, err := s.db.GetClientByToken(client.Token); assert.NoError(s.T(), err) { + assert.Equal(s.T(), client, newClient) + } updateClient := &model.Client{ID: client.ID, UserID: user.ID, Token: "C0000000000", Name: "new_name"} s.db.UpdateClient(updateClient) - updatedClient := s.db.GetClientByID(client.ID) - assert.Equal(s.T(), updateClient, updatedClient) + if updatedClient, err := s.db.GetClientByID(client.ID); assert.NoError(s.T(), err) { + assert.Equal(s.T(), updateClient, updatedClient) + } s.db.DeleteClientByID(client.ID) - clients = s.db.GetClientsByUser(user.ID) - assert.Empty(s.T(), clients) + if clients, err := s.db.GetClientsByUser(user.ID); assert.NoError(s.T(), err) { + assert.Empty(s.T(), clients) + } - assert.Nil(s.T(), s.db.GetClientByID(client.ID)) + if client, err := s.db.GetClientByID(client.ID); assert.NoError(s.T(), err) { + assert.Nil(s.T(), client) + } } diff --git a/database/message.go b/database/message.go index 99f1906..72e95b9 100644 --- a/database/message.go +++ b/database/message.go @@ -2,16 +2,20 @@ package database import ( "github.com/gotify/server/model" + "github.com/jinzhu/gorm" ) // GetMessageByID returns the messages for the given id or nil. -func (d *GormDatabase) GetMessageByID(id uint) *model.Message { +func (d *GormDatabase) GetMessageByID(id uint) (*model.Message, error) { msg := new(model.Message) - d.DB.Find(msg, id) - if msg.ID == id { - return msg + err := d.DB.Find(msg, id).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if msg.ID == id { + return msg, err + } + return nil, err } // CreateMessage creates a message. @@ -20,43 +24,55 @@ func (d *GormDatabase) CreateMessage(message *model.Message) error { } // GetMessagesByUser returns all messages from a user. -func (d *GormDatabase) GetMessagesByUser(userID uint) []*model.Message { +func (d *GormDatabase) GetMessagesByUser(userID uint) ([]*model.Message, error) { var messages []*model.Message - d.DB.Joins("JOIN applications ON applications.user_id = ?", userID). - Where("messages.application_id = applications.id").Order("id desc").Find(&messages) - return messages + err := d.DB.Joins("JOIN applications ON applications.user_id = ?", userID). + Where("messages.application_id = applications.id").Order("id desc").Find(&messages).Error + if err == gorm.ErrRecordNotFound { + err = nil + } + return messages, err } // GetMessagesByUserSince returns limited messages from a user. // If since is 0 it will be ignored. -func (d *GormDatabase) GetMessagesByUserSince(userID uint, limit int, since uint) []*model.Message { +func (d *GormDatabase) GetMessagesByUserSince(userID uint, limit int, since uint) ([]*model.Message, error) { var messages []*model.Message db := d.DB.Joins("JOIN applications ON applications.user_id = ?", userID). Where("messages.application_id = applications.id").Order("id desc").Limit(limit) if since != 0 { db = db.Where("messages.id < ?", since) } - db.Find(&messages) - return messages + err := db.Find(&messages).Error + if err == gorm.ErrRecordNotFound { + err = nil + } + return messages, err } // GetMessagesByApplication returns all messages from an application. -func (d *GormDatabase) GetMessagesByApplication(tokenID uint) []*model.Message { +func (d *GormDatabase) GetMessagesByApplication(tokenID uint) ([]*model.Message, error) { var messages []*model.Message - d.DB.Where("application_id = ?", tokenID).Order("id desc").Find(&messages) - return messages + err := d.DB.Where("application_id = ?", tokenID).Order("id desc").Find(&messages).Error + if err == gorm.ErrRecordNotFound { + err = nil + } + return messages, err } // GetMessagesByApplicationSince returns limited messages from an application. // If since is 0 it will be ignored. -func (d *GormDatabase) GetMessagesByApplicationSince(appID uint, limit int, since uint) []*model.Message { +func (d *GormDatabase) GetMessagesByApplicationSince(appID uint, limit int, since uint) ([]*model.Message, error) { var messages []*model.Message db := d.DB.Where("application_id = ?", appID).Order("id desc").Limit(limit) if since != 0 { db = db.Where("messages.id < ?", since) } - db.Find(&messages) - return messages + err := db.Find(&messages).Error + if err == gorm.ErrRecordNotFound { + err = nil + } + return messages, err } // DeleteMessageByID deletes a message by its id. @@ -71,7 +87,8 @@ func (d *GormDatabase) DeleteMessagesByApplication(applicationID uint) error { // DeleteMessagesByUser deletes all messages from a user. func (d *GormDatabase) DeleteMessagesByUser(userID uint) error { - for _, app := range d.GetApplicationsByUser(userID) { + app, _ := d.GetApplicationsByUser(userID) + for _, app := range app { d.DeleteMessagesByApplication(app.ID) } return nil diff --git a/database/message_test.go b/database/message_test.go index 8f83bdc..3629060 100644 --- a/database/message_test.go +++ b/database/message_test.go @@ -6,10 +6,13 @@ import ( "github.com/gotify/server/model" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func (s *DatabaseSuite) TestMessage() { - assert.Nil(s.T(), s.db.GetMessageByID(5), "not existing message") + messages, err := s.db.GetMessageByID(5) + require.NoError(s.T(), err) + assert.Nil(s.T(), messages, "not existing message") user := &model.User{Name: "test", Pass: []byte{1}} s.db.CreateUser(user) @@ -19,116 +22,140 @@ func (s *DatabaseSuite) TestMessage() { s.db.CreateApplication(backupServer) assert.NotEqual(s.T(), 0, backupServer.ID) - msgs := s.db.GetMessagesByUser(user.ID) + msgs, err := s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) assert.Empty(s.T(), msgs) - msgs = s.db.GetMessagesByApplication(backupServer.ID) + + msgs, err = s.db.GetMessagesByApplication(backupServer.ID) + require.NoError(s.T(), err) assert.Empty(s.T(), msgs) backupdone := &model.Message{ApplicationID: backupServer.ID, Message: "backup done", Title: "backup", Priority: 1, Date: time.Now()} - s.db.CreateMessage(backupdone) + require.NoError(s.T(), s.db.CreateMessage(backupdone)) assert.NotEqual(s.T(), 0, backupdone.ID) - assertEquals(s.T(), s.db.GetMessageByID(backupdone.ID), backupdone) + messages, err = s.db.GetMessageByID(backupdone.ID) + require.NoError(s.T(), err) + assertEquals(s.T(), messages, backupdone) - msgs = s.db.GetMessagesByUser(user.ID) + msgs, err = s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 1) assertEquals(s.T(), msgs[0], backupdone) - msgs = s.db.GetMessagesByApplication(backupServer.ID) + msgs, err = s.db.GetMessagesByApplication(backupServer.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 1) assertEquals(s.T(), msgs[0], backupdone) loginServer := &model.Application{UserID: user.ID, Token: "A0000000001", Name: "loginserver"} - s.db.CreateApplication(loginServer) + require.NoError(s.T(), s.db.CreateApplication(loginServer)) assert.NotEqual(s.T(), 0, loginServer.ID) logindone := &model.Message{ApplicationID: loginServer.ID, Message: "login done", Title: "login", Priority: 1, Date: time.Now()} - s.db.CreateMessage(logindone) + require.NoError(s.T(), s.db.CreateMessage(logindone)) assert.NotEqual(s.T(), 0, logindone.ID) - msgs = s.db.GetMessagesByUser(user.ID) + msgs, err = s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 2) assertEquals(s.T(), msgs[0], logindone) assertEquals(s.T(), msgs[1], backupdone) - msgs = s.db.GetMessagesByApplication(backupServer.ID) + msgs, err = s.db.GetMessagesByApplication(backupServer.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 1) assertEquals(s.T(), msgs[0], backupdone) loginfailed := &model.Message{ApplicationID: loginServer.ID, Message: "login failed", Title: "login", Priority: 1, Date: time.Now()} - s.db.CreateMessage(loginfailed) + require.NoError(s.T(), s.db.CreateMessage(loginfailed)) assert.NotEqual(s.T(), 0, loginfailed.ID) - msgs = s.db.GetMessagesByApplication(backupServer.ID) + msgs, err = s.db.GetMessagesByApplication(backupServer.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 1) assertEquals(s.T(), msgs[0], backupdone) - msgs = s.db.GetMessagesByApplication(loginServer.ID) + msgs, err = s.db.GetMessagesByApplication(loginServer.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 2) assertEquals(s.T(), msgs[0], loginfailed) assertEquals(s.T(), msgs[1], logindone) - msgs = s.db.GetMessagesByUser(user.ID) + msgs, err = s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 3) assertEquals(s.T(), msgs[0], loginfailed) assertEquals(s.T(), msgs[1], logindone) assertEquals(s.T(), msgs[2], backupdone) backupfailed := &model.Message{ApplicationID: backupServer.ID, Message: "backup failed", Title: "backup", Priority: 1, Date: time.Now()} - s.db.CreateMessage(backupfailed) + require.NoError(s.T(), s.db.CreateMessage(backupfailed)) assert.NotEqual(s.T(), 0, backupfailed.ID) - msgs = s.db.GetMessagesByUser(user.ID) + msgs, err = s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 4) assertEquals(s.T(), msgs[0], backupfailed) assertEquals(s.T(), msgs[1], loginfailed) assertEquals(s.T(), msgs[2], logindone) assertEquals(s.T(), msgs[3], backupdone) - msgs = s.db.GetMessagesByApplication(loginServer.ID) + msgs, err = s.db.GetMessagesByApplication(loginServer.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 2) assertEquals(s.T(), msgs[0], loginfailed) assertEquals(s.T(), msgs[1], logindone) - s.db.DeleteMessagesByApplication(loginServer.ID) - assert.Empty(s.T(), s.db.GetMessagesByApplication(loginServer.ID)) + require.NoError(s.T(), s.db.DeleteMessagesByApplication(loginServer.ID)) + msgs, err = s.db.GetMessagesByApplication(loginServer.ID) + require.NoError(s.T(), err) + assert.Empty(s.T(), msgs) - msgs = s.db.GetMessagesByUser(user.ID) + msgs, err = s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 2) assertEquals(s.T(), msgs[0], backupfailed) assertEquals(s.T(), msgs[1], backupdone) logindone = &model.Message{ApplicationID: loginServer.ID, Message: "login done", Title: "login", Priority: 1, Date: time.Now()} - s.db.CreateMessage(logindone) + require.NoError(s.T(), s.db.CreateMessage(logindone)) assert.NotEqual(s.T(), 0, logindone.ID) - msgs = s.db.GetMessagesByUser(user.ID) + msgs, err = s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 3) assertEquals(s.T(), msgs[0], logindone) assertEquals(s.T(), msgs[1], backupfailed) assertEquals(s.T(), msgs[2], backupdone) s.db.DeleteMessagesByUser(user.ID) - assert.Empty(s.T(), s.db.GetMessagesByUser(user.ID)) + msgs, err = s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) + assert.Empty(s.T(), msgs) logout := &model.Message{ApplicationID: loginServer.ID, Message: "logout success", Title: "logout", Priority: 1, Date: time.Now()} s.db.CreateMessage(logout) - msgs = s.db.GetMessagesByUser(user.ID) + msgs, err = s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) assert.Len(s.T(), msgs, 1) assertEquals(s.T(), msgs[0], logout) - s.db.DeleteMessageByID(logout.ID) - assert.Empty(s.T(), s.db.GetMessagesByUser(user.ID)) + require.NoError(s.T(), s.db.DeleteMessageByID(logout.ID)) + msgs, err = s.db.GetMessagesByUser(user.ID) + require.NoError(s.T(), err) + assert.Empty(s.T(), msgs) + } func (s *DatabaseSuite) TestGetMessagesSince() { user := &model.User{Name: "test", Pass: []byte{1}} - s.db.CreateUser(user) + require.NoError(s.T(), s.db.CreateUser(user)) app := &model.Application{UserID: user.ID, Token: "A0000000000"} app2 := &model.Application{UserID: user.ID, Token: "A0000000001"} - s.db.CreateApplication(app) - s.db.CreateApplication(app2) + require.NoError(s.T(), s.db.CreateApplication(app)) + require.NoError(s.T(), s.db.CreateApplication(app2)) curDate := time.Now() for i := 1; i <= 500; i++ { @@ -136,53 +163,66 @@ func (s *DatabaseSuite) TestGetMessagesSince() { s.db.CreateMessage(&model.Message{ApplicationID: app2.ID, Message: "abc", Date: curDate.Add(time.Duration(i) * time.Second)}) } - actual := s.db.GetMessagesByUserSince(user.ID, 50, 0) + actual, err := s.db.GetMessagesByUserSince(user.ID, 50, 0) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 1000, 951, 1) - actual = s.db.GetMessagesByUserSince(user.ID, 50, 951) + actual, err = s.db.GetMessagesByUserSince(user.ID, 50, 951) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 950, 901, 1) - actual = s.db.GetMessagesByUserSince(user.ID, 100, 951) + actual, err = s.db.GetMessagesByUserSince(user.ID, 100, 951) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 100) hasIDInclusiveBetween(s.T(), actual, 950, 851, 1) - actual = s.db.GetMessagesByUserSince(user.ID, 100, 51) + actual, err = s.db.GetMessagesByUserSince(user.ID, 100, 51) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 50, 1, 1) - actual = s.db.GetMessagesByApplicationSince(app.ID, 50, 0) + actual, err = s.db.GetMessagesByApplicationSince(app.ID, 50, 0) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 999, 901, 2) - actual = s.db.GetMessagesByApplicationSince(app.ID, 50, 901) + actual, err = s.db.GetMessagesByApplicationSince(app.ID, 50, 901) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 899, 801, 2) - actual = s.db.GetMessagesByApplicationSince(app.ID, 100, 666) + actual, err = s.db.GetMessagesByApplicationSince(app.ID, 100, 666) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 100) hasIDInclusiveBetween(s.T(), actual, 665, 467, 2) - actual = s.db.GetMessagesByApplicationSince(app.ID, 100, 101) + actual, err = s.db.GetMessagesByApplicationSince(app.ID, 100, 101) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 99, 1, 2) - actual = s.db.GetMessagesByApplicationSince(app2.ID, 50, 0) + actual, err = s.db.GetMessagesByApplicationSince(app2.ID, 50, 0) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 1000, 902, 2) - actual = s.db.GetMessagesByApplicationSince(app2.ID, 50, 902) + actual, err = s.db.GetMessagesByApplicationSince(app2.ID, 50, 902) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 900, 802, 2) - actual = s.db.GetMessagesByApplicationSince(app2.ID, 100, 667) + actual, err = s.db.GetMessagesByApplicationSince(app2.ID, 100, 667) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 100) hasIDInclusiveBetween(s.T(), actual, 666, 468, 2) - actual = s.db.GetMessagesByApplicationSince(app2.ID, 100, 102) + actual, err = s.db.GetMessagesByApplicationSince(app2.ID, 100, 102) + require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 100, 2, 2) + } func hasIDInclusiveBetween(t *testing.T, msgs []*model.Message, from, to, decrement int) { diff --git a/database/migration_test.go b/database/migration_test.go index a198a2f..43c3c91 100644 --- a/database/migration_test.go +++ b/database/migration_test.go @@ -47,17 +47,25 @@ func (s *MigrationSuite) TestMigration() { assert.True(s.T(), db.DB.HasTable(new(model.Application))) // a user already exist, not adding a new user - assert.Nil(s.T(), db.GetUserByName("admin")) + if user, err := db.GetUserByName("admin"); assert.NoError(s.T(), err) { + assert.Nil(s.T(), user) + } // the old user should persist - assert.Equal(s.T(), true, db.GetUserByName("test_user").Admin) + if user, err := db.GetUserByName("test_user"); assert.NoError(s.T(), err) { + assert.Equal(s.T(), true, user.Admin) + } // we should be able to create applications - assert.Nil(s.T(), db.CreateApplication(&model.Application{ - Token: "A1234", - UserID: db.GetUserByName("test_user").ID, - Description: "this is a test application", - Name: "test application", - })) - assert.Equal(s.T(), "test application", db.GetApplicationByToken("A1234").Name) + if user, err := db.GetUserByName("test_user"); assert.NoError(s.T(), err) { + assert.Nil(s.T(), db.CreateApplication(&model.Application{ + Token: "A1234", + UserID: user.ID, + Description: "this is a test application", + Name: "test application", + })) + } + if app, err := db.GetApplicationByToken("A1234"); assert.NoError(s.T(), err) { + assert.Equal(s.T(), "test application", app.Name) + } } diff --git a/database/plugin.go b/database/plugin.go index 689b194..bf5b5a5 100644 --- a/database/plugin.go +++ b/database/plugin.go @@ -2,33 +2,43 @@ package database import ( "github.com/gotify/server/model" + "github.com/jinzhu/gorm" ) // GetPluginConfByUser gets plugin configurations from a user -func (d *GormDatabase) GetPluginConfByUser(userid uint) []*model.PluginConf { +func (d *GormDatabase) GetPluginConfByUser(userid uint) ([]*model.PluginConf, error) { var plugins []*model.PluginConf - d.DB.Where("user_id = ?", userid).Find(&plugins) - return plugins + err := d.DB.Where("user_id = ?", userid).Find(&plugins).Error + if err == gorm.ErrRecordNotFound { + err = nil + } + return plugins, err } // GetPluginConfByUserAndPath gets plugin configuration by user and file name -func (d *GormDatabase) GetPluginConfByUserAndPath(userid uint, path string) *model.PluginConf { +func (d *GormDatabase) GetPluginConfByUserAndPath(userid uint, path string) (*model.PluginConf, error) { plugin := new(model.PluginConf) - d.DB.Where("user_id = ? AND module_path = ?", userid, path).First(plugin) - if plugin.ModulePath == path { - return plugin + err := d.DB.Where("user_id = ? AND module_path = ?", userid, path).First(plugin).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if plugin.ModulePath == path { + return plugin, err + } + return nil, err } // GetPluginConfByApplicationID gets plugin configuration by its internal appid. -func (d *GormDatabase) GetPluginConfByApplicationID(appid uint) *model.PluginConf { +func (d *GormDatabase) GetPluginConfByApplicationID(appid uint) (*model.PluginConf, error) { plugin := new(model.PluginConf) - d.DB.Where("application_id = ?", appid).First(plugin) - if plugin.ApplicationID == appid { - return plugin + err := d.DB.Where("application_id = ?", appid).First(plugin).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if plugin.ApplicationID == appid { + return plugin, err + } + return nil, err } // CreatePluginConf creates a new plugin configuration @@ -37,23 +47,29 @@ func (d *GormDatabase) CreatePluginConf(p *model.PluginConf) error { } // GetPluginConfByToken gets plugin configuration by plugin token -func (d *GormDatabase) GetPluginConfByToken(token string) *model.PluginConf { +func (d *GormDatabase) GetPluginConfByToken(token string) (*model.PluginConf, error) { plugin := new(model.PluginConf) - d.DB.Where("token = ?", token).First(plugin) - if plugin.Token == token { - return plugin + err := d.DB.Where("token = ?", token).First(plugin).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if plugin.Token == token { + return plugin, err + } + return nil, err } // GetPluginConfByID gets plugin configuration by plugin ID -func (d *GormDatabase) GetPluginConfByID(id uint) *model.PluginConf { +func (d *GormDatabase) GetPluginConfByID(id uint) (*model.PluginConf, error) { plugin := new(model.PluginConf) - d.DB.Where("id = ?", id).First(plugin) - if plugin.ID == id { - return plugin + err := d.DB.Where("id = ?", id).First(plugin).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if plugin.ID == id { + return plugin, err + } + return nil, err } // UpdatePluginConf updates plugin configuration diff --git a/database/plugin_test.go b/database/plugin_test.go index 516b3d0..9e56b20 100644 --- a/database/plugin_test.go +++ b/database/plugin_test.go @@ -3,6 +3,7 @@ package database import ( "github.com/gotify/server/model" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func (s *DatabaseSuite) TestPluginConf() { @@ -18,23 +19,53 @@ func (s *DatabaseSuite) TestPluginConf() { assert.Nil(s.T(), s.db.CreatePluginConf(&plugin)) assert.Equal(s.T(), uint(1), plugin.ID) - assert.Equal(s.T(), "Pabc", s.db.GetPluginConfByUserAndPath(1, "github.com/gotify/example-plugin").Token) - assert.Equal(s.T(), true, s.db.GetPluginConfByToken("Pabc").Enabled) - assert.Equal(s.T(), "Pabc", s.db.GetPluginConfByApplicationID(2).Token) - assert.Equal(s.T(), "github.com/gotify/example-plugin", s.db.GetPluginConfByID(1).ModulePath) + pluginConf, err := s.db.GetPluginConfByUserAndPath(1, "github.com/gotify/example-plugin") + require.NoError(s.T(), err) + assert.Equal(s.T(), "Pabc", pluginConf.Token) - assert.Nil(s.T(), s.db.GetPluginConfByToken("Pnotexist")) - assert.Nil(s.T(), s.db.GetPluginConfByID(12)) - assert.Nil(s.T(), s.db.GetPluginConfByUserAndPath(1, "not/exist")) - assert.Nil(s.T(), s.db.GetPluginConfByApplicationID(99)) + pluginConf, err = s.db.GetPluginConfByToken("Pabc") + require.NoError(s.T(), err) + assert.Equal(s.T(), true, pluginConf.Enabled) - assert.Len(s.T(), s.db.GetPluginConfByUser(1), 1) - assert.Len(s.T(), s.db.GetPluginConfByUser(0), 0) + pluginConf, err = s.db.GetPluginConfByApplicationID(2) + require.NoError(s.T(), err) + assert.Equal(s.T(), "Pabc", pluginConf.Token) + + pluginConf, err = s.db.GetPluginConfByID(1) + require.NoError(s.T(), err) + assert.Equal(s.T(), "github.com/gotify/example-plugin", pluginConf.ModulePath) + + pluginConf, err = s.db.GetPluginConfByToken("Pnotexist") + require.NoError(s.T(), err) + assert.Nil(s.T(), pluginConf) + + pluginConf, err = s.db.GetPluginConfByID(12) + require.NoError(s.T(), err) + assert.Nil(s.T(), pluginConf) + + pluginConf, err = s.db.GetPluginConfByUserAndPath(1, "not/exist") + require.NoError(s.T(), err) + assert.Nil(s.T(), pluginConf) + + pluginConf, err = s.db.GetPluginConfByApplicationID(99) + require.NoError(s.T(), err) + assert.Nil(s.T(), pluginConf) + + pluginConfs, err := s.db.GetPluginConfByUser(1) + require.NoError(s.T(), err) + assert.Len(s.T(), pluginConfs, 1) + + pluginConfs, err = s.db.GetPluginConfByUser(0) + require.NoError(s.T(), err) + assert.Len(s.T(), pluginConfs, 0) testConf := `{"test_config_key":"hello"}` plugin.Enabled = false plugin.Config = []byte(testConf) assert.Nil(s.T(), s.db.UpdatePluginConf(&plugin)) - assert.Equal(s.T(), false, s.db.GetPluginConfByToken("Pabc").Enabled) - assert.Equal(s.T(), testConf, string(s.db.GetPluginConfByToken("Pabc").Config)) + pluginConf, err = s.db.GetPluginConfByToken("Pabc") + require.NoError(s.T(), err) + assert.Equal(s.T(), false, pluginConf.Enabled) + assert.Equal(s.T(), testConf, string(pluginConf.Config)) + } diff --git a/database/user.go b/database/user.go index 961a957..3991863 100644 --- a/database/user.go +++ b/database/user.go @@ -2,30 +2,37 @@ package database import ( "github.com/gotify/server/model" + "github.com/jinzhu/gorm" ) // GetUserByName returns the user by the given name or nil. -func (d *GormDatabase) GetUserByName(name string) *model.User { +func (d *GormDatabase) GetUserByName(name string) (*model.User, error) { user := new(model.User) - d.DB.Where("name = ?", name).Find(user) - if user.Name == name { - return user + err := d.DB.Where("name = ?", name).Find(user).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if user.Name == name { + return user, err + } + return nil, err } // GetUserByID returns the user by the given id or nil. -func (d *GormDatabase) GetUserByID(id uint) *model.User { +func (d *GormDatabase) GetUserByID(id uint) (*model.User, error) { user := new(model.User) - d.DB.Find(user, id) - if user.ID == id { - return user + err := d.DB.Find(user, id).Error + if err == gorm.ErrRecordNotFound { + err = nil } - return nil + if user.ID == id { + return user, err + } + return nil, err } // CountUser returns the user count which satisfies the given condition. -func (d *GormDatabase) CountUser(condition ...interface{}) int { +func (d *GormDatabase) CountUser(condition ...interface{}) (int, error) { c := -1 handle := d.DB.Model(new(model.User)) if len(condition) == 1 { @@ -33,34 +40,37 @@ func (d *GormDatabase) CountUser(condition ...interface{}) int { } else if len(condition) > 1 { handle = handle.Where(condition[0], condition[1:]...) } - handle.Count(&c) - return c + err := handle.Count(&c).Error + return c, err } // GetUsers returns all users. -func (d *GormDatabase) GetUsers() []*model.User { +func (d *GormDatabase) GetUsers() ([]*model.User, error) { var users []*model.User - d.DB.Find(&users) - return users + err := d.DB.Find(&users).Error + return users, err } // DeleteUserByID deletes a user by its id. func (d *GormDatabase) DeleteUserByID(id uint) error { - for _, app := range d.GetApplicationsByUser(id) { + apps, _ := d.GetApplicationsByUser(id) + for _, app := range apps { d.DeleteApplicationByID(app.ID) } - for _, client := range d.GetClientsByUser(id) { + clients, _ := d.GetClientsByUser(id) + for _, client := range clients { d.DeleteClientByID(client.ID) } - for _, conf := range d.GetPluginConfByUser(id) { + pluginConfs, _ := d.GetPluginConfByUser(id) + for _, conf := range pluginConfs { d.DeletePluginConfByID(conf.ID) } return d.DB.Where("id = ?", id).Delete(&model.User{}).Error } // UpdateUser updates a user. -func (d *GormDatabase) UpdateUser(user *model.User) { - d.DB.Save(user) +func (d *GormDatabase) UpdateUser(user *model.User) error { + return d.DB.Save(user).Error } // CreateUser creates a user. diff --git a/database/user_test.go b/database/user_test.go index 2b1c601..3c96794 100644 --- a/database/user_test.go +++ b/database/user_test.go @@ -3,28 +3,44 @@ package database import ( "github.com/gotify/server/model" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func (s *DatabaseSuite) TestUser() { - assert.Nil(s.T(), s.db.GetUserByID(55), "not existing user") - assert.Nil(s.T(), s.db.GetUserByName("nicories"), "not existing user") + user, err := s.db.GetUserByID(55) + require.NoError(s.T(), err) + assert.Nil(s.T(), user, "not existing user") - jmattheis := s.db.GetUserByID(1) + user, err = s.db.GetUserByName("nicories") + require.NoError(s.T(), err) + assert.Nil(s.T(), user, "not existing user") + + jmattheis, err := s.db.GetUserByID(1) + require.NoError(s.T(), err) assert.NotNil(s.T(), jmattheis, "on bootup the first user should be automatically created") - assert.Equal(s.T(), 1, s.db.CountUser("admin = ?", true), 1, "there is initially one admin") - users := s.db.GetUsers() + adminCount, err := s.db.CountUser("admin = ?", true) + require.NoError(s.T(), err) + assert.Equal(s.T(), 1, adminCount, 1, "there is initially one admin") + + users, err := s.db.GetUsers() + require.NoError(s.T(), err) assert.Len(s.T(), users, 1) assert.Contains(s.T(), users, jmattheis) nicories := &model.User{Name: "nicories", Pass: []byte{1, 2, 3, 4}, Admin: false} s.db.CreateUser(nicories) assert.NotEqual(s.T(), 0, nicories.ID, "on create user a new id should be assigned") - assert.Equal(s.T(), 2, s.db.CountUser(), "two users should exist") + userCount, err := s.db.CountUser() + require.NoError(s.T(), err) + assert.Equal(s.T(), 2, userCount, "two users should exist") - assert.Equal(s.T(), nicories, s.db.GetUserByName("nicories")) + user, err = s.db.GetUserByName("nicories") + require.NoError(s.T(), err) + assert.Equal(s.T(), nicories, user) - users = s.db.GetUsers() + users, err = s.db.GetUsers() + require.NoError(s.T(), err) assert.Len(s.T(), users, 2) assert.Contains(s.T(), users, jmattheis) assert.Contains(s.T(), users, nicories) @@ -32,73 +48,138 @@ func (s *DatabaseSuite) TestUser() { nicories.Name = "tom" nicories.Pass = []byte{12} nicories.Admin = true - s.db.UpdateUser(nicories) - tom := s.db.GetUserByID(nicories.ID) - assert.Equal(s.T(), &model.User{ID: nicories.ID, Name: "tom", Pass: []byte{12}, Admin: true}, tom) - users = s.db.GetUsers() - assert.Len(s.T(), users, 2) - assert.Equal(s.T(), 2, s.db.CountUser(&model.User{Admin: true}), "two admins exist") + require.NoError(s.T(), s.db.UpdateUser(nicories)) - s.db.DeleteUserByID(tom.ID) - users = s.db.GetUsers() + tom, err := s.db.GetUserByID(nicories.ID) + require.NoError(s.T(), err) + assert.Equal(s.T(), &model.User{ID: nicories.ID, Name: "tom", Pass: []byte{12}, Admin: true}, tom) + + users, err = s.db.GetUsers() + require.NoError(s.T(), err) + assert.Len(s.T(), users, 2) + + adminCount, err = s.db.CountUser(&model.User{Admin: true}) + require.NoError(s.T(), err) + assert.Equal(s.T(), 2, adminCount, "two admins exist") + + require.NoError(s.T(), s.db.DeleteUserByID(tom.ID)) + users, err = s.db.GetUsers() + require.NoError(s.T(), err) assert.Len(s.T(), users, 1) assert.Contains(s.T(), users, jmattheis) s.db.DeleteUserByID(jmattheis.ID) - users = s.db.GetUsers() + users, err = s.db.GetUsers() + require.NoError(s.T(), err) assert.Empty(s.T(), users) + } func (s *DatabaseSuite) TestUserPlugins() { - s.db.CreateUser(&model.User{Name: "geek", ID: 16}) - s.db.CreatePluginConf(&model.PluginConf{ - UserID: s.db.GetUserByName("geek").ID, - ModulePath: "github.com/gotify/example-plugin", - Token: "P1234", - Enabled: true, - }) - s.db.CreatePluginConf(&model.PluginConf{ - UserID: s.db.GetUserByName("geek").ID, - ModulePath: "github.com/gotify/example-plugin/v2", - Token: "P5678", - Enabled: true, - }) + assert.NoError(s.T(), s.db.CreateUser(&model.User{Name: "geek", ID: 16})) + if geekUser, err := s.db.GetUserByName("geek"); assert.NoError(s.T(), err) { + s.db.CreatePluginConf(&model.PluginConf{ + UserID: geekUser.ID, + ModulePath: "github.com/gotify/example-plugin", + Token: "P1234", + Enabled: true, + }) + s.db.CreatePluginConf(&model.PluginConf{ + UserID: geekUser.ID, + ModulePath: "github.com/gotify/example-plugin/v2", + Token: "P5678", + Enabled: true, + }) + } - assert.Len(s.T(), s.db.GetPluginConfByUser(s.db.GetUserByName("geek").ID), 2) - assert.Equal(s.T(), "github.com/gotify/example-plugin", s.db.GetPluginConfByToken("P1234").ModulePath) + if geekUser, err := s.db.GetUserByName("geek"); assert.NoError(s.T(), err) { + if pluginConfs, err := s.db.GetPluginConfByUser(geekUser.ID); assert.NoError(s.T(), err) { + assert.Len(s.T(), pluginConfs, 2) + } + } + if pluginConf, err := s.db.GetPluginConfByToken("P1234"); assert.NoError(s.T(), err) { + assert.Equal(s.T(), "github.com/gotify/example-plugin", pluginConf.ModulePath) + } } func (s *DatabaseSuite) TestDeleteUserDeletesApplicationsAndClientsAndPluginConfs() { - s.db.CreateUser(&model.User{Name: "nicories", ID: 10}) - s.db.CreateApplication(&model.Application{ID: 100, Token: "apptoken", UserID: 10}) - s.db.CreateMessage(&model.Message{ID: 1000, ApplicationID: 100}) - s.db.CreateClient(&model.Client{ID: 10000, Token: "clienttoken", UserID: 10}) - s.db.CreatePluginConf(&model.PluginConf{ID: 1000, Token: "plugintoken", UserID: 10}) + require.NoError(s.T(), s.db.CreateUser(&model.User{Name: "nicories", ID: 10})) + require.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 100, Token: "apptoken", UserID: 10})) + require.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 1000, ApplicationID: 100})) + require.NoError(s.T(), s.db.CreateClient(&model.Client{ID: 10000, Token: "clienttoken", UserID: 10})) + require.NoError(s.T(), s.db.CreatePluginConf(&model.PluginConf{ID: 1000, Token: "plugintoken", UserID: 10})) - s.db.CreateUser(&model.User{Name: "nicories2", ID: 20}) - s.db.CreateApplication(&model.Application{ID: 200, Token: "apptoken2", UserID: 20}) - s.db.CreateMessage(&model.Message{ID: 2000, ApplicationID: 200}) - s.db.CreateClient(&model.Client{ID: 20000, Token: "clienttoken2", UserID: 20}) - s.db.CreatePluginConf(&model.PluginConf{ID: 2000, Token: "plugintoken2", UserID: 20}) + require.NoError(s.T(), s.db.CreateUser(&model.User{Name: "nicories2", ID: 20})) + require.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 200, Token: "apptoken2", UserID: 20})) + require.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 2000, ApplicationID: 200})) + require.NoError(s.T(), s.db.CreateClient(&model.Client{ID: 20000, Token: "clienttoken2", UserID: 20})) + require.NoError(s.T(), s.db.CreatePluginConf(&model.PluginConf{ID: 2000, Token: "plugintoken2", UserID: 20})) - s.db.DeleteUserByID(10) + require.NoError(s.T(), s.db.DeleteUserByID(10)) - assert.Nil(s.T(), s.db.GetApplicationByToken("apptoken")) - assert.Nil(s.T(), s.db.GetClientByToken("clienttoken")) - assert.Empty(s.T(), s.db.GetClientsByUser(10)) - assert.Empty(s.T(), s.db.GetApplicationsByUser(10)) - assert.Empty(s.T(), s.db.GetMessagesByApplication(100)) - assert.Empty(s.T(), s.db.GetMessagesByUser(10)) - assert.Empty(s.T(), s.db.GetPluginConfByUser(10)) - assert.Nil(s.T(), s.db.GetMessageByID(1000)) + app, err := s.db.GetApplicationByToken("apptoken") + require.NoError(s.T(), err) + assert.Nil(s.T(), app) + + client, err := s.db.GetClientByToken("clienttoken") + require.NoError(s.T(), err) + assert.Nil(s.T(), client) + + clients, err := s.db.GetClientsByUser(10) + require.NoError(s.T(), err) + assert.Empty(s.T(), clients) + + apps, err := s.db.GetApplicationsByUser(10) + require.NoError(s.T(), err) + assert.Empty(s.T(), apps) + + msgs, err := s.db.GetMessagesByApplication(100) + require.NoError(s.T(), err) + assert.Empty(s.T(), msgs) + + msgs, err = s.db.GetMessagesByUser(10) + require.NoError(s.T(), err) + assert.Empty(s.T(), msgs) + + pluginConfs, err := s.db.GetPluginConfByUser(10) + require.NoError(s.T(), err) + assert.Empty(s.T(), pluginConfs) + + msg, err := s.db.GetMessageByID(1000) + require.NoError(s.T(), err) + assert.Nil(s.T(), msg) + + app, err = s.db.GetApplicationByToken("apptoken2") + require.NoError(s.T(), err) + assert.NotNil(s.T(), app) + + client, err = s.db.GetClientByToken("clienttoken2") + require.NoError(s.T(), err) + assert.NotNil(s.T(), client) + + clients, err = s.db.GetClientsByUser(20) + require.NoError(s.T(), err) + assert.NotEmpty(s.T(), clients) + + apps, err = s.db.GetApplicationsByUser(20) + require.NoError(s.T(), err) + assert.NotEmpty(s.T(), apps) + + pluginConf, err := s.db.GetPluginConfByUser(20) + require.NoError(s.T(), err) + assert.NotEmpty(s.T(), pluginConf) + + msgs, err = s.db.GetMessagesByApplication(200) + require.NoError(s.T(), err) + assert.NotEmpty(s.T(), msgs) + + msgs, err = s.db.GetMessagesByUser(20) + require.NoError(s.T(), err) + assert.NotEmpty(s.T(), msgs) + + msg, err = s.db.GetMessageByID(2000) + require.NoError(s.T(), err) + assert.NotNil(s.T(), msg) - assert.NotNil(s.T(), s.db.GetApplicationByToken("apptoken2")) - assert.NotNil(s.T(), s.db.GetClientByToken("clienttoken2")) - assert.NotEmpty(s.T(), s.db.GetClientsByUser(20)) - assert.NotEmpty(s.T(), s.db.GetApplicationsByUser(20)) - assert.NotEmpty(s.T(), s.db.GetPluginConfByUser(20)) - assert.NotEmpty(s.T(), s.db.GetMessagesByApplication(200)) - assert.NotEmpty(s.T(), s.db.GetMessagesByUser(20)) - assert.NotNil(s.T(), s.db.GetMessageByID(2000)) } diff --git a/plugin/manager.go b/plugin/manager.go index ee786d5..fc70a7b 100644 --- a/plugin/manager.go +++ b/plugin/manager.go @@ -23,19 +23,19 @@ import ( // The Database interface for encapsulating database access. type Database interface { - GetUsers() []*model.User - GetPluginConfByUserAndPath(userid uint, path string) *model.PluginConf + GetUsers() ([]*model.User, error) + GetPluginConfByUserAndPath(userid uint, path string) (*model.PluginConf, error) CreatePluginConf(p *model.PluginConf) error - GetPluginConfByApplicationID(appid uint) *model.PluginConf + GetPluginConfByApplicationID(appid uint) (*model.PluginConf, error) UpdatePluginConf(p *model.PluginConf) error CreateMessage(message *model.Message) error - GetPluginConfByID(id uint) *model.PluginConf - GetPluginConfByToken(token string) *model.PluginConf - GetUserByID(id uint) *model.User + GetPluginConfByID(id uint) (*model.PluginConf, error) + GetPluginConfByToken(token string) (*model.PluginConf, error) + GetUserByID(id uint) (*model.User, error) CreateApplication(application *model.Application) error UpdateApplication(app *model.Application) error - GetApplicationsByUser(userID uint) []*model.Application - GetApplicationByToken(token string) *model.Application + GetApplicationsByUser(userID uint) ([]*model.Application, error) + GetApplicationByToken(token string) (*model.Application, error) } // Notifier notifies when a new message was created. @@ -87,7 +87,11 @@ func NewManager(db Database, directory string, mux *gin.RouterGroup, notifier No return nil, err } - for _, user := range manager.db.GetUsers() { + users, err := manager.db.GetUsers() + if err != nil { + return nil, err + } + for _, user := range users { if err := manager.initializeForUser(*user); err != nil { return nil, err } @@ -100,11 +104,13 @@ func NewManager(db Database, directory string, mux *gin.RouterGroup, notifier No var ErrAlreadyEnabledOrDisabled = errors.New("config is already enabled/disabled") func (m *Manager) applicationExists(token string) bool { - return m.db.GetApplicationByToken(token) != nil + app, _ := m.db.GetApplicationByToken(token) + return app != nil } func (m *Manager) pluginConfExists(token string) bool { - return m.db.GetPluginConfByToken(token) != nil + pluginConf, _ := m.db.GetPluginConfByToken(token) + return pluginConf != nil } // SetPluginEnabled sets the plugins enabled state. @@ -113,7 +119,10 @@ func (m *Manager) SetPluginEnabled(pluginID uint, enabled bool) error { if err != nil { return errors.New("instance not found") } - conf := m.db.GetPluginConfByID(pluginID) + conf, err := m.db.GetPluginConfByID(pluginID) + if err != nil { + return err + } if conf.Enabled == enabled { return ErrAlreadyEnabledOrDisabled @@ -131,7 +140,9 @@ func (m *Manager) SetPluginEnabled(pluginID uint, enabled bool) error { return err } - conf = m.db.GetPluginConfByID(pluginID) // conf might be updated by instance + if newConf, err := m.db.GetPluginConfByID(pluginID); /* conf might be updated by instance */ err == nil { + conf = newConf + } conf.Enabled = enabled return m.db.UpdatePluginConf(conf) } @@ -172,7 +183,13 @@ func (m *Manager) HasInstance(pluginID uint) bool { // RemoveUser disabled all plugins of a user when the user is disabled func (m *Manager) RemoveUser(userID uint) error { for _, p := range m.plugins { - pluginConf := m.db.GetPluginConfByUserAndPath(userID, p.PluginInfo().ModulePath) + pluginConf, err := m.db.GetPluginConfByUserAndPath(userID, p.PluginInfo().ModulePath) + if err != nil { + return err + } + if pluginConf == nil { + continue + } if pluginConf.Enabled { inst, err := m.Instance(pluginConf.ID) if err != nil { @@ -240,7 +257,10 @@ func (m *Manager) InitializeForUserID(userID uint) error { m.mutex.Lock() defer m.mutex.Unlock() - user := m.db.GetUserByID(userID) + user, err := m.db.GetUserByID(userID) + if err != nil { + return err + } if user != nil { return m.initializeForUser(*user) } @@ -261,8 +281,16 @@ func (m *Manager) initializeForUser(user model.User) error { } } - for _, app := range m.db.GetApplicationsByUser(user.ID) { - if conf := m.db.GetPluginConfByApplicationID(app.ID); conf != nil { + apps, err := m.db.GetApplicationsByUser(user.ID) + if err != nil { + return err + } + for _, app := range apps { + conf, err := m.db.GetPluginConfByApplicationID(app.ID) + if err != nil { + return err + } + if conf != nil { _, compatExist := m.plugins[conf.ModulePath] app.Internal = compatExist } else { @@ -279,7 +307,10 @@ func (m *Manager) initializeSingleUserPlugin(userCtx compat.UserContext, p compa instance := p.NewPluginInstance(userCtx) userID := userCtx.ID - pluginConf := m.db.GetPluginConfByUserAndPath(userID, info.ModulePath) + pluginConf, err := m.db.GetPluginConfByUserAndPath(userID, info.ModulePath) + if err != nil { + return err + } if pluginConf == nil { var err error diff --git a/plugin/manager_test.go b/plugin/manager_test.go index c989481..2c7dbab 100644 --- a/plugin/manager_test.go +++ b/plugin/manager_test.go @@ -85,7 +85,9 @@ func (s *ManagerSuite) SetupSuite() { s.msgReceiver = make(chan MessageWithUserID) assert.Contains(s.T(), s.manager.plugins, examplePluginPath) - assert.NotNil(s.T(), s.db.GetPluginConfByUserAndPath(1, examplePluginPath)) + if pluginConf, err := s.db.GetPluginConfByUserAndPath(1, examplePluginPath); assert.NoError(s.T(), err) { + assert.NotNil(s.T(), pluginConf) + } } func (s *ManagerSuite) TearDownSuite() { @@ -93,11 +95,16 @@ func (s *ManagerSuite) TearDownSuite() { } func (s *ManagerSuite) getConfForExamplePlugin(uid uint) *model.PluginConf { - return s.db.GetPluginConfByUserAndPath(uid, examplePluginPath) + pluginConf, err := s.db.GetPluginConfByUserAndPath(uid, examplePluginPath) + assert.NoError(s.T(), err) + return pluginConf + } func (s *ManagerSuite) getConfForMockPlugin(uid uint) *model.PluginConf { - return s.db.GetPluginConfByUserAndPath(uid, mockPluginPath) + pluginConf, err := s.db.GetPluginConfByUserAndPath(uid, mockPluginPath) + assert.NoError(s.T(), err) + return pluginConf } func (s *ManagerSuite) getMockPluginInstance(uid uint) *mock.PluginInstance { @@ -394,53 +401,69 @@ func TestNewManager_InternalApplicationManagement(t *testing.T) { UserID: 1, }) - assert.True(t, db.GetApplicationByToken("Ainternal_obsolete").Internal) + if app, err := db.GetApplicationByToken("Ainternal_obsolete"); assert.NoError(t, err) { + assert.True(t, app.Internal) + } _, err := NewManager(db, "", nil, nil) assert.Nil(t, err) - assert.False(t, db.GetApplicationByToken("Ainternal_obsolete").Internal) + if app, err := db.GetApplicationByToken("Ainternal_obsolete"); assert.NoError(t, err) { + assert.False(t, app.Internal) + } } { // Application exist, conf exist, no compat - db.CreateApplication(&model.Application{ + assert.NoError(t, db.CreateApplication(&model.Application{ Token: "Ainternal_not_loaded", Internal: true, Name: "not loaded plugin application", UserID: 1, - }) - db.CreatePluginConf(&model.PluginConf{ - ApplicationID: db.GetApplicationByToken("Ainternal_not_loaded").ID, - UserID: 1, - Enabled: true, - Token: auth.GeneratePluginToken(), - }) + })) + if app, err := db.GetApplicationByToken("Ainternal_not_loaded"); assert.NoError(t, err) { + assert.NoError(t, db.CreatePluginConf(&model.PluginConf{ + ApplicationID: app.ID, + UserID: 1, + Enabled: true, + Token: auth.GeneratePluginToken(), + })) + } - assert.True(t, db.GetApplicationByToken("Ainternal_not_loaded").Internal) + if app, err := db.GetApplicationByToken("Ainternal_not_loaded"); assert.NoError(t, err) { + assert.True(t, app.Internal) + } _, err := NewManager(db, "", nil, nil) assert.Nil(t, err) - assert.False(t, db.GetApplicationByToken("Ainternal_not_loaded").Internal) + if app, err := db.GetApplicationByToken("Ainternal_not_loaded"); assert.NoError(t, err) { + assert.False(t, app.Internal) + } } { // Application exist, conf exist, has compat - db.CreateApplication(&model.Application{ + assert.NoError(t, db.CreateApplication(&model.Application{ Token: "Ainternal_loaded", Internal: false, Name: "not loaded plugin application", UserID: 1, - }) - db.CreatePluginConf(&model.PluginConf{ - ApplicationID: db.GetApplicationByToken("Ainternal_loaded").ID, - UserID: 1, - Enabled: true, - ModulePath: mock.ModulePath, - Token: auth.GeneratePluginToken(), - }) + })) + if app, err := db.GetApplicationByToken("Ainternal_loaded"); assert.NoError(t, err) { + assert.NoError(t, db.CreatePluginConf(&model.PluginConf{ + ApplicationID: app.ID, + UserID: 1, + Enabled: true, + ModulePath: mock.ModulePath, + Token: auth.GeneratePluginToken(), + })) + } - assert.False(t, db.GetApplicationByToken("Ainternal_loaded").Internal) + if app, err := db.GetApplicationByToken("Ainternal_loaded"); assert.NoError(t, err) { + assert.False(t, app.Internal) + } manager, err := NewManager(db, "", nil, nil) assert.Nil(t, err) assert.Nil(t, manager.LoadPlugin(new(mock.Plugin))) assert.Nil(t, manager.InitializeForUserID(1)) - assert.True(t, db.GetApplicationByToken("Ainternal_loaded").Internal) + if app, err := db.GetApplicationByToken("Ainternal_loaded"); assert.NoError(t, err) { + assert.True(t, app.Internal) + } } } diff --git a/plugin/pluginenabled.go b/plugin/pluginenabled.go index 826c49b..8e7365a 100644 --- a/plugin/pluginenabled.go +++ b/plugin/pluginenabled.go @@ -8,7 +8,12 @@ import ( func requirePluginEnabled(id uint, db Database) gin.HandlerFunc { return func(c *gin.Context) { - if conf := db.GetPluginConfByID(id); conf == nil || !conf.Enabled { + conf, err := db.GetPluginConfByID(id) + if err != nil { + c.AbortWithError(500, err) + return + } + if conf == nil || !conf.Enabled { c.AbortWithError(400, errors.New("plugin is disabled")) } } diff --git a/plugin/storagehandler.go b/plugin/storagehandler.go index 00b7e2f..372d369 100644 --- a/plugin/storagehandler.go +++ b/plugin/storagehandler.go @@ -6,11 +6,18 @@ type dbStorageHandler struct { } func (c dbStorageHandler) Save(b []byte) error { - conf := c.db.GetPluginConfByID(c.pluginID) + conf, err := c.db.GetPluginConfByID(c.pluginID) + if err != nil { + return err + } conf.Storage = b return c.db.UpdatePluginConf(conf) } func (c dbStorageHandler) Load() ([]byte, error) { - return c.db.GetPluginConfByID(c.pluginID).Storage, nil + pluginConf, err := c.db.GetPluginConfByID(c.pluginID) + if err != nil { + return nil, err + } + return pluginConf.Storage, nil } diff --git a/test/testdb/database.go b/test/testdb/database.go index 7a80b29..d7b2dd2 100644 --- a/test/testdb/database.go +++ b/test/testdb/database.go @@ -171,42 +171,58 @@ func (mb *MessageBuilder) NewMessage(id uint) model.Message { // AssertAppNotExist asserts that the app does not exist. func (d *Database) AssertAppNotExist(id uint) { - assert.True(d.t, d.GetApplicationByID(id) == nil, "app %d must not exist", id) + if app, err := d.GetApplicationByID(id); assert.NoError(d.t, err) { + assert.True(d.t, app == nil, "app %d must not exist", id) + } } // AssertUserNotExist asserts that the user does not exist. func (d *Database) AssertUserNotExist(id uint) { - assert.True(d.t, d.GetUserByID(id) == nil, "user %d must not exist", id) + if user, err := d.GetUserByID(id); assert.NoError(d.t, err) { + assert.True(d.t, user == nil, "user %d must not exist", id) + } } // AssertClientNotExist asserts that the client does not exist. func (d *Database) AssertClientNotExist(id uint) { - assert.True(d.t, d.GetClientByID(id) == nil, "client %d must not exist", id) + if client, err := d.GetClientByID(id); assert.NoError(d.t, err) { + assert.True(d.t, client == nil, "client %d must not exist", id) + } } // AssertMessageNotExist asserts that the messages does not exist. func (d *Database) AssertMessageNotExist(ids ...uint) { for _, id := range ids { - assert.True(d.t, d.GetMessageByID(id) == nil, "message %d must not exist", id) + if msg, err := d.GetMessageByID(id); assert.NoError(d.t, err) { + assert.True(d.t, msg == nil, "message %d must not exist", id) + } } } // AssertAppExist asserts that the app does exist. func (d *Database) AssertAppExist(id uint) { - assert.False(d.t, d.GetApplicationByID(id) == nil, "app %d must exist", id) + if app, err := d.GetApplicationByID(id); assert.NoError(d.t, err) { + assert.False(d.t, app == nil, "app %d must exist", id) + } } // AssertUserExist asserts that the user does exist. func (d *Database) AssertUserExist(id uint) { - assert.False(d.t, d.GetUserByID(id) == nil, "user %d must exist", id) + if user, err := d.GetUserByID(id); assert.NoError(d.t, err) { + assert.False(d.t, user == nil, "user %d must exist", id) + } } // AssertClientExist asserts that the client does exist. func (d *Database) AssertClientExist(id uint) { - assert.False(d.t, d.GetClientByID(id) == nil, "client %d must exist", id) + if client, err := d.GetClientByID(id); assert.NoError(d.t, err) { + assert.False(d.t, client == nil, "client %d must exist", id) + } } // AssertMessageExist asserts that the message does exist. func (d *Database) AssertMessageExist(id uint) { - assert.False(d.t, d.GetMessageByID(id) == nil, "message %d must exist", id) + if msg, err := d.GetMessageByID(id); assert.NoError(d.t, err) { + assert.False(d.t, msg == nil, "message %d must exist", id) + } } diff --git a/test/testdb/database_test.go b/test/testdb/database_test.go index 84a37b4..c973368 100644 --- a/test/testdb/database_test.go +++ b/test/testdb/database_test.go @@ -12,7 +12,9 @@ import ( func Test_WithDefault(t *testing.T) { db := testdb.NewDBWithDefaultUser(t) - assert.NotNil(t, db.GetUserByName("admin")) + if user, err := db.GetUserByName("admin"); assert.NoError(t, err) { + assert.NotNil(t, user) + } db.Close() } @@ -45,7 +47,9 @@ func (s *DatabaseSuite) Test_Users() { users := []*model.User{{ID: 1, Name: "user1"}, {ID: 2, Name: "user2"}, {ID: 3, Name: "tom"}} - assert.Equal(s.T(), users, s.db.GetUsers()) + if usersActual, err := s.db.GetUsers(); assert.NoError(s.T(), err) { + assert.Equal(s.T(), users, usersActual) + } s.db.AssertUserExist(1) s.db.AssertUserExist(2) s.db.AssertUserExist(3) @@ -68,9 +72,13 @@ func (s *DatabaseSuite) Test_Clients() { assert.Equal(s.T(), newClientExpected, newClientActual) userOneExpected := []*model.Client{{ID: 1, Token: "client1", UserID: 1}, {ID: 2, Token: "asdf", UserID: 1}} - assert.Equal(s.T(), userOneExpected, s.db.GetClientsByUser(1)) + if clients, err := s.db.GetClientsByUser(1); assert.NoError(s.T(), err) { + assert.Equal(s.T(), userOneExpected, clients) + } userTwoExpected := []*model.Client{{ID: 5, Token: "client5", UserID: 2}} - assert.Equal(s.T(), userTwoExpected, s.db.GetClientsByUser(2)) + if clients, err := s.db.GetClientsByUser(2); assert.NoError(s.T(), err) { + assert.Equal(s.T(), userTwoExpected, clients) + } s.db.AssertClientExist(1) s.db.AssertClientExist(2) @@ -99,9 +107,13 @@ func (s *DatabaseSuite) Test_Apps() { assert.Equal(s.T(), newInternalAppExpected, newInternalAppActual) userOneExpected := []*model.Application{{ID: 1, Token: "app1", UserID: 1}, {ID: 2, Token: "asdf", UserID: 1}, {ID: 3, Token: "qwer", UserID: 1, Internal: true}} - assert.Equal(s.T(), userOneExpected, s.db.GetApplicationsByUser(1)) + if app, err := s.db.GetApplicationsByUser(1); assert.NoError(s.T(), err) { + assert.Equal(s.T(), userOneExpected, app) + } userTwoExpected := []*model.Application{{ID: 5, Token: "app5", UserID: 2, Internal: true}} - assert.Equal(s.T(), userTwoExpected, s.db.GetApplicationsByUser(2)) + if app, err := s.db.GetApplicationsByUser(2); assert.NoError(s.T(), err) { + assert.Equal(s.T(), userTwoExpected, app) + } newAppWithName := userBuilder.NewAppWithTokenAndName(7, "test-token", "app name") newAppWithNameExpected := &model.Application{ID: 7, Token: "test-token", UserID: 1, Name: "app name"} @@ -139,9 +151,13 @@ func (s *DatabaseSuite) Test_Messages() { s.db.User(2).App(2).Message(4).Message(5) userOneExpected := []*model.Message{{ID: 2, ApplicationID: 1}, {ID: 1, ApplicationID: 1}} - assert.Equal(s.T(), userOneExpected, s.db.GetMessagesByUser(1)) + if msgs, err := s.db.GetMessagesByUser(1); assert.NoError(s.T(), err) { + assert.Equal(s.T(), userOneExpected, msgs) + } userTwoExpected := []*model.Message{{ID: 5, ApplicationID: 2}, {ID: 4, ApplicationID: 2}} - assert.Equal(s.T(), userTwoExpected, s.db.GetMessagesByUser(2)) + if msgs, err := s.db.GetMessagesByUser(2); assert.NoError(s.T(), err) { + assert.Equal(s.T(), userTwoExpected, msgs) + } s.db.AssertMessageExist(1) s.db.AssertMessageExist(2)