diff options
Diffstat (limited to 'lib/sql')
| -rw-r--r-- | lib/sql/client.go | 199 |
1 files changed, 198 insertions, 1 deletions
diff --git a/lib/sql/client.go b/lib/sql/client.go index dd959cbb..82bb8c20 100644 --- a/lib/sql/client.go +++ b/lib/sql/client.go @@ -6,14 +6,29 @@ package sql import ( "database/sql" + "errors" "fmt" + "io/ioutil" + "log" + "net/http" + "os" + "path" + "path/filepath" + "sort" + "strings" +) + +const ( + sqlExtension = ".sql" + sqlComment = "--" + sqlTerminator = ';' ) // // Client provide a wrapper for generic database instance. // type Client struct { - DB *sql.DB + *sql.DB DriverName string TableNames []string // List of tables in database. } @@ -76,6 +91,188 @@ func (cl *Client) FetchTableNames() (tableNames []string, err error) { } // +// Migrate the database using list of SQL files inside a directory. +// Each SQL file in directory will be executed in alphabetical order based on +// the last state. +// +// The state of migration will be saved in table "_migration" including the +// SQL file name that has been executed and the timestamp. +// +func (cl *Client) Migrate(fs http.FileSystem) (err error) { + root, err := fs.Open("/") + if err != nil { + return fmt.Errorf("Migrate: %w", err) + } + + fis, err := root.Readdir(0) + if err != nil { + return fmt.Errorf("Migrate: %w", err) + } + + sort.SliceStable(fis, func(x, y int) bool { + return fis[x].Name() < fis[y].Name() + }) + + lastFile, err := cl.migrateInit() + if err != nil { + return fmt.Errorf("Migrate: %w", err) + } + + var x int + if len(lastFile) > 0 { + for ; x < len(fis); x++ { + if fis[x].Name() == lastFile { + break + } + } + if x == len(fis) { + x = 0 + } else { + x++ + } + } + for ; x < len(fis); x++ { + name := fis[x].Name() + + sqlRaw, err := loadSQL(fs, fis[x], name) + if err != nil { + return fmt.Errorf("Migrate %q: %w", name, err) + } + if len(sqlRaw) == 0 { + continue + } + + err = cl.migrateApply(name, sqlRaw) + if err != nil { + return fmt.Errorf("Migrate %q: %w", name, err) + } + } + return nil +} + +// +// migrateInit get the last file in table migration or if its not exist create +// the migration table. +// +func (cl *Client) migrateInit() (lastFile string, err error) { + lastFile, err = cl.migrateLastFile() + if err == nil { + return lastFile, nil + } + + err = cl.migrateCreateTable() + if err != nil { + return "", err + } + + return "", nil +} + +// +// migrateLastFile return the last finished migration or empty string if table +// migration does not exist. +// +func (cl *Client) migrateLastFile() (file string, err error) { + q := ` + SELECT filename + FROM _migration + ORDER BY filename DESC + LIMIT 1 + ` + + err = cl.DB.QueryRow(q).Scan(&file) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return "", err + } + + return file, nil +} + +func (cl *Client) migrateCreateTable() (err error) { + q := ` + CREATE TABLE _migration ( + filename VARCHAR(1024) + , applied_at TIMESTAMP DEFAULT NOW() + ); + ` + _, err = cl.DB.Exec(q) + if err != nil { + return fmt.Errorf("migrateCreateTable: %w", err) + } + return nil +} + +func (cl *Client) migrateApply(filename string, sqlRaw []byte) (err error) { + tx, err := cl.DB.Begin() + if err != nil { + return err + } + + _, err = tx.Exec(string(sqlRaw)) + if err == nil { + err = cl.migrateFinished(tx, filename) + } + if err != nil { + err2 := tx.Rollback() + if err2 != nil { + log.Printf("migrateApply %s: %s", filename, err2) + } + return fmt.Errorf("migrateApply: %w", err) + } + + err = tx.Commit() + if err != nil { + return fmt.Errorf("migrateApply: %w", err) + } + + return nil +} + +func (cl *Client) migrateFinished(tx *sql.Tx, file string) (err error) { + var q string + + switch cl.DriverName { + case DriverNamePostgres: + q = `INSERT INTO _migration (filename) VALUES ($1)` + case DriverNameMysql: + q = `INSERT INTO _migration (filename) VALUES (?)` + } + + _, err = tx.Exec(q, file) + if err != nil { + return err + } + + return nil +} + +func loadSQL(fs http.FileSystem, fi os.FileInfo, filename string) ( + sqlRaw []byte, err error, +) { + if strings.ToLower(filepath.Ext(filename)) != sqlExtension { + return nil, nil + } + + if !fi.Mode().IsRegular() { + return nil, nil + } + + fileSQL := path.Join("/", filename) + + file, err := fs.Open(fileSQL) + if err != nil { + return nil, err + } + + sqlRaw, err = ioutil.ReadAll(file) + if err != nil { + return nil, err + } + + return sqlRaw, nil +} + +// // TruncateTable truncate all data on table `tableName`. // func (cl *Client) TruncateTable(tableName string) (err error) { |
