From 2425c08bc372809c28bea021cd9ed0de3b8eebb2 Mon Sep 17 00:00:00 2001 From: Shulhan Date: Thu, 13 Feb 2020 20:14:39 +0700 Subject: sql: a new package as an extension to "database/sql" --- lib/sql/client.go | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ lib/sql/row.go | 34 +++++++++++++++++++++ lib/sql/session.go | 24 +++++++++++++++ lib/sql/sql.go | 14 +++++++++ lib/sql/table.go | 55 ++++++++++++++++++++++++++++++++++ 5 files changed, 215 insertions(+) create mode 100644 lib/sql/client.go create mode 100644 lib/sql/row.go create mode 100644 lib/sql/session.go create mode 100644 lib/sql/sql.go create mode 100644 lib/sql/table.go (limited to 'lib/sql') diff --git a/lib/sql/client.go b/lib/sql/client.go new file mode 100644 index 00000000..dd959cbb --- /dev/null +++ b/lib/sql/client.go @@ -0,0 +1,88 @@ +// Copyright 2020, Shulhan . All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "database/sql" + "fmt" +) + +// +// Client provide a wrapper for generic database instance. +// +type Client struct { + DB *sql.DB + DriverName string + TableNames []string // List of tables in database. +} + +// +// New wrap a database client to provide additional methods. +// +func New(driverName string, db *sql.DB) (cl *Client, err error) { + cl = &Client{ + DB: db, + DriverName: driverName, + } + + return cl, nil +} + +// +// FetchTableNames return the table names in current database schema sorted +// in ascending order. +// +func (cl *Client) FetchTableNames() (tableNames []string, err error) { + var q, v string + + switch cl.DriverName { + case DriverNameMysql, DriverNamePostgres: + q = ` + SELECT + table_name + FROM + information_schema.tables + ORDER BY + table_name + ` + } + + rows, err := cl.DB.Query(q) + if err != nil { + return nil, fmt.Errorf("FetchTableNames: " + err.Error()) + } + + if len(cl.TableNames) > 0 { + cl.TableNames = cl.TableNames[:0] + } + + for rows.Next() { + err = rows.Scan(&v) + if err != nil { + _ = rows.Close() + return cl.TableNames, err + } + + cl.TableNames = append(cl.TableNames, v) + } + err = rows.Err() + if err != nil { + return nil, err + } + + return cl.TableNames, nil +} + +// +// TruncateTable truncate all data on table `tableName`. +// +func (cl *Client) TruncateTable(tableName string) (err error) { + q := `TRUNCATE TABLE ` + tableName + _, err = cl.DB.Exec(q) + if err != nil { + return fmt.Errorf("TruncateTable %q: %s", tableName, err) + } + return nil +} diff --git a/lib/sql/row.go b/lib/sql/row.go new file mode 100644 index 00000000..524c540f --- /dev/null +++ b/lib/sql/row.go @@ -0,0 +1,34 @@ +// Copyright 2020, Shulhan . All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +// +// Row represent a single row in table. +// +type Row map[string]interface{} + +// +// ExtractSQLFields extract the column's name, column place holder (default is +// "?"), and column values; as slices. +// +func (row Row) ExtractSQLFields() ( + names, holders []string, values []interface{}, +) { + if len(row) == 0 { + return nil, nil, nil + } + + names = make([]string, 0, len(row)) + holders = make([]string, 0, len(row)) + values = make([]interface{}, 0, len(row)) + + for k, v := range row { + names = append(names, k) + holders = append(holders, "?") + values = append(values, v) + } + + return names, holders, values +} diff --git a/lib/sql/session.go b/lib/sql/session.go new file mode 100644 index 00000000..dec4b0cd --- /dev/null +++ b/lib/sql/session.go @@ -0,0 +1,24 @@ +// Copyright 2020, Shulhan . All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "context" + "database/sql" +) + +// +// Session is an interface that represent both sql.DB and sql.Tx. +// +type Session interface { + Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} diff --git a/lib/sql/sql.go b/lib/sql/sql.go new file mode 100644 index 00000000..d2024ed5 --- /dev/null +++ b/lib/sql/sql.go @@ -0,0 +1,14 @@ +// Copyright 2020, Shulhan . All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +// Package sql is an extension to standard library "database/sql.DB" that +// provide common functionality across DBMS. +// +package sql + +const ( + DriverNameMysql = "mysql" + DriverNamePostgres = "postgres" +) diff --git a/lib/sql/table.go b/lib/sql/table.go new file mode 100644 index 00000000..3dbdb984 --- /dev/null +++ b/lib/sql/table.go @@ -0,0 +1,55 @@ +// Copyright 2020, Shulhan . All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "database/sql" + "fmt" + "strings" +) + +// +// Table represent a tuple or table in database. +// +// A table has Name, PrimaryKey, and list of Row. +// +type Table struct { + Name string // Table name, required. + PrimaryKey string // Primary key of table, optional. + Rows []Row // The row or data in the table, optional. +} + +// +// Insert all rows into table, one by one. +// +// On success, it will return list of ID, if table has primary key. +// +func (table *Table) Insert(tx *sql.Tx) (ids []int64, err error) { + for _, row := range table.Rows { + names, holders, values := row.ExtractSQLFields() + if len(names) == 0 { + continue + } + + //nolint: gosec + q := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", + table.Name, strings.Join(names, ","), + strings.Join(holders, ",")) + + res, err := tx.Exec(q, values...) + if err != nil { + return nil, err + } + + id, err := res.LastInsertId() + if err != nil { + continue + } + + ids = append(ids, id) + } + + return ids, nil +} -- cgit v1.3