aboutsummaryrefslogtreecommitdiff
path: root/internal/database/database.go
diff options
context:
space:
mode:
authorJonathan Amsterdam <jba@google.com>2021-03-24 19:42:09 -0400
committerJonathan Amsterdam <jba@google.com>2021-03-25 11:02:39 +0000
commitb80d4dd39a545d302609dbc52e3dc3e7a6fc18ae (patch)
treee287f10e4c2ce53a7b980d38c972aa2515ad75b1 /internal/database/database.go
parent030d6bd4c085ea8d1dc4e94ce4d5e69736df9114 (diff)
downloadgo-x-pkgsite-b80d4dd39a545d302609dbc52e3dc3e7a6fc18ae.tar.xz
internal/database: add CopyUpsert
Add the CopyUpsert method, which uses an efficient Postgres protocol to insert rows. For this to work, we need the connection underlying a sql.Tx value. Since sql.Tx doesn't expose its connection, we create one explicitly in DB.transact. Change-Id: Ie48ce7a4318f4531d4756f779943188a6f0fb6cd Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/304631 Trust: Jonathan Amsterdam <jba@google.com> Run-TryBot: Jonathan Amsterdam <jba@google.com> TryBot-Result: kokoro <noreply+kokoro@google.com> Reviewed-by: Julie Qiu <julie@golang.org>
Diffstat (limited to 'internal/database/database.go')
-rw-r--r--internal/database/database.go12
1 files changed, 10 insertions, 2 deletions
diff --git a/internal/database/database.go b/internal/database/database.go
index 96e858f4..312f83af 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -33,6 +33,7 @@ type DB struct {
db *sql.DB
instanceID string
tx *sql.Tx
+ conn *sql.Conn // the Conn of the Tx, when tx != nil
opts sql.TxOptions // valid when tx != nil
mu sync.Mutex
maxRetries int // max times a single transaction was retried
@@ -237,9 +238,15 @@ func (db *DB) transact(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB
if db.InTransaction() {
return errors.New("a DB Transact function was called on a DB already in a transaction")
}
- tx, err := db.db.BeginTx(ctx, opts)
+ conn, err := db.db.Conn(ctx)
if err != nil {
- return fmt.Errorf("db.BeginTx(): %w", err)
+ return err
+ }
+ defer conn.Close()
+
+ tx, err := conn.BeginTx(ctx, opts)
+ if err != nil {
+ return fmt.Errorf("conn.BeginTx(): %w", err)
}
defer func() {
if p := recover(); p != nil {
@@ -256,6 +263,7 @@ func (db *DB) transact(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB
dbtx := New(db.db, db.instanceID)
dbtx.tx = tx
+ dbtx.conn = conn
dbtx.opts = *opts
defer dbtx.logTransaction(ctx)(&err)
if err := txFunc(dbtx); err != nil {