diff --git a/auth/authentication.go b/auth/authentication.go index f377901..6f3c036 100644 --- a/auth/authentication.go +++ b/auth/authentication.go @@ -8,15 +8,13 @@ import ( ) const ( - headerName = "Authorization" - headerSchema = "ApiKey " - typeAdmin = 0 - typeAll = 1 - typeWriteOnly = 2 + headerName = "Authorization" + headerSchema = "ApiKey " ) type Database interface { - GetTokenById(id string) *model.Token + GetApplicationById(id string) *model.Application + GetClientById(id string) *model.Client GetUserByName(name string) *model.User GetUserById(id uint) *model.User } @@ -25,39 +23,62 @@ type Auth struct { DB Database } +type authenticate func(tokenId string, user *model.User) (success bool, userId uint) + func (a *Auth) RequireAdmin() gin.HandlerFunc { - return a.requireToken(typeAdmin) + return a.requireToken(func(tokenId string, user *model.User) (bool, uint) { + if user != nil { + return user.Admin, user.Id + } + if token := a.DB.GetClientById(tokenId); token != nil { + return a.DB.GetUserById(token.UserId).Admin, token.UserId + } + return false, 0 + }) } func (a *Auth) RequireAll() gin.HandlerFunc { - return a.requireToken(typeAll) + return a.requireToken(func(tokenId string, user *model.User) (bool, uint) { + if user != nil { + return true, user.Id + } + if token := a.DB.GetClientById(tokenId); token != nil { + return true, token.UserId + } + return false, 0 + }) } func (a *Auth) RequireWrite() gin.HandlerFunc { - return a.requireToken(typeWriteOnly) + return a.requireToken(func(tokenId string, user *model.User) (bool, uint) { + if user != nil { + return false, 0 + } + if token := a.DB.GetApplicationById(tokenId); token != nil { + return true, token.UserId + } + return false, 0 + }) } -func (a *Auth) tokenFromQueryOrHeader(ctx *gin.Context) *model.Token { - if token := a.tokenFromQuery(ctx); token != nil { +func (a *Auth) tokenFromQueryOrHeader(ctx *gin.Context) string { + if token := a.tokenFromQuery(ctx); token != "" { return token - } else if token := a.tokenFromHeader(ctx); token != nil { + } else if token := a.tokenFromHeader(ctx); token != "" { return token } - return nil + return "" } -func (a *Auth) tokenFromQuery(ctx *gin.Context) *model.Token { - if token := ctx.Request.URL.Query().Get("token"); token != "" { - return a.DB.GetTokenById(token) - } - return nil +func (a *Auth) tokenFromQuery(ctx *gin.Context) string { + return ctx.Request.URL.Query().Get("token") } -func (a *Auth) tokenFromHeader(ctx *gin.Context) *model.Token { +func (a *Auth) tokenFromHeader(ctx *gin.Context) string { if header := ctx.Request.Header.Get(headerName); header != "" && strings.HasPrefix(header, headerSchema) { - return a.DB.GetTokenById(strings.TrimPrefix(header, headerSchema)) + return strings.TrimPrefix(header, headerSchema) } - return nil + return "" } func (a *Auth) userFromBasicAuth(ctx *gin.Context) *model.User { @@ -69,33 +90,17 @@ func (a *Auth) userFromBasicAuth(ctx *gin.Context) *model.User { return nil } -func (a *Auth) isAuthenticated(checkType int, token *model.Token, user *model.User) bool { - if token == nil && user == nil { - return false - } - - switch checkType { - case typeWriteOnly: - return true - case typeAll: - return user != nil || (token != nil && !token.WriteOnly) - default: - if user == nil { - user = a.DB.GetUserById(token.UserID) - } - return user != nil && user.Admin - } -} - -func (a *Auth) requireToken(checkType int) gin.HandlerFunc { +func (a *Auth) requireToken(auth authenticate) gin.HandlerFunc { return func(ctx *gin.Context) { token := a.tokenFromQueryOrHeader(ctx) user := a.userFromBasicAuth(ctx) - if a.isAuthenticated(checkType, token, user) { - ctx.Next() - } else { - ctx.AbortWithError(401, errors.New("could not authenticate")) + if user != nil || token != "" { + if ok, _ := auth(token, user); ok { + ctx.Next() + return + } } + ctx.AbortWithError(401, errors.New("could not authenticate")) } } diff --git a/auth/authentication_test.go b/auth/authentication_test.go index b7b7b1b..3cc1d7f 100644 --- a/auth/authentication_test.go +++ b/auth/authentication_test.go @@ -3,8 +3,10 @@ package auth import ( "fmt" "github.com/gin-gonic/gin" + authmock "github.com/jmattheis/memo/auth/mock" "github.com/jmattheis/memo/model" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "net/http/httptest" "testing" @@ -17,33 +19,54 @@ func TestSuite(t *testing.T) { type AuthenticationSuite struct { suite.Suite auth *Auth + DB *authmock.MockDatabase } func (s *AuthenticationSuite) SetupSuite() { gin.SetMode(gin.TestMode) - s.auth = &Auth{&DBMock{}} + s.DB = &authmock.MockDatabase{} + s.auth = &Auth{s.DB} + s.DB.On("GetClientById", "clienttoken").Return(&model.Client{Id: "clienttoken", UserId: 1, Name: "android phone"}) + s.DB.On("GetClientById", "clienttoken_admin").Return(&model.Client{Id: "clienttoken", UserId: 2, Name: "android phone2"}) + s.DB.On("GetClientById", mock.Anything).Return(nil) + s.DB.On("GetApplicationById", "apptoken").Return(&model.Application{Id: "apptoken", UserId: 1, Name: "backup server", Description: "irrelevant"}) + s.DB.On("GetApplicationById", "apptoken_admin").Return(&model.Application{Id: "apptoken", UserId: 2, Name: "backup server", Description: "irrelevant"}) + s.DB.On("GetApplicationById", mock.Anything).Return(nil) + + s.DB.On("GetUserById", uint(1)).Return(&model.User{Id: 1, Name: "irrelevant", Admin: false}) + s.DB.On("GetUserById", uint(2)).Return(&model.User{Id: 2, Name: "irrelevant", Admin: true}) + + s.DB.On("GetUserByName", "existing").Return(&model.User{Name: "existing", Pass: CreatePassword("pw")}) + s.DB.On("GetUserByName", "admin").Return(&model.User{Name: "admin", Pass: CreatePassword("pw"), Admin: true}) + s.DB.On("GetUserByName", mock.Anything).Return(nil) } func (s *AuthenticationSuite) TestQueryToken() { + // not existing token s.assertQueryRequest("token", "ergerogerg", s.auth.RequireWrite, 401) s.assertQueryRequest("token", "ergerogerg", s.auth.RequireAll, 401) s.assertQueryRequest("token", "ergerogerg", s.auth.RequireAdmin, 401) - s.assertQueryRequest("tokenx", "all", s.auth.RequireWrite, 401) - s.assertQueryRequest("tokenx", "all", s.auth.RequireAll, 401) - s.assertQueryRequest("tokenx", "all", s.auth.RequireAdmin, 401) + // not existing key + s.assertQueryRequest("tokenx", "clienttoken", s.auth.RequireWrite, 401) + s.assertQueryRequest("tokenx", "clienttoken", s.auth.RequireAll, 401) + s.assertQueryRequest("tokenx", "clienttoken", s.auth.RequireAdmin, 401) - s.assertQueryRequest("token", "writeonly", s.auth.RequireWrite, 200) - s.assertQueryRequest("token", "writeonly", s.auth.RequireAll, 401) - s.assertQueryRequest("token", "writeonly", s.auth.RequireAdmin, 401) + // apptoken + s.assertQueryRequest("token", "apptoken", s.auth.RequireWrite, 200) + s.assertQueryRequest("token", "apptoken", s.auth.RequireAll, 401) + s.assertQueryRequest("token", "apptoken", s.auth.RequireAdmin, 401) + s.assertQueryRequest("token", "apptoken_admin", s.auth.RequireWrite, 200) + s.assertQueryRequest("token", "apptoken_admin", s.auth.RequireAll, 401) + s.assertQueryRequest("token", "apptoken_admin", s.auth.RequireAdmin, 401) - s.assertQueryRequest("token", "all", s.auth.RequireWrite, 200) - s.assertQueryRequest("token", "all", s.auth.RequireAll, 200) - s.assertQueryRequest("token", "all", s.auth.RequireAdmin, 401) - - s.assertQueryRequest("token", "admin", s.auth.RequireWrite, 200) - s.assertQueryRequest("token", "admin", s.auth.RequireAll, 200) - s.assertQueryRequest("token", "admin", s.auth.RequireAdmin, 200) + // clienttoken + s.assertQueryRequest("token", "clienttoken", s.auth.RequireWrite, 401) + s.assertQueryRequest("token", "clienttoken", s.auth.RequireAll, 200) + s.assertQueryRequest("token", "clienttoken", s.auth.RequireAdmin, 401) + s.assertQueryRequest("token", "clienttoken_admin", s.auth.RequireWrite, 401) + s.assertQueryRequest("token", "clienttoken_admin", s.auth.RequireAll, 200) + s.assertQueryRequest("token", "clienttoken_admin", s.auth.RequireAdmin, 200) } func (s *AuthenticationSuite) assertQueryRequest(key, value string, f fMiddleware, code int) { @@ -63,29 +86,41 @@ func (s *AuthenticationSuite) TestNothingProvided() { } func (s *AuthenticationSuite) TestHeaderApiKeyToken() { - s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireWrite, 401) - s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAll, 401) - s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAdmin, 401) - + // not existing token s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireWrite, 401) s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireAll, 401) s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireAdmin, 401) - s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireWrite, 401) - s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireAll, 401) - s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireAdmin, 401) + // no authentication schema + s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAdmin, 401) - s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireWrite, 200) - s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireAll, 401) - s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireAdmin, 401) + // wrong authentication schema + s.assertHeaderRequest("Authorization", "ApiKeyx clienttoken", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorization", "ApiKeyx clienttoken", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "ApiKeyx clienttoken", s.auth.RequireAdmin, 401) - s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireWrite, 200) - s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireAll, 200) - s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireAdmin, 401) + // not existing key + s.assertHeaderRequest("Authorizationx", "ApiKey clienttoken", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorizationx", "ApiKey clienttoken", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorizationx", "ApiKey clienttoken", s.auth.RequireAdmin, 401) - s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireWrite, 200) - s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireAll, 200) - s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireAdmin, 200) + // apptoken + s.assertHeaderRequest("Authorization", "ApiKey apptoken", s.auth.RequireWrite, 200) + s.assertHeaderRequest("Authorization", "ApiKey apptoken", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "ApiKey apptoken", s.auth.RequireAdmin, 401) + s.assertHeaderRequest("Authorization", "ApiKey apptoken_admin", s.auth.RequireWrite, 200) + s.assertHeaderRequest("Authorization", "ApiKey apptoken_admin", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "ApiKey apptoken_admin", s.auth.RequireAdmin, 401) + + // clienttoken + s.assertHeaderRequest("Authorization", "ApiKey clienttoken", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorization", "ApiKey clienttoken", s.auth.RequireAll, 200) + s.assertHeaderRequest("Authorization", "ApiKey clienttoken", s.auth.RequireAdmin, 401) + s.assertHeaderRequest("Authorization", "ApiKey clienttoken_admin", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorization", "ApiKey clienttoken_admin", s.auth.RequireAll, 200) + s.assertHeaderRequest("Authorization", "ApiKey clienttoken_admin", s.auth.RequireAdmin, 200) } func (s *AuthenticationSuite) TestBasicAuth() { @@ -94,12 +129,12 @@ func (s *AuthenticationSuite) TestBasicAuth() { s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireAdmin, 401) // user existing:pw - s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireWrite, 200) + s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireWrite, 401) s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireAll, 200) s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireAdmin, 401) // user admin:pw - s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireWrite, 200) + s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireWrite, 401) s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireAll, 200) s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireAdmin, 200) @@ -107,6 +142,11 @@ func (s *AuthenticationSuite) TestBasicAuth() { s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireWrite, 401) s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireAll, 401) s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireAdmin, 401) + + // user notexisting:pw + s.assertHeaderRequest("Authorization", "Basic bm90ZXhpc3Rpbmc6cHc=", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorization", "Basic bm90ZXhpc3Rpbmc6cHc=", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "Basic bm90ZXhpc3Rpbmc6cHc=", s.auth.RequireAdmin, 401) } func (s *AuthenticationSuite) assertHeaderRequest(key, value string, f fMiddleware, code int) { @@ -119,37 +159,3 @@ func (s *AuthenticationSuite) assertHeaderRequest(key, value string, f fMiddlewa } type fMiddleware func() gin.HandlerFunc -type DBMock struct{} - -func (d *DBMock) GetTokenById(id string) *model.Token { - if id == "writeonly" { - return &model.Token{Id: "valid", WriteOnly: true, UserID: 1} - } - if id == "all" { - return &model.Token{Id: "valid", WriteOnly: false, UserID: 1} - } - if id == "admin" { - return &model.Token{Id: "valid", WriteOnly: false, UserID: 2} - } - return nil -} - -func (d *DBMock) GetUserByName(name string) *model.User { - if name == "existing" { - return &model.User{Name: "existing", Pass: CreatePassword("pw")} - } - if name == "admin" { - return &model.User{Name: "admin", Pass: CreatePassword("pw"), Admin: true} - } - return nil -} -func (d *DBMock) GetUserById(id uint) *model.User { - if id == 1 { - return &model.User{Name: "existing", Pass: CreatePassword("pw"), Admin: false} - } - - if id == 2 { - return &model.User{Name: "existing", Pass: CreatePassword("pw"), Admin: true} - } - return nil -} diff --git a/auth/mock/mock_database.go b/auth/mock/mock_database.go new file mode 100644 index 0000000..9e4bd11 --- /dev/null +++ b/auth/mock/mock_database.go @@ -0,0 +1,74 @@ +// Code generated by mockery v1.0.0 +package mock + +import mock "github.com/stretchr/testify/mock" +import model "github.com/jmattheis/memo/model" + +// MockDatabase is an autogenerated mock type for the Database type +type MockDatabase struct { + mock.Mock +} + +// GetApplicationById provides a mock function with given fields: id +func (_m *MockDatabase) GetApplicationById(id string) *model.Application { + ret := _m.Called(id) + + var r0 *model.Application + if rf, ok := ret.Get(0).(func(string) *model.Application); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Application) + } + } + + return r0 +} + +// GetClientById provides a mock function with given fields: id +func (_m *MockDatabase) GetClientById(id string) *model.Client { + ret := _m.Called(id) + + var r0 *model.Client + if rf, ok := ret.Get(0).(func(string) *model.Client); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Client) + } + } + + return r0 +} + +// GetUserById provides a mock function with given fields: id +func (_m *MockDatabase) GetUserById(id uint) *model.User { + ret := _m.Called(id) + + var r0 *model.User + if rf, ok := ret.Get(0).(func(uint) *model.User); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.User) + } + } + + return r0 +} + +// GetUserByName provides a mock function with given fields: name +func (_m *MockDatabase) GetUserByName(name string) *model.User { + ret := _m.Called(name) + + var r0 *model.User + if rf, ok := ret.Get(0).(func(string) *model.User); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.User) + } + } + + return r0 +} diff --git a/model/application.go b/model/application.go new file mode 100644 index 0000000..884608c --- /dev/null +++ b/model/application.go @@ -0,0 +1,9 @@ +package model + +type Application struct { + Id string `gorm:"primary_key;unique_index"` + UserId uint `gorm:"index" json:"-"` + Name string `form:"name" query:"name" json:"name" binding:"required"` + Description string `form:"description" query:"description" json:"description"` + Messages []Message `json:"-"` +} diff --git a/model/client.go b/model/client.go new file mode 100644 index 0000000..557eb17 --- /dev/null +++ b/model/client.go @@ -0,0 +1,7 @@ +package model + +type Client struct { + Id string `gorm:"primary_key;unique_index"` + UserId uint `gorm:"index" json:"-"` + Name string `form:"name" query:"name" json:"name" binding:"required"` +} diff --git a/model/message.go b/model/message.go index 44069e5..49911c7 100644 --- a/model/message.go +++ b/model/message.go @@ -1,9 +1,12 @@ package model +import "time" + type Message struct { - ID uint `gorm:"primary_key" gorm:"AUTO_INCREMENT;primary_key;index"` - TokenID string + Id uint `gorm:"AUTO_INCREMENT;primary_key;index"` + TokenId string Message string Title string Priority int + Date time.Time } diff --git a/model/token.go b/model/token.go deleted file mode 100644 index e295041..0000000 --- a/model/token.go +++ /dev/null @@ -1,12 +0,0 @@ -package model - -type Token struct { - Name string - DefaultTitle string - Description string - Icon string - WriteOnly bool - UserID uint `gorm:"index"` - Id string `gorm:"primary_key;unique_index"` - Messages []Message -} diff --git a/model/user.go b/model/user.go index e995eb1..5d62da8 100644 --- a/model/user.go +++ b/model/user.go @@ -1,9 +1,10 @@ package model type User struct { - ID uint `gorm:"primary_key;unique_index" gorm:"AUTO_INCREMENT"` - Name string - Pass []byte - Admin bool - Tokens []Token + Id uint `gorm:"primary_key;unique_index;AUTO_INCREMENT"` + Name string + Pass []byte + Admin bool + Tokens []Application + Clients []Client }