diff options
| author | Jonathan Amsterdam <jba@google.com> | 2021-03-24 19:42:09 -0400 |
|---|---|---|
| committer | Jonathan Amsterdam <jba@google.com> | 2021-03-25 11:02:39 +0000 |
| commit | b80d4dd39a545d302609dbc52e3dc3e7a6fc18ae (patch) | |
| tree | e287f10e4c2ce53a7b980d38c972aa2515ad75b1 /internal/database/database.go | |
| parent | 030d6bd4c085ea8d1dc4e94ce4d5e69736df9114 (diff) | |
| download | go-x-pkgsite-b80d4dd39a545d302609dbc52e3dc3e7a6fc18ae.tar.xz | |
internal/database: add CopyUpsert
Add the CopyUpsert method, which uses an efficient
Postgres protocol to insert rows.
For this to work, we need the connection underlying a sql.Tx value.
Since sql.Tx doesn't expose its connection, we create one explicitly
in DB.transact.
Change-Id: Ie48ce7a4318f4531d4756f779943188a6f0fb6cd
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/304631
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
Diffstat (limited to 'internal/database/database.go')
| -rw-r--r-- | internal/database/database.go | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/internal/database/database.go b/internal/database/database.go index 96e858f4..312f83af 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -33,6 +33,7 @@ type DB struct { db *sql.DB instanceID string tx *sql.Tx + conn *sql.Conn // the Conn of the Tx, when tx != nil opts sql.TxOptions // valid when tx != nil mu sync.Mutex maxRetries int // max times a single transaction was retried @@ -237,9 +238,15 @@ func (db *DB) transact(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB if db.InTransaction() { return errors.New("a DB Transact function was called on a DB already in a transaction") } - tx, err := db.db.BeginTx(ctx, opts) + conn, err := db.db.Conn(ctx) if err != nil { - return fmt.Errorf("db.BeginTx(): %w", err) + return err + } + defer conn.Close() + + tx, err := conn.BeginTx(ctx, opts) + if err != nil { + return fmt.Errorf("conn.BeginTx(): %w", err) } defer func() { if p := recover(); p != nil { @@ -256,6 +263,7 @@ func (db *DB) transact(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB dbtx := New(db.db, db.instanceID) dbtx.tx = tx + dbtx.conn = conn dbtx.opts = *opts defer dbtx.logTransaction(ctx)(&err) if err := txFunc(dbtx); err != nil { |
