diff options
| author | Jonathan Amsterdam <jba@google.com> | 2020-04-13 14:55:37 -0400 |
|---|---|---|
| committer | Jonathan Amsterdam <jba@google.com> | 2020-04-14 15:38:30 +0000 |
| commit | 86348f6125d8422c830a3ca3c67f247962fbea1c (patch) | |
| tree | 8513c06ec40404213be6b38b4629ba476d991a13 /internal/database/database.go | |
| parent | 90861f00a20f0c21fabcc46ff145201279e67da6 (diff) | |
| download | go-x-pkgsite-86348f6125d8422c830a3ca3c67f247962fbea1c.tar.xz | |
internal/database: use the DB for inside a transaction as well as out
The database.DB type now can represent a DB connection in the
middle of a transaction. Such a DB is created only by calling
DB.Transact.
The resulting API is much simpler, since all the ...Tx methods
disappear.
Change-Id: I41afada87738e1eacdec2fcf115902edddeff867
Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/716719
Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
| -rw-r--r-- | internal/database/database.go | 64 |
1 files changed, 34 insertions, 30 deletions
diff --git a/internal/database/database.go b/internal/database/database.go index 73cc6412..31e3b59a 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -10,6 +10,7 @@ package database import ( "context" "database/sql" + "errors" "fmt" "regexp" "strings" @@ -25,8 +26,12 @@ import ( // DB wraps a sql.DB. The methods it exports correspond closely to those of // sql.DB. They enhance the original by requiring a context argument, and by // logging the query and any resulting errors. +// +// A DB may represent a transaction. If so, its execution and query methods +// operate within the transaction. type DB struct { db *sql.DB + tx *sql.Tx } // Open creates a new DB for the given connection string. @@ -41,12 +46,16 @@ func Open(driverName, dbinfo string) (_ *DB, err error) { if err = db.Ping(); err != nil { return nil, err } - return &DB{db}, nil + return New(db), nil } // New creates a new DB from a sql.DB. func New(db *sql.DB) *DB { - return &DB{db} + return &DB{db: db} +} + +func (db *DB) InTransaction() bool { + return db.tx != nil } var passwordRegexp = regexp.MustCompile(`password=\S+`) @@ -64,25 +73,27 @@ func (db *DB) Close() error { func (db *DB) Exec(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) { defer logQuery(ctx, query, args)(&err) + if db.tx != nil { + return db.tx.ExecContext(ctx, query, args...) + } return db.db.ExecContext(ctx, query, args...) } -// ExecTx runs a statement in a transaction. -func ExecTx(ctx context.Context, tx *sql.Tx, query string, args ...interface{}) (res sql.Result, err error) { - defer logQuery(ctx, query, args)(&err) - - return tx.ExecContext(ctx, query, args...) -} - // Query runs the DB query. func (db *DB) Query(ctx context.Context, query string, args ...interface{}) (_ *sql.Rows, err error) { defer logQuery(ctx, query, args)(&err) + if db.tx != nil { + return db.tx.QueryContext(ctx, query, args...) + } return db.db.QueryContext(ctx, query, args...) } // QueryRow runs the query and returns a single row. func (db *DB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { defer logQuery(ctx, query, args)(nil) + if db.tx != nil { + return db.tx.QueryRowContext(ctx, query, args...) + } return db.db.QueryRowContext(ctx, query, args...) } @@ -92,21 +103,6 @@ func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) erro if err != nil { return err } - return processRows(rows, f) -} - -// RunQueryTx is like RunQuery, but runs the query inside a transaction. -func RunQueryTx(ctx context.Context, tx *sql.Tx, query string, f func(*sql.Rows) error, args ...interface{}) (err error) { - defer logQuery(ctx, query, args)(&err) - rows, err := tx.QueryContext(ctx, query, args...) - if err != nil { - return err - } - return processRows(rows, f) -} - -// processRows iterates through rows, calling f on each row. -func processRows(rows *sql.Rows, f func(*sql.Rows) error) error { defer rows.Close() for rows.Next() { if err := f(rows); err != nil { @@ -118,12 +114,18 @@ func processRows(rows *sql.Rows, f func(*sql.Rows) error) error { // Transact executes the given function in the context of a SQL transaction, // rolling back the transaction if the function panics or returns an error. -func (db *DB) Transact(txFunc func(*sql.Tx) error) (err error) { +// +// The given function is called with a DB that is associated with a transaction. +// The DB should be used only inside the function; if it is used to access the +// database after the function returns, the calls will return errors. +func (db *DB) Transact(txFunc func(*DB) error) (err error) { + if db.InTransaction() { + return errors.New("DB.Transact called on a DB already in a transaction") + } tx, err := db.db.Begin() if err != nil { return fmt.Errorf("db.Begin(): %v", err) } - defer func() { if p := recover(); p != nil { tx.Rollback() @@ -137,7 +139,9 @@ func (db *DB) Transact(txFunc func(*sql.Tx) error) (err error) { } }() - if err := txFunc(tx); err != nil { + dbtx := New(db.db) + dbtx.tx = tx + if err := txFunc(dbtx); err != nil { return fmt.Errorf("txFunc(tx): %v", err) } return nil @@ -150,8 +154,8 @@ const OnConflictDoNothing = "ON CONFLICT DO NOTHING" // (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it // append ON CONFLICT DO NOTHING to the end of the query. The query is executed // using a PREPARE statement with the provided values. -func BulkInsert(ctx context.Context, tx *sql.Tx, table string, columns []string, values []interface{}, conflictAction string) (err error) { - defer derrors.Wrap(&err, "bulkInsert(ctx, tx, %q, %v, [%d values], %q)", +func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, values []interface{}, conflictAction string) (err error) { + defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)", table, columns, len(values), conflictAction) if remainder := len(values) % len(columns); remainder != 0 { @@ -174,7 +178,7 @@ func BulkInsert(ctx context.Context, tx *sql.Tx, table string, columns []string, } valueSlice := values[leftBound:rightBound] query := buildInsertQuery(table, columns, valueSlice, conflictAction) - if _, err := ExecTx(ctx, tx, query, valueSlice...); err != nil { + if _, err := db.Exec(ctx, query, valueSlice...); err != nil { return fmt.Errorf("tx.ExecContext(ctx, [bulk insert query], values[%d:%d]): %v", leftBound, rightBound, err) } } |
