aboutsummaryrefslogtreecommitdiff
path: root/internal/database/database.go
diff options
context:
space:
mode:
authorJonathan Amsterdam <jba@google.com>2020-04-13 14:55:37 -0400
committerJonathan Amsterdam <jba@google.com>2020-04-14 15:38:30 +0000
commit86348f6125d8422c830a3ca3c67f247962fbea1c (patch)
tree8513c06ec40404213be6b38b4629ba476d991a13 /internal/database/database.go
parent90861f00a20f0c21fabcc46ff145201279e67da6 (diff)
downloadgo-x-pkgsite-86348f6125d8422c830a3ca3c67f247962fbea1c.tar.xz
internal/database: use the DB for inside a transaction as well as out
The database.DB type now can represent a DB connection in the middle of a transaction. Such a DB is created only by calling DB.Transact. The resulting API is much simpler, since all the ...Tx methods disappear. Change-Id: I41afada87738e1eacdec2fcf115902edddeff867 Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/716719 Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
-rw-r--r--internal/database/database.go64
1 files changed, 34 insertions, 30 deletions
diff --git a/internal/database/database.go b/internal/database/database.go
index 73cc6412..31e3b59a 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -10,6 +10,7 @@ package database
import (
"context"
"database/sql"
+ "errors"
"fmt"
"regexp"
"strings"
@@ -25,8 +26,12 @@ import (
// DB wraps a sql.DB. The methods it exports correspond closely to those of
// sql.DB. They enhance the original by requiring a context argument, and by
// logging the query and any resulting errors.
+//
+// A DB may represent a transaction. If so, its execution and query methods
+// operate within the transaction.
type DB struct {
db *sql.DB
+ tx *sql.Tx
}
// Open creates a new DB for the given connection string.
@@ -41,12 +46,16 @@ func Open(driverName, dbinfo string) (_ *DB, err error) {
if err = db.Ping(); err != nil {
return nil, err
}
- return &DB{db}, nil
+ return New(db), nil
}
// New creates a new DB from a sql.DB.
func New(db *sql.DB) *DB {
- return &DB{db}
+ return &DB{db: db}
+}
+
+func (db *DB) InTransaction() bool {
+ return db.tx != nil
}
var passwordRegexp = regexp.MustCompile(`password=\S+`)
@@ -64,25 +73,27 @@ func (db *DB) Close() error {
func (db *DB) Exec(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) {
defer logQuery(ctx, query, args)(&err)
+ if db.tx != nil {
+ return db.tx.ExecContext(ctx, query, args...)
+ }
return db.db.ExecContext(ctx, query, args...)
}
-// ExecTx runs a statement in a transaction.
-func ExecTx(ctx context.Context, tx *sql.Tx, query string, args ...interface{}) (res sql.Result, err error) {
- defer logQuery(ctx, query, args)(&err)
-
- return tx.ExecContext(ctx, query, args...)
-}
-
// Query runs the DB query.
func (db *DB) Query(ctx context.Context, query string, args ...interface{}) (_ *sql.Rows, err error) {
defer logQuery(ctx, query, args)(&err)
+ if db.tx != nil {
+ return db.tx.QueryContext(ctx, query, args...)
+ }
return db.db.QueryContext(ctx, query, args...)
}
// QueryRow runs the query and returns a single row.
func (db *DB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row {
defer logQuery(ctx, query, args)(nil)
+ if db.tx != nil {
+ return db.tx.QueryRowContext(ctx, query, args...)
+ }
return db.db.QueryRowContext(ctx, query, args...)
}
@@ -92,21 +103,6 @@ func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) erro
if err != nil {
return err
}
- return processRows(rows, f)
-}
-
-// RunQueryTx is like RunQuery, but runs the query inside a transaction.
-func RunQueryTx(ctx context.Context, tx *sql.Tx, query string, f func(*sql.Rows) error, args ...interface{}) (err error) {
- defer logQuery(ctx, query, args)(&err)
- rows, err := tx.QueryContext(ctx, query, args...)
- if err != nil {
- return err
- }
- return processRows(rows, f)
-}
-
-// processRows iterates through rows, calling f on each row.
-func processRows(rows *sql.Rows, f func(*sql.Rows) error) error {
defer rows.Close()
for rows.Next() {
if err := f(rows); err != nil {
@@ -118,12 +114,18 @@ func processRows(rows *sql.Rows, f func(*sql.Rows) error) error {
// Transact executes the given function in the context of a SQL transaction,
// rolling back the transaction if the function panics or returns an error.
-func (db *DB) Transact(txFunc func(*sql.Tx) error) (err error) {
+//
+// The given function is called with a DB that is associated with a transaction.
+// The DB should be used only inside the function; if it is used to access the
+// database after the function returns, the calls will return errors.
+func (db *DB) Transact(txFunc func(*DB) error) (err error) {
+ if db.InTransaction() {
+ return errors.New("DB.Transact called on a DB already in a transaction")
+ }
tx, err := db.db.Begin()
if err != nil {
return fmt.Errorf("db.Begin(): %v", err)
}
-
defer func() {
if p := recover(); p != nil {
tx.Rollback()
@@ -137,7 +139,9 @@ func (db *DB) Transact(txFunc func(*sql.Tx) error) (err error) {
}
}()
- if err := txFunc(tx); err != nil {
+ dbtx := New(db.db)
+ dbtx.tx = tx
+ if err := txFunc(dbtx); err != nil {
return fmt.Errorf("txFunc(tx): %v", err)
}
return nil
@@ -150,8 +154,8 @@ const OnConflictDoNothing = "ON CONFLICT DO NOTHING"
// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it
// append ON CONFLICT DO NOTHING to the end of the query. The query is executed
// using a PREPARE statement with the provided values.
-func BulkInsert(ctx context.Context, tx *sql.Tx, table string, columns []string, values []interface{}, conflictAction string) (err error) {
- defer derrors.Wrap(&err, "bulkInsert(ctx, tx, %q, %v, [%d values], %q)",
+func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, values []interface{}, conflictAction string) (err error) {
+ defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)",
table, columns, len(values), conflictAction)
if remainder := len(values) % len(columns); remainder != 0 {
@@ -174,7 +178,7 @@ func BulkInsert(ctx context.Context, tx *sql.Tx, table string, columns []string,
}
valueSlice := values[leftBound:rightBound]
query := buildInsertQuery(table, columns, valueSlice, conflictAction)
- if _, err := ExecTx(ctx, tx, query, valueSlice...); err != nil {
+ if _, err := db.Exec(ctx, query, valueSlice...); err != nil {
return fmt.Errorf("tx.ExecContext(ctx, [bulk insert query], values[%d:%d]): %v", leftBound, rightBound, err)
}
}