aboutsummaryrefslogtreecommitdiff
path: root/internal/database/database.go
diff options
context:
space:
mode:
authorJonathan Amsterdam <jba@google.com>2020-03-20 15:46:37 -0400
committerJulie Qiu <julie@golang.org>2020-04-06 17:09:52 -0400
commit05893675c0420a4a558a833ca4a22b1aa499f314 (patch)
treed6db62ee820a613e709c0355464685b6a3e8d540 /internal/database/database.go
parent89fb626f90cefb9fb3ed02cbdebacd9655be17d6 (diff)
downloadgo-x-pkgsite-05893675c0420a4a558a833ca4a22b1aa499f314.tar.xz
internal/database: add RunQueryTx
Add a function to run a query inside a transaction. Use it in the couple of places it can be used. Change-Id: If9230f434633ff1b5c60087c8281039b533354c6 Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/696538 Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
-rw-r--r--internal/database/database.go18
1 files changed, 16 insertions, 2 deletions
diff --git a/internal/database/database.go b/internal/database/database.go
index 6f6a2be2..73cc6412 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -67,7 +67,7 @@ func (db *DB) Exec(ctx context.Context, query string, args ...interface{}) (res
return db.db.ExecContext(ctx, query, args...)
}
-// ExecTx runs a query in a transaction.
+// ExecTx runs a statement in a transaction.
func ExecTx(ctx context.Context, tx *sql.Tx, query string, args ...interface{}) (res sql.Result, err error) {
defer logQuery(ctx, query, args)(&err)
@@ -92,8 +92,22 @@ func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) erro
if err != nil {
return err
}
- defer rows.Close()
+ return processRows(rows, f)
+}
+// RunQueryTx is like RunQuery, but runs the query inside a transaction.
+func RunQueryTx(ctx context.Context, tx *sql.Tx, query string, f func(*sql.Rows) error, args ...interface{}) (err error) {
+ defer logQuery(ctx, query, args)(&err)
+ rows, err := tx.QueryContext(ctx, query, args...)
+ if err != nil {
+ return err
+ }
+ return processRows(rows, f)
+}
+
+// processRows iterates through rows, calling f on each row.
+func processRows(rows *sql.Rows, f func(*sql.Rows) error) error {
+ defer rows.Close()
for rows.Next() {
if err := f(rows); err != nil {
return err