diff --git a/api/user.go b/api/user.go index 0432b67..e0e8305 100644 --- a/api/user.go +++ b/api/user.go @@ -17,6 +17,7 @@ type UserDatabase interface { DeleteUserByID(id uint) error UpdateUser(user *model.User) CreateUser(user *model.User) error + CountUser(condition ...interface{}) int } // UserChangeNotifier notifies listeners for user changes. @@ -252,6 +253,10 @@ func (a *UserAPI) GetUserByID(ctx *gin.Context) { func (a *UserAPI) DeleteUserByID(ctx *gin.Context) { withID(ctx, "id", func(id uint) { if user := a.DB.GetUserByID(id); user != nil { + if user.Admin && a.DB.CountUser(&model.User{Admin: true}) == 1 { + ctx.AbortWithError(400, errors.New("cannot delete last admin")) + return + } if err := a.UserChangeNotifier.fireUserDeleted(id); err != nil { ctx.AbortWithError(500, err) return @@ -350,6 +355,10 @@ func (a *UserAPI) UpdateUserByID(ctx *gin.Context) { var user *model.UserExternalWithPass if err := ctx.Bind(&user); err == nil { if oldUser := a.DB.GetUserByID(id); oldUser != nil { + if !user.Admin && oldUser.Admin && a.DB.CountUser(&model.User{Admin: true}) == 1 { + ctx.AbortWithError(400, errors.New("cannot delete last admin")) + return + } internal := a.toInternalUser(user, oldUser.Pass) internal.ID = id a.DB.UpdateUser(internal) diff --git a/api/user_test.go b/api/user_test.go index 029a709..43a8ba6 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -101,6 +101,19 @@ func (s *UserSuite) Test_GetUserByID_UnknownUser() { assert.Equal(s.T(), 404, s.recorder.Code) } +func (s *UserSuite) Test_DeleteUserByID_LastAdmin_Expect400() { + s.db.CreateUser(&model.User{ + ID: 7, + Name: "admin", + Admin: true, + }) + s.ctx.Params = gin.Params{{Key: "id", Value: "7"}} + + s.a.DeleteUserByID(s.ctx) + + assert.Equal(s.T(), 400, s.recorder.Code) +} + func (s *UserSuite) Test_DeleteUserByID_InvalidID() { s.ctx.Params = gin.Params{{Key: "id", Value: "abc"}} @@ -221,6 +234,22 @@ func (s *UserSuite) Test_UpdateUserByID_InvalidID() { assert.Equal(s.T(), 400, s.recorder.Code) } +func (s *UserSuite) Test_UpdateUserByID_LastAdmin_Expect400() { + s.db.CreateUser(&model.User{ + ID: 7, + Name: "admin", + Admin: true, + }) + + s.ctx.Params = gin.Params{{Key: "id", Value: "7"}} + + s.ctx.Request = httptest.NewRequest("POST", "/user/7", strings.NewReader(`{"name": "admin", "pass": "", "admin": false}`)) + s.ctx.Request.Header.Set("Content-Type", "application/json") + s.a.UpdateUserByID(s.ctx) + + assert.Equal(s.T(), 400, s.recorder.Code) +} + func (s *UserSuite) Test_UpdateUserByID_UnknownUser() { s.ctx.Params = gin.Params{{Key: "id", Value: "2"}} diff --git a/database/user.go b/database/user.go index 9790f4b..961a957 100644 --- a/database/user.go +++ b/database/user.go @@ -24,6 +24,19 @@ func (d *GormDatabase) GetUserByID(id uint) *model.User { return nil } +// CountUser returns the user count which satisfies the given condition. +func (d *GormDatabase) CountUser(condition ...interface{}) int { + c := -1 + handle := d.DB.Model(new(model.User)) + if len(condition) == 1 { + handle = handle.Where(condition[0]) + } else if len(condition) > 1 { + handle = handle.Where(condition[0], condition[1:]...) + } + handle.Count(&c) + return c +} + // GetUsers returns all users. func (d *GormDatabase) GetUsers() []*model.User { var users []*model.User diff --git a/database/user_test.go b/database/user_test.go index 2ff88c3..2b1c601 100644 --- a/database/user_test.go +++ b/database/user_test.go @@ -11,6 +11,7 @@ func (s *DatabaseSuite) TestUser() { jmattheis := s.db.GetUserByID(1) assert.NotNil(s.T(), jmattheis, "on bootup the first user should be automatically created") + assert.Equal(s.T(), 1, s.db.CountUser("admin = ?", true), 1, "there is initially one admin") users := s.db.GetUsers() assert.Len(s.T(), users, 1) @@ -19,6 +20,7 @@ func (s *DatabaseSuite) TestUser() { nicories := &model.User{Name: "nicories", Pass: []byte{1, 2, 3, 4}, Admin: false} s.db.CreateUser(nicories) assert.NotEqual(s.T(), 0, nicories.ID, "on create user a new id should be assigned") + assert.Equal(s.T(), 2, s.db.CountUser(), "two users should exist") assert.Equal(s.T(), nicories, s.db.GetUserByName("nicories")) @@ -35,6 +37,7 @@ func (s *DatabaseSuite) TestUser() { assert.Equal(s.T(), &model.User{ID: nicories.ID, Name: "tom", Pass: []byte{12}, Admin: true}, tom) users = s.db.GetUsers() assert.Len(s.T(), users, 2) + assert.Equal(s.T(), 2, s.db.CountUser(&model.User{Admin: true}), "two admins exist") s.db.DeleteUserByID(tom.ID) users = s.db.GetUsers()