aboutsummaryrefslogtreecommitdiff
path: root/lib/sql/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sql/client.go')
-rw-r--r--lib/sql/client.go199
1 files changed, 198 insertions, 1 deletions
diff --git a/lib/sql/client.go b/lib/sql/client.go
index dd959cbb..82bb8c20 100644
--- a/lib/sql/client.go
+++ b/lib/sql/client.go
@@ -6,14 +6,29 @@ package sql
import (
"database/sql"
+ "errors"
"fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+ "sort"
+ "strings"
+)
+
+const (
+ sqlExtension = ".sql"
+ sqlComment = "--"
+ sqlTerminator = ';'
)
//
// Client provide a wrapper for generic database instance.
//
type Client struct {
- DB *sql.DB
+ *sql.DB
DriverName string
TableNames []string // List of tables in database.
}
@@ -76,6 +91,188 @@ func (cl *Client) FetchTableNames() (tableNames []string, err error) {
}
//
+// Migrate the database using list of SQL files inside a directory.
+// Each SQL file in directory will be executed in alphabetical order based on
+// the last state.
+//
+// The state of migration will be saved in table "_migration" including the
+// SQL file name that has been executed and the timestamp.
+//
+func (cl *Client) Migrate(fs http.FileSystem) (err error) {
+ root, err := fs.Open("/")
+ if err != nil {
+ return fmt.Errorf("Migrate: %w", err)
+ }
+
+ fis, err := root.Readdir(0)
+ if err != nil {
+ return fmt.Errorf("Migrate: %w", err)
+ }
+
+ sort.SliceStable(fis, func(x, y int) bool {
+ return fis[x].Name() < fis[y].Name()
+ })
+
+ lastFile, err := cl.migrateInit()
+ if err != nil {
+ return fmt.Errorf("Migrate: %w", err)
+ }
+
+ var x int
+ if len(lastFile) > 0 {
+ for ; x < len(fis); x++ {
+ if fis[x].Name() == lastFile {
+ break
+ }
+ }
+ if x == len(fis) {
+ x = 0
+ } else {
+ x++
+ }
+ }
+ for ; x < len(fis); x++ {
+ name := fis[x].Name()
+
+ sqlRaw, err := loadSQL(fs, fis[x], name)
+ if err != nil {
+ return fmt.Errorf("Migrate %q: %w", name, err)
+ }
+ if len(sqlRaw) == 0 {
+ continue
+ }
+
+ err = cl.migrateApply(name, sqlRaw)
+ if err != nil {
+ return fmt.Errorf("Migrate %q: %w", name, err)
+ }
+ }
+ return nil
+}
+
+//
+// migrateInit get the last file in table migration or if its not exist create
+// the migration table.
+//
+func (cl *Client) migrateInit() (lastFile string, err error) {
+ lastFile, err = cl.migrateLastFile()
+ if err == nil {
+ return lastFile, nil
+ }
+
+ err = cl.migrateCreateTable()
+ if err != nil {
+ return "", err
+ }
+
+ return "", nil
+}
+
+//
+// migrateLastFile return the last finished migration or empty string if table
+// migration does not exist.
+//
+func (cl *Client) migrateLastFile() (file string, err error) {
+ q := `
+ SELECT filename
+ FROM _migration
+ ORDER BY filename DESC
+ LIMIT 1
+ `
+
+ err = cl.DB.QueryRow(q).Scan(&file)
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return "", err
+ }
+
+ return file, nil
+}
+
+func (cl *Client) migrateCreateTable() (err error) {
+ q := `
+ CREATE TABLE _migration (
+ filename VARCHAR(1024)
+ , applied_at TIMESTAMP DEFAULT NOW()
+ );
+ `
+ _, err = cl.DB.Exec(q)
+ if err != nil {
+ return fmt.Errorf("migrateCreateTable: %w", err)
+ }
+ return nil
+}
+
+func (cl *Client) migrateApply(filename string, sqlRaw []byte) (err error) {
+ tx, err := cl.DB.Begin()
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.Exec(string(sqlRaw))
+ if err == nil {
+ err = cl.migrateFinished(tx, filename)
+ }
+ if err != nil {
+ err2 := tx.Rollback()
+ if err2 != nil {
+ log.Printf("migrateApply %s: %s", filename, err2)
+ }
+ return fmt.Errorf("migrateApply: %w", err)
+ }
+
+ err = tx.Commit()
+ if err != nil {
+ return fmt.Errorf("migrateApply: %w", err)
+ }
+
+ return nil
+}
+
+func (cl *Client) migrateFinished(tx *sql.Tx, file string) (err error) {
+ var q string
+
+ switch cl.DriverName {
+ case DriverNamePostgres:
+ q = `INSERT INTO _migration (filename) VALUES ($1)`
+ case DriverNameMysql:
+ q = `INSERT INTO _migration (filename) VALUES (?)`
+ }
+
+ _, err = tx.Exec(q, file)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func loadSQL(fs http.FileSystem, fi os.FileInfo, filename string) (
+ sqlRaw []byte, err error,
+) {
+ if strings.ToLower(filepath.Ext(filename)) != sqlExtension {
+ return nil, nil
+ }
+
+ if !fi.Mode().IsRegular() {
+ return nil, nil
+ }
+
+ fileSQL := path.Join("/", filename)
+
+ file, err := fs.Open(fileSQL)
+ if err != nil {
+ return nil, err
+ }
+
+ sqlRaw, err = ioutil.ReadAll(file)
+ if err != nil {
+ return nil, err
+ }
+
+ return sqlRaw, nil
+}
+
+//
// TruncateTable truncate all data on table `tableName`.
//
func (cl *Client) TruncateTable(tableName string) (err error) {