1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
|
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package database adds some useful functionality to a sql.DB.
// It is independent of the database driver and the
// DB schema.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"regexp"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/lib/pq"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/log"
)
// DB wraps a sql.DB. The methods it exports correspond closely to those of
// sql.DB. They enhance the original by requiring a context argument, and by
// logging the query and any resulting errors.
//
// A DB may represent a transaction. If so, its execution and query methods
// operate within the transaction.
type DB struct {
db *sql.DB
instanceID string
tx *sql.Tx
conn *sql.Conn // the Conn of the Tx, when tx != nil
opts sql.TxOptions // valid when tx != nil
mu sync.Mutex
maxRetries int // max times a single transaction was retried
}
// Open creates a new DB for the given connection string.
func Open(driverName, dbinfo, instanceID string) (_ *DB, err error) {
defer derrors.Wrap(&err, "database.Open(%q, %q)",
driverName, redactPassword(dbinfo))
db, err := sql.Open(driverName, dbinfo)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, err
}
return New(db, instanceID), nil
}
// New creates a new DB from a sql.DB.
func New(db *sql.DB, instanceID string) *DB {
return &DB{db: db, instanceID: instanceID}
}
// Underlying returns the underlying *sql.DB.
func (db *DB) Underlying() *sql.DB {
return db.db
}
// SetPoolSettings sets the connection pool settings for the database.
func (db *DB) SetPoolSettings(maxOpen, maxIdle int, maxLifetime, maxIdleTime time.Duration) {
if maxOpen > 0 && maxIdle > maxOpen {
log.Warningf(context.Background(), "SetPoolSettings: maxIdle (%d) > maxOpen (%d); capping maxIdle to maxOpen", maxIdle, maxOpen)
maxIdle = maxOpen
}
db.db.SetMaxOpenConns(maxOpen)
db.db.SetMaxIdleConns(maxIdle)
db.db.SetConnMaxLifetime(maxLifetime)
db.db.SetConnMaxIdleTime(maxIdleTime)
}
func (db *DB) Ping() error {
return db.db.Ping()
}
func (db *DB) InTransaction() bool {
return db.tx != nil
}
func (db *DB) IsRetryable() bool {
return db.tx != nil && isRetryable(db.opts.Isolation)
}
var passwordRegexp = regexp.MustCompile(`password=\S+`)
func redactPassword(dbinfo string) string {
return passwordRegexp.ReplaceAllLiteralString(dbinfo, "password=REDACTED")
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.db.Close()
}
// Exec executes a SQL statement and returns the number of rows it affected.
func (db *DB) Exec(ctx context.Context, query string, args ...any) (_ int64, err error) {
defer logQuery(ctx, query, args, db.instanceID, db.IsRetryable())(&err)
res, err := db.execResult(ctx, query, args...)
if err != nil {
return 0, err
}
n, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("RowsAffected: %v", err)
}
return n, nil
}
// execResult executes a SQL statement and returns a sql.Result.
func (db *DB) execResult(ctx context.Context, query string, args ...any) (res sql.Result, err error) {
if db.tx != nil {
return db.tx.ExecContext(ctx, query, args...)
}
return db.db.ExecContext(ctx, query, args...)
}
// Query runs the DB query.
func (db *DB) Query(ctx context.Context, query string, args ...any) (_ *sql.Rows, err error) {
defer logQuery(ctx, query, args, db.instanceID, db.IsRetryable())(&err)
if db.tx != nil {
return db.tx.QueryContext(ctx, query, args...)
}
return db.db.QueryContext(ctx, query, args...)
}
// QueryRow runs the query and returns a single row.
func (db *DB) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
defer logQuery(ctx, query, args, db.instanceID, db.IsRetryable())(nil)
start := time.Now()
defer func() {
if ctx.Err() != nil {
d, _ := ctx.Deadline()
msg := fmt.Sprintf("args=%v; elapsed=%q, start=%q, deadline=%q", args, time.Since(start), start, d)
log.Errorf(ctx, "QueryRow context error: %v "+msg, ctx.Err())
}
}()
if db.tx != nil {
return db.tx.QueryRowContext(ctx, query, args...)
}
return db.db.QueryRowContext(ctx, query, args...)
}
func (db *DB) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
defer logQuery(ctx, "preparing "+query, nil, db.instanceID, db.IsRetryable())
if db.tx != nil {
return db.tx.PrepareContext(ctx, query)
}
return db.db.PrepareContext(ctx, query)
}
// RunQuery executes query, then calls f on each row. It stops when there are no
// more rows or f returns a non-nil error.
func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) error, params ...any) error {
rows, err := db.Query(ctx, query, params...)
if err != nil {
return err
}
_, err = processRows(rows, f)
return err
}
func processRows(rows *sql.Rows, f func(*sql.Rows) error) (int, error) {
defer rows.Close()
n := 0
for rows.Next() {
n++
if err := f(rows); err != nil {
return n, err
}
}
return n, rows.Err()
}
// RunQueryIncrementally executes query, then calls f on each row. It fetches
// rows in groups of size batchSize. It stops when there are no more rows, or
// when f returns io.EOF.
func (db *DB) RunQueryIncrementally(ctx context.Context, query string, batchSize int, f func(*sql.Rows) error, params ...any) (err error) {
// Run in a transaction, because cursors require one.
return db.Transact(ctx, sql.LevelDefault, func(tx *DB) error {
// Declare a cursor and associate it with the query.
// It will be closed when the transaction commits.
_, err = tx.Exec(ctx, fmt.Sprintf(`DECLARE c CURSOR FOR %s`, query), params...)
if err != nil {
return err
}
for {
// Fetch batchSize rows and process them.
rows, err := tx.Query(ctx, fmt.Sprintf(`FETCH %d FROM c`, batchSize))
if err != nil {
return err
}
n, err := processRows(rows, f)
// Stop if there were no rows, or the processing function returned io.EOF.
if n == 0 || err == io.EOF {
return nil
}
if err != nil {
return err
}
}
})
}
// Transact executes the given function in the context of a SQL transaction at
// the given isolation level, rolling back the transaction if the function
// panics or returns an error.
//
// The given function is called with a DB that is associated with a transaction.
// The DB should be used only inside the function; if it is used to access the
// database after the function returns, the calls will return errors.
//
// If the isolation level requires it, Transact will retry the transaction upon
// serialization failure, so txFunc may be called more than once.
func (db *DB) Transact(ctx context.Context, iso sql.IsolationLevel, txFunc func(*DB) error) (err error) {
defer derrors.Wrap(&err, "Transact(%s)", iso)
// For the levels which require retry, see
// https://www.postgresql.org/docs/11/transaction-iso.html.
opts := &sql.TxOptions{Isolation: iso}
if isRetryable(iso) {
return db.transactWithRetry(ctx, opts, txFunc)
}
return db.transact(ctx, opts, txFunc)
}
func isRetryable(iso sql.IsolationLevel) bool {
return iso == sql.LevelRepeatableRead || iso == sql.LevelSerializable
}
// serializationFailureCode is the Postgres error code returned when a serializable
// transaction fails because it would violate serializability.
// See https://www.postgresql.org/docs/current/errcodes-appendix.html.
const serializationFailureCode = "40001"
func (db *DB) transactWithRetry(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB) error) (err error) {
defer derrors.Wrap(&err, "transactWithRetry(%v)", opts)
// Retry on serialization failure, up to some max.
// See https://www.postgresql.org/docs/11/transaction-iso.html.
const maxRetries = 20
sleepDur := 125 * time.Millisecond
for i := 0; i <= maxRetries; i++ {
err = db.transact(ctx, opts, txFunc)
if isSerializationFailure(err) {
db.mu.Lock()
if i > db.maxRetries {
db.maxRetries = i
}
db.mu.Unlock()
log.Debugf(ctx, "serialization failure; retrying after %s", sleepDur)
time.Sleep(sleepDur)
sleepDur *= 2
continue
}
if err != nil {
log.Debugf(ctx, "transactWithRetry: error type %T: %[1]v", err)
if strings.Contains(err.Error(), serializationFailureCode) {
return fmt.Errorf("error text has %q but not recognized as serialization failure: type %T, err %v",
serializationFailureCode, err, err)
}
}
if i > 0 {
log.Debugf(ctx, "retried serializable transaction %d time(s)", i)
}
return err
}
return fmt.Errorf("reached max number of tries due to serialization failure (%d)", maxRetries)
}
func isSerializationFailure(err error) bool {
// The underlying error type depends on the driver. Try both pq and pgx types.
var perr *pq.Error
if errors.As(err, &perr) && perr.Code == serializationFailureCode {
return true
}
var gerr *pgconn.PgError
if errors.As(err, &gerr) && gerr.Code == serializationFailureCode {
return true
}
return false
}
func (db *DB) transact(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB) error) (err error) {
if db.InTransaction() {
return errors.New("a DB Transact function was called on a DB already in a transaction")
}
conn, err := db.db.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()
tx, err := conn.BeginTx(ctx, opts)
if err != nil {
return fmt.Errorf("conn.BeginTx(): %w", err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
tx.Rollback()
} else {
if txErr := tx.Commit(); txErr != nil {
err = fmt.Errorf("tx.Commit(): %w", txErr)
}
}
}()
dbtx := New(db.db, db.instanceID)
dbtx.tx = tx
dbtx.conn = conn
dbtx.opts = *opts
defer dbtx.logTransaction(ctx)(&err)
if err := txFunc(dbtx); err != nil {
return fmt.Errorf("txFunc(tx): %w", err)
}
return nil
}
// MaxRetries returns the maximum number of times thata serializable transaction was retried.
func (db *DB) MaxRetries() int {
db.mu.Lock()
defer db.mu.Unlock()
return db.maxRetries
}
const OnConflictDoNothing = "ON CONFLICT DO NOTHING"
// BulkInsert constructs and executes a multi-value insert statement. The
// query is constructed using the format:
//
// INSERT INTO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>)
//
// If conflictAction is not empty, it is appended to the statement.
//
// The query is executed using a PREPARE statement with the provided values.
func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, values []any, conflictAction string) (err error) {
defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)",
table, columns, len(values), conflictAction)
return db.bulkInsert(ctx, table, columns, nil, values, conflictAction, nil)
}
// BulkInsertReturning is like BulkInsert, but supports returning values from the INSERT statement.
// In addition to the arguments of BulkInsert, it takes a list of columns to return and a function
// to scan those columns. To get the returned values, provide a function that scans them as if
// they were the selected columns of a query. See TestBulkInsert for an example.
func (db *DB) BulkInsertReturning(ctx context.Context, table string, columns []string, values []any, conflictAction string, returningColumns []string, scanFunc func(*sql.Rows) error) (err error) {
defer derrors.Wrap(&err, "DB.BulkInsertReturning(ctx, %q, %v, [%d values], %q, %v, scanFunc)",
table, columns, len(values), conflictAction, returningColumns)
if returningColumns == nil || scanFunc == nil {
return errors.New("need returningColumns and scan function")
}
return db.bulkInsert(ctx, table, columns, returningColumns, values, conflictAction, scanFunc)
}
// BulkUpsert is like BulkInsert, but instead of a conflict action, a list of
// conflicting columns is provided. An "ON CONFLICT (conflict_columns) DO
// UPDATE" clause is added to the statement, with assignments "c=excluded.c" for
// every column c.
func (db *DB) BulkUpsert(ctx context.Context, table string, columns []string, values []any, conflictColumns []string) error {
conflictAction := buildUpsertConflictAction(columns, conflictColumns)
return db.BulkInsert(ctx, table, columns, values, conflictAction)
}
// BulkUpsertReturning is like BulkInsertReturning, but performs an upsert like BulkUpsert.
func (db *DB) BulkUpsertReturning(ctx context.Context, table string, columns []string, values []any, conflictColumns, returningColumns []string, scanFunc func(*sql.Rows) error) error {
conflictAction := buildUpsertConflictAction(columns, conflictColumns)
return db.BulkInsertReturning(ctx, table, columns, values, conflictAction, returningColumns, scanFunc)
}
func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningColumns []string, values []any, conflictAction string, scanFunc func(*sql.Rows) error) (err error) {
if remainder := len(values) % len(columns); remainder != 0 {
return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder)
}
// Postgres supports up to 65535 parameters, but stop well before that
// so we don't construct humongous queries.
const maxParameters = 1000
stride := (maxParameters / len(columns)) * len(columns)
if stride == 0 {
// This is a pathological case (len(columns) > maxParameters), but we
// handle it cautiously.
return fmt.Errorf("too many columns to insert: %d", len(columns))
}
prepare := func(n int) (*sql.Stmt, error) {
return db.Prepare(ctx, buildInsertQuery(table, columns, returningColumns, n, conflictAction))
}
var stmt *sql.Stmt
for leftBound := 0; leftBound < len(values); leftBound += stride {
rightBound := leftBound + stride
if rightBound <= len(values) && stmt == nil {
stmt, err = prepare(stride)
if err != nil {
return err
}
defer stmt.Close()
} else if rightBound > len(values) {
rightBound = len(values)
stmt, err = prepare(rightBound - leftBound)
if err != nil {
return err
}
defer stmt.Close()
}
valueSlice := values[leftBound:rightBound]
var err error
if returningColumns == nil {
_, err = stmt.ExecContext(ctx, valueSlice...)
} else {
var rows *sql.Rows
rows, err = stmt.QueryContext(ctx, valueSlice...)
if err != nil {
return err
}
_, err = processRows(rows, scanFunc)
}
if err != nil {
return fmt.Errorf("running bulk insert query, values[%d:%d]): %w", leftBound, rightBound, err)
}
}
return nil
}
// buildInsertQuery builds an multi-value insert query, following the format:
// INSERT TO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>) <conflictAction>
// If returningColumns is not empty, it appends a RETURNING clause to the query.
//
// When calling buildInsertQuery, it must be true that nvalues % len(columns) == 0.
func buildInsertQuery(table string, columns, returningColumns []string, nvalues int, conflictAction string) string {
var b strings.Builder
fmt.Fprintf(&b, "INSERT INTO %s", table)
fmt.Fprintf(&b, "(%s) VALUES", strings.Join(columns, ", "))
var placeholders []string
for i := 1; i <= nvalues; i++ {
// Construct the full query by adding placeholders for each
// set of values that we want to insert.
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
if i%len(columns) != 0 {
continue
}
// When the end of a set is reached, write it to the query
// builder and reset placeholders.
fmt.Fprintf(&b, "(%s)", strings.Join(placeholders, ", "))
placeholders = nil
// Do not add a comma delimiter after the last set of values.
if i == nvalues {
break
}
b.WriteString(", ")
}
if conflictAction != "" {
b.WriteString(" " + conflictAction)
}
if len(returningColumns) > 0 {
fmt.Fprintf(&b, " RETURNING %s", strings.Join(returningColumns, ", "))
}
return b.String()
}
func buildUpsertConflictAction(columns, conflictColumns []string) string {
var sets []string
for _, c := range columns {
sets = append(sets, fmt.Sprintf("%s=excluded.%[1]s", c))
}
return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s",
strings.Join(conflictColumns, ", "),
strings.Join(sets, ", "))
}
// maxBulkUpdateArrayLen is the maximum size of an array that BulkUpdate will send to
// Postgres. (Postgres has no size limit on arrays, but we want to keep the statements
// to a reasonable size.)
// It is a variable for testing.
var maxBulkUpdateArrayLen = 10000
// BulkUpdate executes multiple UPDATE statements in a transaction.
//
// Columns must contain the names of some of table's columns. The first is treated
// as a key; that is, the values to update are matched with existing rows by comparing
// the values of the first column.
//
// Types holds the database type of each column. For example,
//
// []string{"INT", "TEXT"}
//
// Values contains one slice of values per column. (Note that this is unlike BulkInsert, which
// takes a single slice of interleaved values.)
func (db *DB) BulkUpdate(ctx context.Context, table string, columns, types []string, values [][]any) (err error) {
defer derrors.Wrap(&err, "DB.BulkUpdate(ctx, tx, %q, %v, [%d values])",
table, columns, len(values))
if len(columns) < 2 {
return errors.New("need at least two columns")
}
if len(columns) != len(values) {
return errors.New("len(values) != len(columns)")
}
nRows := len(values[0])
for _, v := range values[1:] {
if len(v) != nRows {
return errors.New("all values slices must be the same length")
}
}
query := buildBulkUpdateQuery(table, columns, types)
for left := 0; left < nRows; left += maxBulkUpdateArrayLen {
right := min(left+maxBulkUpdateArrayLen, nRows)
var args []any
for _, vs := range values {
args = append(args, pq.Array(vs[left:right]))
}
if _, err := db.Exec(ctx, query, args...); err != nil {
return fmt.Errorf("db.Exec(%q, values[%d:%d]): %w", query, left, right, err)
}
}
return nil
}
func buildBulkUpdateQuery(table string, columns, types []string) string {
var sets, unnests []string
// Build "c = data.c" for each non-key column.
for _, c := range columns[1:] {
sets = append(sets, fmt.Sprintf("%s = data.%[1]s", c))
}
// Build "UNNEST($1::TYPE) AS c" for each column.
// We need the type, or Postgres complains that UNNEST is not unique.
for i, c := range columns {
unnests = append(unnests, fmt.Sprintf("UNNEST($%d::%s[]) AS %s", i+1, types[i], c))
}
return fmt.Sprintf(`
UPDATE %[1]s
SET %[2]s
FROM (SELECT %[3]s) AS data
WHERE %[1]s.%[4]s = data.%[4]s`,
table, // 1
strings.Join(sets, ", "), // 2
strings.Join(unnests, ", "), // 3
columns[0], // 4
)
}
// Collect1 runs the query, which must select for a single column that can be
// scanned into a value of type T, and returns a slice of the resulting values.
func Collect1[T any](ctx context.Context, db *DB, query string, args ...any) (ts []T, err error) {
defer derrors.WrapStack(&err, "Collect1(%q)", query)
err = db.RunQuery(ctx, query, func(rows *sql.Rows) error {
var t T
if err := rows.Scan(&t); err != nil {
return err
}
ts = append(ts, t)
return nil
}, args...)
if err != nil {
return nil, err
}
return ts, nil
}
// emptyStringScanner wraps the functionality of sql.NullString to just write
// an empty string if the value is NULL.
type emptyStringScanner struct {
ptr *string
}
func (e emptyStringScanner) Scan(value any) error {
var ns sql.NullString
if err := ns.Scan(value); err != nil {
return err
}
*e.ptr = ns.String
return nil
}
// NullIsEmpty returns a sql.Scanner that writes the empty string to s if the
// sql.Value is NULL.
func NullIsEmpty(s *string) sql.Scanner {
return emptyStringScanner{s}
}
|