diff options
| author | Jonathan Amsterdam <jba@google.com> | 2021-03-24 19:42:09 -0400 |
|---|---|---|
| committer | Jonathan Amsterdam <jba@google.com> | 2021-03-25 11:02:39 +0000 |
| commit | b80d4dd39a545d302609dbc52e3dc3e7a6fc18ae (patch) | |
| tree | e287f10e4c2ce53a7b980d38c972aa2515ad75b1 /internal/database | |
| parent | 030d6bd4c085ea8d1dc4e94ce4d5e69736df9114 (diff) | |
| download | go-x-pkgsite-b80d4dd39a545d302609dbc52e3dc3e7a6fc18ae.tar.xz | |
internal/database: add CopyUpsert
Add the CopyUpsert method, which uses an efficient
Postgres protocol to insert rows.
For this to work, we need the connection underlying a sql.Tx value.
Since sql.Tx doesn't expose its connection, we create one explicitly
in DB.transact.
Change-Id: Ie48ce7a4318f4531d4756f779943188a6f0fb6cd
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/304631
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
Diffstat (limited to 'internal/database')
| -rw-r--r-- | internal/database/copy.go | 114 | ||||
| -rw-r--r-- | internal/database/copy_test.go | 68 | ||||
| -rw-r--r-- | internal/database/database.go | 12 |
3 files changed, 192 insertions, 2 deletions
diff --git a/internal/database/copy.go b/internal/database/copy.go new file mode 100644 index 00000000..90dcc07d --- /dev/null +++ b/internal/database/copy.go @@ -0,0 +1,114 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package database + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" + "golang.org/x/pkgsite/internal/derrors" + "golang.org/x/pkgsite/internal/log" +) + +// CopyUpsert upserts rows into table using the pgx driver's CopyFrom method. +// It returns an error if the underlying driver is not pgx. +// columns is the list of columns to upsert. +// src is the source of the rows to upsert. +// conflictColumns are the columns that might conflict (i.e. that have a UNIQUE +// constraint). +// +// CopyUpsert works by first creating a temporary table, populating it with +// CopyFrom, and then running an INSERT...SELECT...ON CONFLICT to upsert its +// rows into the original table. +func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string) (err error) { + defer derrors.Wrap(&err, "CopyUpsert(%q)", table) + + if !db.InTransaction() { + return errors.New("not in a transaction") + } + + return db.conn.Raw(func(c interface{}) error { + if w, ok := c.(*wrapConn); ok { + c = w.underlying + } + stdConn, ok := c.(*stdlib.Conn) + if !ok { + return fmt.Errorf("DB driver is not pgx or wrapper; conn type is %T", c) + } + conn := stdConn.Conn() + tempTable := fmt.Sprintf("__%s_copy", table) + stmt := fmt.Sprintf(` + DROP TABLE IF EXISTS %s; + CREATE TEMP TABLE %[1]s AS SELECT * FROM %s LIMIT 0 + `, tempTable, table) + _, err = conn.Exec(ctx, stmt) + if err != nil { + return err + } + start := time.Now() + n, err := conn.CopyFrom(ctx, []string{tempTable}, columns, src) + if err != nil { + return err + } + log.Debugf(ctx, "CopyUpsert(%q): copied %d rows in %s", table, n, time.Since(start)) + conflictAction := buildUpsertConflictAction(columns, conflictColumns) + query := buildCopyUpsertQuery(table, tempTable, columns, conflictAction) + + defer logQuery(ctx, query, nil, db.instanceID, db.IsRetryable())(&err) + start = time.Now() + ctag, err := conn.Exec(ctx, query) + if err != nil { + return err + } + log.Debugf(ctx, "CopyUpsert(%q): upserted %d rows in %s", table, ctag.RowsAffected(), time.Since(start)) + return nil + }) +} + +func buildCopyUpsertQuery(table, tempTable string, columns []string, conflictAction string) string { + cols := strings.Join(columns, ", ") + return fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction) +} + +// A RowItem is a row of values or an error. +type RowItem struct { + Values []interface{} + Err error +} + +// CopyFromChan returns a CopyFromSource that gets its rows from a channel. +func CopyFromChan(c <-chan RowItem) pgx.CopyFromSource { + return &chanCopySource{c: c} +} + +type chanCopySource struct { + c <-chan RowItem + next RowItem +} + +// Next implements CopyFromSource.Next. +func (cs *chanCopySource) Next() bool { + if cs.next.Err != nil { + return false + } + var ok bool + cs.next, ok = <-cs.c + return ok +} + +// Values implements CopyFromSource.Values. +func (cs *chanCopySource) Values() ([]interface{}, error) { + return cs.next.Values, cs.next.Err +} + +// Err implements CopyFromSource.Err. +func (cs *chanCopySource) Err() error { + return cs.next.Err +} diff --git a/internal/database/copy_test.go b/internal/database/copy_test.go new file mode 100644 index 00000000..5c6f41b6 --- /dev/null +++ b/internal/database/copy_test.go @@ -0,0 +1,68 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package database + +import ( + "context" + "database/sql" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" +) + +func TestCopyUpsert(t *testing.T) { + ctx := context.Background() + conn, err := testDB.db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + conn.Raw(func(c interface{}) error { + if _, ok := c.(*stdlib.Conn); !ok { + t.Skip("skipping; DB driver not pgx") + } + return nil + }) + + for _, stmt := range []string{ + `DROP TABLE IF EXISTS test_streaming_upsert`, + `CREATE TABLE test_streaming_upsert (key INTEGER PRIMARY KEY, value TEXT)`, + `INSERT INTO test_streaming_upsert (key, value) VALUES (1, 'foo'), (2, 'bar')`, + } { + if _, err := testDB.Exec(ctx, stmt); err != nil { + t.Fatal(err) + } + } + rows := [][]interface{}{ + {3, "baz"}, // new row + {1, "moo"}, // replace "foo" with "moo" + } + err = testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error { + return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}) + }) + if err != nil { + t.Fatal(err) + } + + type row struct { + Key int + Value string + } + + wantRows := []row{ + {1, "moo"}, + {2, "bar"}, + {3, "baz"}, + } + var gotRows []row + if err := testDB.CollectStructs(ctx, &gotRows, `SELECT * FROM test_streaming_upsert ORDER BY key`); err != nil { + t.Fatal(err) + } + if !cmp.Equal(gotRows, wantRows) { + t.Errorf("got %v, want %v", gotRows, wantRows) + } + +} diff --git a/internal/database/database.go b/internal/database/database.go index 96e858f4..312f83af 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -33,6 +33,7 @@ type DB struct { db *sql.DB instanceID string tx *sql.Tx + conn *sql.Conn // the Conn of the Tx, when tx != nil opts sql.TxOptions // valid when tx != nil mu sync.Mutex maxRetries int // max times a single transaction was retried @@ -237,9 +238,15 @@ func (db *DB) transact(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB if db.InTransaction() { return errors.New("a DB Transact function was called on a DB already in a transaction") } - tx, err := db.db.BeginTx(ctx, opts) + conn, err := db.db.Conn(ctx) if err != nil { - return fmt.Errorf("db.BeginTx(): %w", err) + return err + } + defer conn.Close() + + tx, err := conn.BeginTx(ctx, opts) + if err != nil { + return fmt.Errorf("conn.BeginTx(): %w", err) } defer func() { if p := recover(); p != nil { @@ -256,6 +263,7 @@ func (db *DB) transact(ctx context.Context, opts *sql.TxOptions, txFunc func(*DB dbtx := New(db.db, db.instanceID) dbtx.tx = tx + dbtx.conn = conn dbtx.opts = *opts defer dbtx.logTransaction(ctx)(&err) if err := txFunc(dbtx); err != nil { |
