Add authentication middleware
This commit is contained in:
parent
d76ab85396
commit
662e1f4fb5
|
|
@ -0,0 +1,101 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jmattheis/memo/model"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
headerName = "Authorization"
|
||||
headerSchema = "ApiKey "
|
||||
typeAdmin = 0
|
||||
typeAll = 1
|
||||
typeWriteOnly = 2
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
GetTokenById(id string) *model.Token
|
||||
GetUserByName(name string) *model.User
|
||||
GetUserById(id uint) *model.User
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
DB Database
|
||||
}
|
||||
|
||||
func (a *Auth) RequireAdmin() gin.HandlerFunc {
|
||||
return a.requireToken(typeAdmin)
|
||||
}
|
||||
|
||||
func (a *Auth) RequireAll() gin.HandlerFunc {
|
||||
return a.requireToken(typeAll)
|
||||
}
|
||||
|
||||
func (a *Auth) RequireWrite() gin.HandlerFunc {
|
||||
return a.requireToken(typeWriteOnly)
|
||||
}
|
||||
|
||||
func (a *Auth) tokenFromQueryOrHeader(ctx *gin.Context) *model.Token {
|
||||
if token := a.tokenFromQuery(ctx); token != nil {
|
||||
return token
|
||||
} else if token := a.tokenFromHeader(ctx); token != nil {
|
||||
return token
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) tokenFromQuery(ctx *gin.Context) *model.Token {
|
||||
if token := ctx.Request.URL.Query().Get("token"); token != "" {
|
||||
return a.DB.GetTokenById(token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) tokenFromHeader(ctx *gin.Context) *model.Token {
|
||||
if header := ctx.Request.Header.Get(headerName); header != "" && strings.HasPrefix(header, headerSchema) {
|
||||
return a.DB.GetTokenById(strings.TrimPrefix(header, headerSchema))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) userFromBasicAuth(ctx *gin.Context) *model.User {
|
||||
if name, pass, ok := ctx.Request.BasicAuth(); ok {
|
||||
if user := a.DB.GetUserByName(name); user != nil && ComparePassword(user.Pass, []byte(pass)) {
|
||||
return user
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) isAuthenticated(checkType int, token *model.Token, user *model.User) bool {
|
||||
if token == nil && user == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch checkType {
|
||||
case typeWriteOnly:
|
||||
return true
|
||||
case typeAll:
|
||||
return user != nil || (token != nil && !token.WriteOnly)
|
||||
default:
|
||||
if user == nil {
|
||||
user = a.DB.GetUserById(token.UserID)
|
||||
}
|
||||
return user != nil && user.Admin
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth) requireToken(checkType int) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
token := a.tokenFromQueryOrHeader(ctx)
|
||||
user := a.userFromBasicAuth(ctx)
|
||||
|
||||
if a.isAuthenticated(checkType, token, user) {
|
||||
ctx.Next()
|
||||
} else {
|
||||
ctx.AbortWithError(401, errors.New("could not authenticate"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jmattheis/memo/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSuite(t *testing.T) {
|
||||
suite.Run(t, new(AuthenticationSuite))
|
||||
}
|
||||
|
||||
type AuthenticationSuite struct {
|
||||
suite.Suite
|
||||
auth *Auth
|
||||
}
|
||||
|
||||
func (s *AuthenticationSuite) SetupSuite() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.auth = &Auth{&DBMock{}}
|
||||
}
|
||||
|
||||
func (s *AuthenticationSuite) TestQueryToken() {
|
||||
s.assertQueryRequest("token", "ergerogerg", s.auth.RequireWrite, 401)
|
||||
s.assertQueryRequest("token", "ergerogerg", s.auth.RequireAll, 401)
|
||||
s.assertQueryRequest("token", "ergerogerg", s.auth.RequireAdmin, 401)
|
||||
|
||||
s.assertQueryRequest("tokenx", "all", s.auth.RequireWrite, 401)
|
||||
s.assertQueryRequest("tokenx", "all", s.auth.RequireAll, 401)
|
||||
s.assertQueryRequest("tokenx", "all", s.auth.RequireAdmin, 401)
|
||||
|
||||
s.assertQueryRequest("token", "writeonly", s.auth.RequireWrite, 200)
|
||||
s.assertQueryRequest("token", "writeonly", s.auth.RequireAll, 401)
|
||||
s.assertQueryRequest("token", "writeonly", s.auth.RequireAdmin, 401)
|
||||
|
||||
s.assertQueryRequest("token", "all", s.auth.RequireWrite, 200)
|
||||
s.assertQueryRequest("token", "all", s.auth.RequireAll, 200)
|
||||
s.assertQueryRequest("token", "all", s.auth.RequireAdmin, 401)
|
||||
|
||||
s.assertQueryRequest("token", "admin", s.auth.RequireWrite, 200)
|
||||
s.assertQueryRequest("token", "admin", s.auth.RequireAll, 200)
|
||||
s.assertQueryRequest("token", "admin", s.auth.RequireAdmin, 200)
|
||||
}
|
||||
|
||||
func (s *AuthenticationSuite) assertQueryRequest(key, value string, f fMiddleware, code int) {
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest("GET", fmt.Sprintf("/?%s=%s", key, value), nil)
|
||||
f()(ctx)
|
||||
assert.Equal(s.T(), code, recorder.Code)
|
||||
}
|
||||
|
||||
func (s *AuthenticationSuite) TestNothingProvided() {
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest("GET", "/", nil)
|
||||
s.auth.RequireWrite()(ctx)
|
||||
assert.Equal(s.T(), 401, recorder.Code)
|
||||
}
|
||||
|
||||
func (s *AuthenticationSuite) TestHeaderApiKeyToken() {
|
||||
s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireWrite, 401)
|
||||
s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAll, 401)
|
||||
s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAdmin, 401)
|
||||
|
||||
s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireWrite, 401)
|
||||
s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireAll, 401)
|
||||
s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireAdmin, 401)
|
||||
|
||||
s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireWrite, 401)
|
||||
s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireAll, 401)
|
||||
s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireAdmin, 401)
|
||||
|
||||
s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireWrite, 200)
|
||||
s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireAll, 401)
|
||||
s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireAdmin, 401)
|
||||
|
||||
s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireWrite, 200)
|
||||
s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireAll, 200)
|
||||
s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireAdmin, 401)
|
||||
|
||||
s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireWrite, 200)
|
||||
s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireAll, 200)
|
||||
s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireAdmin, 200)
|
||||
}
|
||||
|
||||
func (s *AuthenticationSuite) TestBasicAuth() {
|
||||
s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireWrite, 401)
|
||||
s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireAll, 401)
|
||||
s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireAdmin, 401)
|
||||
|
||||
// user existing:pw
|
||||
s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireWrite, 200)
|
||||
s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireAll, 200)
|
||||
s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireAdmin, 401)
|
||||
|
||||
// user admin:pw
|
||||
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireWrite, 200)
|
||||
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireAll, 200)
|
||||
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireAdmin, 200)
|
||||
|
||||
// user admin:pwx
|
||||
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireWrite, 401)
|
||||
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireAll, 401)
|
||||
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireAdmin, 401)
|
||||
}
|
||||
|
||||
func (s *AuthenticationSuite) assertHeaderRequest(key, value string, f fMiddleware, code int) {
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest("GET", "/", nil)
|
||||
ctx.Request.Header.Set(key, value)
|
||||
f()(ctx)
|
||||
assert.Equal(s.T(), code, recorder.Code)
|
||||
}
|
||||
|
||||
type fMiddleware func() gin.HandlerFunc
|
||||
type DBMock struct{}
|
||||
|
||||
func (d *DBMock) GetTokenById(id string) *model.Token {
|
||||
if id == "writeonly" {
|
||||
return &model.Token{Id: "valid", WriteOnly: true, UserID: 1}
|
||||
}
|
||||
if id == "all" {
|
||||
return &model.Token{Id: "valid", WriteOnly: false, UserID: 1}
|
||||
}
|
||||
if id == "admin" {
|
||||
return &model.Token{Id: "valid", WriteOnly: false, UserID: 2}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DBMock) GetUserByName(name string) *model.User {
|
||||
if name == "existing" {
|
||||
return &model.User{Name: "existing", Pass: CreatePassword("pw")}
|
||||
}
|
||||
if name == "admin" {
|
||||
return &model.User{Name: "admin", Pass: CreatePassword("pw"), Admin: true}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (d *DBMock) GetUserById(id uint) *model.User {
|
||||
if id == 1 {
|
||||
return &model.User{Name: "existing", Pass: CreatePassword("pw"), Admin: false}
|
||||
}
|
||||
|
||||
if id == 2 {
|
||||
return &model.User{Name: "existing", Pass: CreatePassword("pw"), Admin: true}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
package auth
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
var strength = 13
|
||||
|
||||
func CreatePassword(pw string) []byte {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(pw), strength)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hashedPassword
|
||||
}
|
||||
|
||||
func ComparePassword(hashedPassword, password []byte) bool {
|
||||
return bcrypt.CompareHashAndPassword(hashedPassword, password) == nil
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPasswordSuccess(t *testing.T) {
|
||||
password := CreatePassword("secret")
|
||||
assert.Equal(t, true, ComparePassword(password, []byte("secret")))
|
||||
}
|
||||
|
||||
func TestPasswordFailure(t *testing.T) {
|
||||
password := CreatePassword("secret")
|
||||
assert.Equal(t, false, ComparePassword(password, []byte("secretx")))
|
||||
}
|
||||
|
||||
func TestBCryptFailure(t *testing.T) {
|
||||
strength = 12312 // invalid value
|
||||
assert.Panics(t, func() { CreatePassword("secret") })
|
||||
}
|
||||
Loading…
Reference in New Issue