diff --git a/database/database.go b/database/database.go index d658d0b..7bae41d 100644 --- a/database/database.go +++ b/database/database.go @@ -1,6 +1,9 @@ package database import ( + "os" + "path/filepath" + "github.com/gotify/server/auth/password" "github.com/gotify/server/model" "github.com/jinzhu/gorm" @@ -11,6 +14,8 @@ import ( // New creates a new wrapper for the gorm database framework. func New(dialect, connection, defaultUser, defaultPass string, strength int, createDefaultUser bool) (*GormDatabase, error) { + createDirectoryIfSqlite(dialect, connection) + db, err := gorm.Open(dialect, connection) if err != nil { return nil, err @@ -39,6 +44,16 @@ func New(dialect, connection, defaultUser, defaultPass string, strength int, cre return &GormDatabase{DB: db}, nil } +func createDirectoryIfSqlite(dialect string, connection string) { + if dialect == "sqlite3" { + if _, err := os.Stat(filepath.Dir(connection)); os.IsNotExist(err) { + if err := os.MkdirAll(filepath.Dir(connection), 0777); err != nil { + panic(err) + } + } + } +} + // GormDatabase is a wrapper for the gorm framework. type GormDatabase struct { DB *gorm.DB diff --git a/database/database_test.go b/database/database_test.go index 562b212..4b37491 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -4,6 +4,9 @@ import ( "os" "testing" + "errors" + + "github.com/bouk/monkey" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -32,3 +35,39 @@ func TestInvalidDialect(t *testing.T) { _, err := New("asdf", "testdb.db", "defaultUser", "defaultPass", 5, true) assert.NotNil(t, err) } + +func TestCreateSqliteFolder(t *testing.T) { + // ensure path not exists + os.RemoveAll("somepath") + + db, err := New("sqlite3", "somepath/testdb.db", "defaultUser", "defaultPass", 5, true) + assert.Nil(t, err) + assert.DirExists(t, "somepath") + db.Close() + + assert.Nil(t, os.RemoveAll("somepath")) +} + +func TestWithAlreadyExistingSqliteFolder(t *testing.T) { + // ensure path not exists + os.RemoveAll("somepath") + os.MkdirAll("somepath", 0777) + + db, err := New("sqlite3", "somepath/testdb.db", "defaultUser", "defaultPass", 5, true) + assert.Nil(t, err) + assert.DirExists(t, "somepath") + db.Close() + + assert.Nil(t, os.RemoveAll("somepath")) +} + +func TestPanicsOnMkdirError(t *testing.T) { + patch := monkey.Patch(os.MkdirAll, func(string, os.FileMode) error { return errors.New("whoops") }) + defer patch.Unpatch() + // ensure path not exists + os.RemoveAll("somepath") + + assert.Panics(t, func() { + New("sqlite3", "somepath/testdb.db", "defaultUser", "defaultPass", 5, true) + }) +}