Prevent removing last admin (#130)

This commit is contained in:
饺子w 2019-02-27 01:46:42 +08:00 committed by Jannis Mattheis
parent ec5b1f8c30
commit 2fa395cb84
4 changed files with 54 additions and 0 deletions

View File

@ -17,6 +17,7 @@ type UserDatabase interface {
DeleteUserByID(id uint) error DeleteUserByID(id uint) error
UpdateUser(user *model.User) UpdateUser(user *model.User)
CreateUser(user *model.User) error CreateUser(user *model.User) error
CountUser(condition ...interface{}) int
} }
// UserChangeNotifier notifies listeners for user changes. // UserChangeNotifier notifies listeners for user changes.
@ -252,6 +253,10 @@ func (a *UserAPI) GetUserByID(ctx *gin.Context) {
func (a *UserAPI) DeleteUserByID(ctx *gin.Context) { func (a *UserAPI) DeleteUserByID(ctx *gin.Context) {
withID(ctx, "id", func(id uint) { withID(ctx, "id", func(id uint) {
if user := a.DB.GetUserByID(id); user != nil { if user := a.DB.GetUserByID(id); user != nil {
if user.Admin && a.DB.CountUser(&model.User{Admin: true}) == 1 {
ctx.AbortWithError(400, errors.New("cannot delete last admin"))
return
}
if err := a.UserChangeNotifier.fireUserDeleted(id); err != nil { if err := a.UserChangeNotifier.fireUserDeleted(id); err != nil {
ctx.AbortWithError(500, err) ctx.AbortWithError(500, err)
return return
@ -350,6 +355,10 @@ func (a *UserAPI) UpdateUserByID(ctx *gin.Context) {
var user *model.UserExternalWithPass var user *model.UserExternalWithPass
if err := ctx.Bind(&user); err == nil { if err := ctx.Bind(&user); err == nil {
if oldUser := a.DB.GetUserByID(id); oldUser != nil { if oldUser := a.DB.GetUserByID(id); oldUser != nil {
if !user.Admin && oldUser.Admin && a.DB.CountUser(&model.User{Admin: true}) == 1 {
ctx.AbortWithError(400, errors.New("cannot delete last admin"))
return
}
internal := a.toInternalUser(user, oldUser.Pass) internal := a.toInternalUser(user, oldUser.Pass)
internal.ID = id internal.ID = id
a.DB.UpdateUser(internal) a.DB.UpdateUser(internal)

View File

@ -101,6 +101,19 @@ func (s *UserSuite) Test_GetUserByID_UnknownUser() {
assert.Equal(s.T(), 404, s.recorder.Code) assert.Equal(s.T(), 404, s.recorder.Code)
} }
func (s *UserSuite) Test_DeleteUserByID_LastAdmin_Expect400() {
s.db.CreateUser(&model.User{
ID: 7,
Name: "admin",
Admin: true,
})
s.ctx.Params = gin.Params{{Key: "id", Value: "7"}}
s.a.DeleteUserByID(s.ctx)
assert.Equal(s.T(), 400, s.recorder.Code)
}
func (s *UserSuite) Test_DeleteUserByID_InvalidID() { func (s *UserSuite) Test_DeleteUserByID_InvalidID() {
s.ctx.Params = gin.Params{{Key: "id", Value: "abc"}} s.ctx.Params = gin.Params{{Key: "id", Value: "abc"}}
@ -221,6 +234,22 @@ func (s *UserSuite) Test_UpdateUserByID_InvalidID() {
assert.Equal(s.T(), 400, s.recorder.Code) assert.Equal(s.T(), 400, s.recorder.Code)
} }
func (s *UserSuite) Test_UpdateUserByID_LastAdmin_Expect400() {
s.db.CreateUser(&model.User{
ID: 7,
Name: "admin",
Admin: true,
})
s.ctx.Params = gin.Params{{Key: "id", Value: "7"}}
s.ctx.Request = httptest.NewRequest("POST", "/user/7", strings.NewReader(`{"name": "admin", "pass": "", "admin": false}`))
s.ctx.Request.Header.Set("Content-Type", "application/json")
s.a.UpdateUserByID(s.ctx)
assert.Equal(s.T(), 400, s.recorder.Code)
}
func (s *UserSuite) Test_UpdateUserByID_UnknownUser() { func (s *UserSuite) Test_UpdateUserByID_UnknownUser() {
s.ctx.Params = gin.Params{{Key: "id", Value: "2"}} s.ctx.Params = gin.Params{{Key: "id", Value: "2"}}

View File

@ -24,6 +24,19 @@ func (d *GormDatabase) GetUserByID(id uint) *model.User {
return nil return nil
} }
// CountUser returns the user count which satisfies the given condition.
func (d *GormDatabase) CountUser(condition ...interface{}) int {
c := -1
handle := d.DB.Model(new(model.User))
if len(condition) == 1 {
handle = handle.Where(condition[0])
} else if len(condition) > 1 {
handle = handle.Where(condition[0], condition[1:]...)
}
handle.Count(&c)
return c
}
// GetUsers returns all users. // GetUsers returns all users.
func (d *GormDatabase) GetUsers() []*model.User { func (d *GormDatabase) GetUsers() []*model.User {
var users []*model.User var users []*model.User

View File

@ -11,6 +11,7 @@ func (s *DatabaseSuite) TestUser() {
jmattheis := s.db.GetUserByID(1) jmattheis := s.db.GetUserByID(1)
assert.NotNil(s.T(), jmattheis, "on bootup the first user should be automatically created") assert.NotNil(s.T(), jmattheis, "on bootup the first user should be automatically created")
assert.Equal(s.T(), 1, s.db.CountUser("admin = ?", true), 1, "there is initially one admin")
users := s.db.GetUsers() users := s.db.GetUsers()
assert.Len(s.T(), users, 1) assert.Len(s.T(), users, 1)
@ -19,6 +20,7 @@ func (s *DatabaseSuite) TestUser() {
nicories := &model.User{Name: "nicories", Pass: []byte{1, 2, 3, 4}, Admin: false} nicories := &model.User{Name: "nicories", Pass: []byte{1, 2, 3, 4}, Admin: false}
s.db.CreateUser(nicories) s.db.CreateUser(nicories)
assert.NotEqual(s.T(), 0, nicories.ID, "on create user a new id should be assigned") assert.NotEqual(s.T(), 0, nicories.ID, "on create user a new id should be assigned")
assert.Equal(s.T(), 2, s.db.CountUser(), "two users should exist")
assert.Equal(s.T(), nicories, s.db.GetUserByName("nicories")) assert.Equal(s.T(), nicories, s.db.GetUserByName("nicories"))
@ -35,6 +37,7 @@ func (s *DatabaseSuite) TestUser() {
assert.Equal(s.T(), &model.User{ID: nicories.ID, Name: "tom", Pass: []byte{12}, Admin: true}, tom) assert.Equal(s.T(), &model.User{ID: nicories.ID, Name: "tom", Pass: []byte{12}, Admin: true}, tom)
users = s.db.GetUsers() users = s.db.GetUsers()
assert.Len(s.T(), users, 2) assert.Len(s.T(), users, 2)
assert.Equal(s.T(), 2, s.db.CountUser(&model.User{Admin: true}), "two admins exist")
s.db.DeleteUserByID(tom.ID) s.db.DeleteUserByID(tom.ID)
users = s.db.GetUsers() users = s.db.GetUsers()