From dffe12418bb7f459b79cbe87e5ae2fe22ba46e12 Mon Sep 17 00:00:00 2001 From: Jannis Mattheis Date: Sat, 10 Feb 2018 23:08:13 +0100 Subject: [PATCH] Add database wrapper --- database/application.go | 32 +++++++++ database/application_test.go | 34 +++++++++ database/client.go | 30 ++++++++ database/client_test.go | 34 +++++++++ database/database.go | 35 ++++++++++ database/database_test.go | 34 +++++++++ database/message.go | 53 ++++++++++++++ database/message_test.go | 129 +++++++++++++++++++++++++++++++++++ database/user.go | 47 +++++++++++++ database/user_test.go | 47 +++++++++++++ 10 files changed, 475 insertions(+) create mode 100644 database/application.go create mode 100644 database/application_test.go create mode 100644 database/client.go create mode 100644 database/client_test.go create mode 100644 database/database.go create mode 100644 database/database_test.go create mode 100644 database/message.go create mode 100644 database/message_test.go create mode 100644 database/user.go create mode 100644 database/user_test.go diff --git a/database/application.go b/database/application.go new file mode 100644 index 0000000..e139ff9 --- /dev/null +++ b/database/application.go @@ -0,0 +1,32 @@ +package database + +import ( + "github.com/jmattheis/memo/model" +) + +// GetApplicationByID returns the application for the given id or nil. +func (d *GormDatabase) GetApplicationByID(id string) *model.Application { + app := new(model.Application) + d.DB.Where("id = ?", id).Find(app) + if app.ID == id { + return app + } + return nil +} + +// CreateApplication creates an application. +func (d *GormDatabase) CreateApplication(application *model.Application) error { + return d.DB.Create(application).Error +} + +// DeleteApplicationByID deletes an application by its id. +func (d *GormDatabase) DeleteApplicationByID(id string) error { + return d.DB.Where("id = ?", id).Delete(&model.Application{}).Error +} + +// GetApplicationsByUser returns all applications from a user. +func (d *GormDatabase) GetApplicationsByUser(userID uint) []*model.Application { + var apps []*model.Application + d.DB.Where("user_id = ?", userID).Find(&apps) + return apps +} diff --git a/database/application_test.go b/database/application_test.go new file mode 100644 index 0000000..f26a696 --- /dev/null +++ b/database/application_test.go @@ -0,0 +1,34 @@ +package database + +import ( + "github.com/jmattheis/memo/model" + "github.com/stretchr/testify/assert" +) + +func (s *DatabaseSuite) TestApplication() { + assert.Nil(s.T(), s.db.GetApplicationByID("asdasdf"), "not existing app") + + user := &model.User{Name: "test", Pass: []byte{1}} + s.db.CreateUser(user) + assert.NotEqual(s.T(), 0, user.ID) + + apps := s.db.GetApplicationsByUser(user.ID) + assert.Empty(s.T(), apps) + + app := &model.Application{UserID: user.ID, ID: "C0000000000", Name: "backupserver"} + s.db.CreateApplication(app) + + apps = s.db.GetApplicationsByUser(user.ID) + assert.Len(s.T(), apps, 1) + assert.Contains(s.T(), apps, app) + + newApp := s.db.GetApplicationByID(app.ID) + assert.Equal(s.T(), app, newApp) + + s.db.DeleteApplicationByID(app.ID) + + apps = s.db.GetApplicationsByUser(user.ID) + assert.Empty(s.T(), apps) + + assert.Nil(s.T(), s.db.GetApplicationByID(app.ID)) +} diff --git a/database/client.go b/database/client.go new file mode 100644 index 0000000..e4a2de0 --- /dev/null +++ b/database/client.go @@ -0,0 +1,30 @@ +package database + +import "github.com/jmattheis/memo/model" + +// GetClientByID returns the client for the given id or nil. +func (d *GormDatabase) GetClientByID(id string) *model.Client { + client := new(model.Client) + d.DB.Where("id = ?", id).Find(client) + if client.ID == id { + return client + } + return nil +} + +// CreateClient creates a client. +func (d *GormDatabase) CreateClient(client *model.Client) error { + return d.DB.Create(client).Error +} + +// GetClientsByUser returns all clients from a user. +func (d *GormDatabase) GetClientsByUser(userID uint) []*model.Client { + var clients []*model.Client + d.DB.Where("user_id = ?", userID).Find(&clients) + return clients +} + +// DeleteClientByID deletes a client by its id. +func (d *GormDatabase) DeleteClientByID(id string) error { + return d.DB.Where("id = ?", id).Delete(&model.Client{}).Error +} diff --git a/database/client_test.go b/database/client_test.go new file mode 100644 index 0000000..5a3324a --- /dev/null +++ b/database/client_test.go @@ -0,0 +1,34 @@ +package database + +import ( + "github.com/jmattheis/memo/model" + "github.com/stretchr/testify/assert" +) + +func (s *DatabaseSuite) TestClient() { + assert.Nil(s.T(), s.db.GetClientByID("asdasdf"), "not existing client") + + user := &model.User{Name: "test", Pass: []byte{1}} + s.db.CreateUser(user) + assert.NotEqual(s.T(), 0, user.ID) + + clients := s.db.GetClientsByUser(user.ID) + assert.Empty(s.T(), clients) + + client := &model.Client{UserID: user.ID, ID: "C0000000000", Name: "android"} + s.db.CreateClient(client) + + clients = s.db.GetClientsByUser(user.ID) + assert.Len(s.T(), clients, 1) + assert.Contains(s.T(), clients, client) + + newClient := s.db.GetClientByID(client.ID) + assert.Equal(s.T(), client, newClient) + + s.db.DeleteClientByID(client.ID) + + clients = s.db.GetClientsByUser(user.ID) + assert.Empty(s.T(), clients) + + assert.Nil(s.T(), s.db.GetClientByID(client.ID)) +} diff --git a/database/database.go b/database/database.go new file mode 100644 index 0000000..5ce8156 --- /dev/null +++ b/database/database.go @@ -0,0 +1,35 @@ +package database + +import ( + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/mysql" // enable the mysql dialect + _ "github.com/jinzhu/gorm/dialects/postgres" // enable the postgres dialect + _ "github.com/jinzhu/gorm/dialects/sqlite" // enable the sqlite3 dialect + "github.com/jmattheis/memo/auth" + "github.com/jmattheis/memo/model" +) + +// New creates a new wrapper for the gorm database framework. +func New(dialect, connection, defaultUser, defaultPass string) (*GormDatabase, error) { + db, err := gorm.Open(dialect, connection) + if err != nil { + return nil, err + } + if !db.HasTable(new(model.User)) && !db.HasTable(new(model.Message)) && + !db.HasTable(new(model.Client)) && !db.HasTable(new(model.Application)) { + db.AutoMigrate(new(model.User), new(model.Application), new(model.Message), new(model.Client)) + db.Create(&model.User{Name: defaultUser, Pass: auth.CreatePassword(defaultPass), Admin: true}) + } + + return &GormDatabase{DB: db}, nil +} + +// GormDatabase is a wrapper for the gorm framework. +type GormDatabase struct { + DB *gorm.DB +} + +// Close closes the gorm database connection. +func (d *GormDatabase) Close() { + d.DB.Close() +} diff --git a/database/database_test.go b/database/database_test.go new file mode 100644 index 0000000..bd1abbb --- /dev/null +++ b/database/database_test.go @@ -0,0 +1,34 @@ +package database + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +func TestDatabaseSuite(t *testing.T) { + suite.Run(t, new(DatabaseSuite)) +} + +type DatabaseSuite struct { + suite.Suite + db *GormDatabase +} + +func (s *DatabaseSuite) BeforeTest(suiteName, testName string) { + db, err := New("sqlite3", "testdb.db", "defaultUser", "defaultPass") + assert.Nil(s.T(), err) + s.db = db +} + +func (s *DatabaseSuite) AfterTest(suiteName, testName string) { + s.db.Close() + assert.Nil(s.T(), os.Remove("testdb.db")) +} + +func TestInvalidDialect(t *testing.T) { + _, err := New("asdf", "testdb.db", "defaultUser", "defaultPass") + assert.NotNil(t, err) +} diff --git a/database/message.go b/database/message.go new file mode 100644 index 0000000..6ecb007 --- /dev/null +++ b/database/message.go @@ -0,0 +1,53 @@ +package database + +import ( + "github.com/jmattheis/memo/model" +) + +// GetMessageByID returns the messages for the given id or nil. +func (d *GormDatabase) GetMessageByID(id uint) *model.Message { + msg := new(model.Message) + d.DB.Find(msg, id) + if msg.ID == id { + return msg + } + return nil +} + +// CreateMessage creates a message. +func (d *GormDatabase) CreateMessage(message *model.Message) error { + return d.DB.Create(message).Error +} + +// GetMessagesByUser returns all messages from a user. +func (d *GormDatabase) GetMessagesByUser(userID uint) []*model.Message { + var messages []*model.Message + d.DB.Joins("JOIN applications ON applications.user_id = ?", userID). + Where("messages.application_id = applications.id").Order("date desc").Find(&messages) + return messages +} + +// GetMessagesByApplication returns all messages from an application. +func (d *GormDatabase) GetMessagesByApplication(tokenID string) []*model.Message { + var messages []*model.Message + d.DB.Where("application_id = ?", tokenID).Order("date desc").Find(&messages) + return messages +} + +// DeleteMessageByID deletes a message by its id. +func (d *GormDatabase) DeleteMessageByID(id uint) error { + return d.DB.Where("id = ?", id).Delete(&model.Message{}).Error +} + +// DeleteMessagesByApplication deletes all messages from an application. +func (d *GormDatabase) DeleteMessagesByApplication(applicationID string) error { + return d.DB.Where("application_id = ?", applicationID).Delete(&model.Message{}).Error +} + +// DeleteMessagesByUser deletes all messages from a user. +func (d *GormDatabase) DeleteMessagesByUser(userID uint) error { + for _, app := range d.GetApplicationsByUser(userID) { + d.DB.Model(app).Association("Messages").Clear() + } + return nil +} diff --git a/database/message_test.go b/database/message_test.go new file mode 100644 index 0000000..148ad91 --- /dev/null +++ b/database/message_test.go @@ -0,0 +1,129 @@ +package database + +import ( + "testing" + "time" + + "github.com/jmattheis/memo/model" + "github.com/stretchr/testify/assert" +) + +func (s *DatabaseSuite) TestMessage() { + assert.Nil(s.T(), s.db.GetMessageByID(5), "not existing message") + + user := &model.User{Name: "test", Pass: []byte{1}} + s.db.CreateUser(user) + assert.NotEqual(s.T(), 0, user.ID) + + backupServer := &model.Application{UserID: user.ID, ID: "A0000000000", Name: "backupserver"} + s.db.CreateApplication(backupServer) + assert.NotEqual(s.T(), 0, backupServer.ID) + + msgs := s.db.GetMessagesByUser(user.ID) + assert.Empty(s.T(), msgs) + msgs = s.db.GetMessagesByApplication(backupServer.ID) + assert.Empty(s.T(), msgs) + + backupdone := &model.Message{ApplicationID: backupServer.ID, Message: "backup done", Title: "backup", Priority: 1, Date: time.Now()} + s.db.CreateMessage(backupdone) + assert.NotEqual(s.T(), 0, backupdone.ID) + + assertEquals(s.T(), s.db.GetMessageByID(backupdone.ID), backupdone) + + msgs = s.db.GetMessagesByUser(user.ID) + assert.Len(s.T(), msgs, 1) + assertEquals(s.T(), msgs[0], backupdone) + + msgs = s.db.GetMessagesByApplication(backupServer.ID) + assert.Len(s.T(), msgs, 1) + assertEquals(s.T(), msgs[0], backupdone) + + loginServer := &model.Application{UserID: user.ID, ID: "A0000000001", Name: "loginserver"} + s.db.CreateApplication(loginServer) + assert.NotEqual(s.T(), 0, loginServer.ID) + + logindone := &model.Message{ApplicationID: loginServer.ID, Message: "login done", Title: "login", Priority: 1, Date: time.Now()} + s.db.CreateMessage(logindone) + assert.NotEqual(s.T(), 0, logindone.ID) + + msgs = s.db.GetMessagesByUser(user.ID) + assert.Len(s.T(), msgs, 2) + assertEquals(s.T(), msgs[0], logindone) + assertEquals(s.T(), msgs[1], backupdone) + + msgs = s.db.GetMessagesByApplication(backupServer.ID) + assert.Len(s.T(), msgs, 1) + assertEquals(s.T(), msgs[0], backupdone) + + loginfailed := &model.Message{ApplicationID: loginServer.ID, Message: "login failed", Title: "login", Priority: 1, Date: time.Now()} + s.db.CreateMessage(loginfailed) + assert.NotEqual(s.T(), 0, loginfailed.ID) + + msgs = s.db.GetMessagesByApplication(backupServer.ID) + assert.Len(s.T(), msgs, 1) + assertEquals(s.T(), msgs[0], backupdone) + + msgs = s.db.GetMessagesByApplication(loginServer.ID) + assert.Len(s.T(), msgs, 2) + assertEquals(s.T(), msgs[0], loginfailed) + assertEquals(s.T(), msgs[1], logindone) + + msgs = s.db.GetMessagesByUser(user.ID) + assert.Len(s.T(), msgs, 3) + assertEquals(s.T(), msgs[0], loginfailed) + assertEquals(s.T(), msgs[1], logindone) + assertEquals(s.T(), msgs[2], backupdone) + + backupfailed := &model.Message{ApplicationID: backupServer.ID, Message: "backup failed", Title: "backup", Priority: 1, Date: time.Now()} + s.db.CreateMessage(backupfailed) + assert.NotEqual(s.T(), 0, backupfailed.ID) + + msgs = s.db.GetMessagesByUser(user.ID) + assert.Len(s.T(), msgs, 4) + assertEquals(s.T(), msgs[0], backupfailed) + assertEquals(s.T(), msgs[1], loginfailed) + assertEquals(s.T(), msgs[2], logindone) + assertEquals(s.T(), msgs[3], backupdone) + + msgs = s.db.GetMessagesByApplication(loginServer.ID) + assert.Len(s.T(), msgs, 2) + assertEquals(s.T(), msgs[0], loginfailed) + assertEquals(s.T(), msgs[1], logindone) + + s.db.DeleteMessagesByApplication(loginServer.ID) + assert.Empty(s.T(), s.db.GetMessagesByApplication(loginServer.ID)) + + msgs = s.db.GetMessagesByUser(user.ID) + assert.Len(s.T(), msgs, 2) + assertEquals(s.T(), msgs[0], backupfailed) + assertEquals(s.T(), msgs[1], backupdone) + + logindone = &model.Message{ApplicationID: loginServer.ID, Message: "login done", Title: "login", Priority: 1, Date: time.Now()} + s.db.CreateMessage(logindone) + assert.NotEqual(s.T(), 0, logindone.ID) + + msgs = s.db.GetMessagesByUser(user.ID) + assert.Len(s.T(), msgs, 3) + assertEquals(s.T(), msgs[0], logindone) + assertEquals(s.T(), msgs[1], backupfailed) + assertEquals(s.T(), msgs[2], backupdone) + + s.db.DeleteMessagesByUser(user.ID) + assert.Empty(s.T(), s.db.GetMessagesByUser(user.ID)) + + logout := &model.Message{ApplicationID: loginServer.ID, Message: "logout success", Title: "logout", Priority: 1, Date: time.Now()} + s.db.CreateMessage(logout) + msgs = s.db.GetMessagesByUser(user.ID) + assert.Len(s.T(), msgs, 1) + assertEquals(s.T(), msgs[0], logout) + + s.db.DeleteMessageByID(logout.ID) + assert.Empty(s.T(), s.db.GetMessagesByUser(user.ID)) +} + +// assertEquals compares messages and correctly check dates +func assertEquals(t *testing.T, left *model.Message, right *model.Message) { + assert.Equal(t, left.Date.Unix(), right.Date.Unix()) + left.Date = right.Date + assert.Equal(t, left, right) +} diff --git a/database/user.go b/database/user.go new file mode 100644 index 0000000..b6dae74 --- /dev/null +++ b/database/user.go @@ -0,0 +1,47 @@ +package database + +import ( + "github.com/jmattheis/memo/model" +) + +// GetUserByName returns the user by the given name or nil. +func (d *GormDatabase) GetUserByName(name string) *model.User { + user := new(model.User) + d.DB.Where("name = ?", name).Find(user) + if user.Name == name { + return user + } + return nil +} + +// GetUserByID returns the user by the given id or nil. +func (d *GormDatabase) GetUserByID(id uint) *model.User { + user := new(model.User) + d.DB.Find(user, id) + if user.ID == id { + return user + } + return nil +} + +// GetUsers returns all users. +func (d *GormDatabase) GetUsers() []*model.User { + var users []*model.User + d.DB.Find(&users) + return users +} + +// DeleteUserByID deletes a user by its id. +func (d *GormDatabase) DeleteUserByID(id uint) error { + return d.DB.Where("id = ?", id).Delete(&model.User{}).Error +} + +// UpdateUser updates a user. +func (d *GormDatabase) UpdateUser(user *model.User) { + d.DB.Save(user) +} + +// CreateUser creates a user. +func (d *GormDatabase) CreateUser(user *model.User) error { + return d.DB.Create(user).Error +} diff --git a/database/user_test.go b/database/user_test.go new file mode 100644 index 0000000..7116b1c --- /dev/null +++ b/database/user_test.go @@ -0,0 +1,47 @@ +package database + +import ( + "github.com/jmattheis/memo/model" + "github.com/stretchr/testify/assert" +) + +func (s *DatabaseSuite) TestUser() { + assert.Nil(s.T(), s.db.GetUserByID(55), "not existing user") + assert.Nil(s.T(), s.db.GetUserByName("nicories"), "not existing user") + + jmattheis := s.db.GetUserByID(1) + assert.NotNil(s.T(), jmattheis, "on bootup the first user should be automatically created") + + users := s.db.GetUsers() + assert.Len(s.T(), users, 1) + assert.Contains(s.T(), users, jmattheis) + + nicories := &model.User{Name: "nicories", Pass: []byte{1, 2, 3, 4}, Admin: false} + s.db.CreateUser(nicories) + assert.NotEqual(s.T(), 0, nicories.ID, "on create user a new id should be assigned") + + assert.Equal(s.T(), nicories, s.db.GetUserByName("nicories")) + + users = s.db.GetUsers() + assert.Len(s.T(), users, 2) + assert.Contains(s.T(), users, jmattheis) + assert.Contains(s.T(), users, nicories) + + nicories.Name = "tom" + nicories.Pass = []byte{12} + nicories.Admin = true + s.db.UpdateUser(nicories) + tom := s.db.GetUserByID(nicories.ID) + assert.Equal(s.T(), &model.User{ID: nicories.ID, Name: "tom", Pass: []byte{12}, Admin: true}, tom) + users = s.db.GetUsers() + assert.Len(s.T(), users, 2) + + s.db.DeleteUserByID(tom.ID) + users = s.db.GetUsers() + assert.Len(s.T(), users, 1) + assert.Contains(s.T(), users, jmattheis) + + s.db.DeleteUserByID(jmattheis.ID) + users = s.db.GetUsers() + assert.Empty(s.T(), users) +}