Add authentication middleware

This commit is contained in:
Jannis Mattheis 2018-01-19 22:19:27 +01:00 committed by Jannis Mattheis
parent d76ab85396
commit 662e1f4fb5
4 changed files with 294 additions and 0 deletions

101
auth/authentication.go Normal file
View File

@ -0,0 +1,101 @@
package auth
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/jmattheis/memo/model"
"strings"
)
const (
headerName = "Authorization"
headerSchema = "ApiKey "
typeAdmin = 0
typeAll = 1
typeWriteOnly = 2
)
type Database interface {
GetTokenById(id string) *model.Token
GetUserByName(name string) *model.User
GetUserById(id uint) *model.User
}
type Auth struct {
DB Database
}
func (a *Auth) RequireAdmin() gin.HandlerFunc {
return a.requireToken(typeAdmin)
}
func (a *Auth) RequireAll() gin.HandlerFunc {
return a.requireToken(typeAll)
}
func (a *Auth) RequireWrite() gin.HandlerFunc {
return a.requireToken(typeWriteOnly)
}
func (a *Auth) tokenFromQueryOrHeader(ctx *gin.Context) *model.Token {
if token := a.tokenFromQuery(ctx); token != nil {
return token
} else if token := a.tokenFromHeader(ctx); token != nil {
return token
}
return nil
}
func (a *Auth) tokenFromQuery(ctx *gin.Context) *model.Token {
if token := ctx.Request.URL.Query().Get("token"); token != "" {
return a.DB.GetTokenById(token)
}
return nil
}
func (a *Auth) tokenFromHeader(ctx *gin.Context) *model.Token {
if header := ctx.Request.Header.Get(headerName); header != "" && strings.HasPrefix(header, headerSchema) {
return a.DB.GetTokenById(strings.TrimPrefix(header, headerSchema))
}
return nil
}
func (a *Auth) userFromBasicAuth(ctx *gin.Context) *model.User {
if name, pass, ok := ctx.Request.BasicAuth(); ok {
if user := a.DB.GetUserByName(name); user != nil && ComparePassword(user.Pass, []byte(pass)) {
return user
}
}
return nil
}
func (a *Auth) isAuthenticated(checkType int, token *model.Token, user *model.User) bool {
if token == nil && user == nil {
return false
}
switch checkType {
case typeWriteOnly:
return true
case typeAll:
return user != nil || (token != nil && !token.WriteOnly)
default:
if user == nil {
user = a.DB.GetUserById(token.UserID)
}
return user != nil && user.Admin
}
}
func (a *Auth) requireToken(checkType int) gin.HandlerFunc {
return func(ctx *gin.Context) {
token := a.tokenFromQueryOrHeader(ctx)
user := a.userFromBasicAuth(ctx)
if a.isAuthenticated(checkType, token, user) {
ctx.Next()
} else {
ctx.AbortWithError(401, errors.New("could not authenticate"))
}
}
}

155
auth/authentication_test.go Normal file
View File

@ -0,0 +1,155 @@
package auth
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/jmattheis/memo/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"net/http/httptest"
"testing"
)
func TestSuite(t *testing.T) {
suite.Run(t, new(AuthenticationSuite))
}
type AuthenticationSuite struct {
suite.Suite
auth *Auth
}
func (s *AuthenticationSuite) SetupSuite() {
gin.SetMode(gin.TestMode)
s.auth = &Auth{&DBMock{}}
}
func (s *AuthenticationSuite) TestQueryToken() {
s.assertQueryRequest("token", "ergerogerg", s.auth.RequireWrite, 401)
s.assertQueryRequest("token", "ergerogerg", s.auth.RequireAll, 401)
s.assertQueryRequest("token", "ergerogerg", s.auth.RequireAdmin, 401)
s.assertQueryRequest("tokenx", "all", s.auth.RequireWrite, 401)
s.assertQueryRequest("tokenx", "all", s.auth.RequireAll, 401)
s.assertQueryRequest("tokenx", "all", s.auth.RequireAdmin, 401)
s.assertQueryRequest("token", "writeonly", s.auth.RequireWrite, 200)
s.assertQueryRequest("token", "writeonly", s.auth.RequireAll, 401)
s.assertQueryRequest("token", "writeonly", s.auth.RequireAdmin, 401)
s.assertQueryRequest("token", "all", s.auth.RequireWrite, 200)
s.assertQueryRequest("token", "all", s.auth.RequireAll, 200)
s.assertQueryRequest("token", "all", s.auth.RequireAdmin, 401)
s.assertQueryRequest("token", "admin", s.auth.RequireWrite, 200)
s.assertQueryRequest("token", "admin", s.auth.RequireAll, 200)
s.assertQueryRequest("token", "admin", s.auth.RequireAdmin, 200)
}
func (s *AuthenticationSuite) assertQueryRequest(key, value string, f fMiddleware, code int) {
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest("GET", fmt.Sprintf("/?%s=%s", key, value), nil)
f()(ctx)
assert.Equal(s.T(), code, recorder.Code)
}
func (s *AuthenticationSuite) TestNothingProvided() {
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest("GET", "/", nil)
s.auth.RequireWrite()(ctx)
assert.Equal(s.T(), 401, recorder.Code)
}
func (s *AuthenticationSuite) TestHeaderApiKeyToken() {
s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireWrite, 401)
s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAll, 401)
s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAdmin, 401)
s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireWrite, 401)
s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireAll, 401)
s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireAdmin, 401)
s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireWrite, 401)
s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireAll, 401)
s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireAdmin, 401)
s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireWrite, 200)
s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireAll, 401)
s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireAdmin, 401)
s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireWrite, 200)
s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireAll, 200)
s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireAdmin, 401)
s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireWrite, 200)
s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireAll, 200)
s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireAdmin, 200)
}
func (s *AuthenticationSuite) TestBasicAuth() {
s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireWrite, 401)
s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireAll, 401)
s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireAdmin, 401)
// user existing:pw
s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireWrite, 200)
s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireAll, 200)
s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireAdmin, 401)
// user admin:pw
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireWrite, 200)
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireAll, 200)
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireAdmin, 200)
// user admin:pwx
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireWrite, 401)
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireAll, 401)
s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireAdmin, 401)
}
func (s *AuthenticationSuite) assertHeaderRequest(key, value string, f fMiddleware, code int) {
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest("GET", "/", nil)
ctx.Request.Header.Set(key, value)
f()(ctx)
assert.Equal(s.T(), code, recorder.Code)
}
type fMiddleware func() gin.HandlerFunc
type DBMock struct{}
func (d *DBMock) GetTokenById(id string) *model.Token {
if id == "writeonly" {
return &model.Token{Id: "valid", WriteOnly: true, UserID: 1}
}
if id == "all" {
return &model.Token{Id: "valid", WriteOnly: false, UserID: 1}
}
if id == "admin" {
return &model.Token{Id: "valid", WriteOnly: false, UserID: 2}
}
return nil
}
func (d *DBMock) GetUserByName(name string) *model.User {
if name == "existing" {
return &model.User{Name: "existing", Pass: CreatePassword("pw")}
}
if name == "admin" {
return &model.User{Name: "admin", Pass: CreatePassword("pw"), Admin: true}
}
return nil
}
func (d *DBMock) GetUserById(id uint) *model.User {
if id == 1 {
return &model.User{Name: "existing", Pass: CreatePassword("pw"), Admin: false}
}
if id == 2 {
return &model.User{Name: "existing", Pass: CreatePassword("pw"), Admin: true}
}
return nil
}

17
auth/password.go Normal file
View File

@ -0,0 +1,17 @@
package auth
import "golang.org/x/crypto/bcrypt"
var strength = 13
func CreatePassword(pw string) []byte {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(pw), strength)
if err != nil {
panic(err)
}
return hashedPassword
}
func ComparePassword(hashedPassword, password []byte) bool {
return bcrypt.CompareHashAndPassword(hashedPassword, password) == nil
}

21
auth/password_test.go Normal file
View File

@ -0,0 +1,21 @@
package auth
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestPasswordSuccess(t *testing.T) {
password := CreatePassword("secret")
assert.Equal(t, true, ComparePassword(password, []byte("secret")))
}
func TestPasswordFailure(t *testing.T) {
password := CreatePassword("secret")
assert.Equal(t, false, ComparePassword(password, []byte("secretx")))
}
func TestBCryptFailure(t *testing.T) {
strength = 12312 // invalid value
assert.Panics(t, func() { CreatePassword("secret") })
}