diff options
| author | Jonathan Amsterdam <jba@google.com> | 2021-03-30 10:25:18 -0400 |
|---|---|---|
| committer | Jonathan Amsterdam <jba@google.com> | 2021-03-30 16:42:02 +0000 |
| commit | 127896bc9e7e1d1237fbf854d06200fa778ddaaf (patch) | |
| tree | c5800a746ce39eb8c0ced3b56dea4abdf36dabf7 /internal/database | |
| parent | ace022fde69528cbb7a119fcd0d524ade7b9ff82 (diff) | |
| download | go-x-pkgsite-127896bc9e7e1d1237fbf854d06200fa778ddaaf.tar.xz | |
internal/database: CopyUpsert: support dropping a column
You can't do a CopyFrom on a table with a generated column: postgres
complains about the column value being null. To fix, drop the column
on the temporary table.
Change-Id: Ia52f59af6d026b3fcdaafe3c7865a2eb85deb179
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/305830
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
Diffstat (limited to 'internal/database')
| -rw-r--r-- | internal/database/copy.go | 20 | ||||
| -rw-r--r-- | internal/database/copy_test.go | 70 |
2 files changed, 67 insertions, 23 deletions
diff --git a/internal/database/copy.go b/internal/database/copy.go index 74c5d00e..7a092c0f 100644 --- a/internal/database/copy.go +++ b/internal/database/copy.go @@ -23,11 +23,13 @@ import ( // src is the source of the rows to upsert. // conflictColumns are the columns that might conflict (i.e. that have a UNIQUE // constraint). +// If dropColumn is non-empty, that column will be dropped from the temporary +// table before copying. Use dropColumn for generated ID columns. // // CopyUpsert works by first creating a temporary table, populating it with // CopyFrom, and then running an INSERT...SELECT...ON CONFLICT to upsert its // rows into the original table. -func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string) (err error) { +func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, src pgx.CopyFromSource, conflictColumns []string, dropColumn string) (err error) { defer derrors.Wrap(&err, "CopyUpsert(%q)", table) if !db.InTransaction() { @@ -46,8 +48,11 @@ func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, sr tempTable := fmt.Sprintf("__%s_copy", table) stmt := fmt.Sprintf(` DROP TABLE IF EXISTS %s; - CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP + CREATE TEMP TABLE %[1]s (LIKE %s) ON COMMIT DROP; `, tempTable, table) + if dropColumn != "" { + stmt += fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tempTable, dropColumn) + } _, err = conn.Exec(ctx, stmt) if err != nil { return err @@ -55,12 +60,12 @@ func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, sr start := time.Now() n, err := conn.CopyFrom(ctx, []string{tempTable}, columns, src) if err != nil { - return err + return fmt.Errorf("CopyFrom: %w", err) } log.Debugf(ctx, "CopyUpsert(%q): copied %d rows in %s", table, n, time.Since(start)) conflictAction := buildUpsertConflictAction(columns, conflictColumns) - query := buildCopyUpsertQuery(table, tempTable, columns, conflictAction) - + cols := strings.Join(columns, ", ") + query := fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction) defer logQuery(ctx, query, nil, db.instanceID, db.IsRetryable())(&err) start = time.Now() ctag, err := conn.Exec(ctx, query) @@ -72,11 +77,6 @@ func (db *DB) CopyUpsert(ctx context.Context, table string, columns []string, sr }) } -func buildCopyUpsertQuery(table, tempTable string, columns []string, conflictAction string) string { - cols := strings.Join(columns, ", ") - return fmt.Sprintf("INSERT INTO %s (%s) SELECT %s FROM %s %s", table, cols, cols, tempTable, conflictAction) -} - // A RowItem is a row of values or an error. type RowItem struct { Values []interface{} diff --git a/internal/database/copy_test.go b/internal/database/copy_test.go index 5c6f41b6..238af8e8 100644 --- a/internal/database/copy_test.go +++ b/internal/database/copy_test.go @@ -15,18 +15,8 @@ import ( ) func TestCopyUpsert(t *testing.T) { + pgxOnly(t) ctx := context.Background() - conn, err := testDB.db.Conn(ctx) - if err != nil { - t.Fatal(err) - } - conn.Raw(func(c interface{}) error { - if _, ok := c.(*stdlib.Conn); !ok { - t.Skip("skipping; DB driver not pgx") - } - return nil - }) - for _, stmt := range []string{ `DROP TABLE IF EXISTS test_streaming_upsert`, `CREATE TABLE test_streaming_upsert (key INTEGER PRIMARY KEY, value TEXT)`, @@ -40,8 +30,8 @@ func TestCopyUpsert(t *testing.T) { {3, "baz"}, // new row {1, "moo"}, // replace "foo" with "moo" } - err = testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error { - return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}) + err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error { + return tx.CopyUpsert(ctx, "test_streaming_upsert", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "") }) if err != nil { t.Fatal(err) @@ -66,3 +56,57 @@ func TestCopyUpsert(t *testing.T) { } } + +func TestCopyUpsertGeneratedColumn(t *testing.T) { + pgxOnly(t) + ctx := context.Background() + stmt := ` + DROP TABLE IF EXISTS test_copy_gen; + CREATE TABLE test_copy_gen (id bigint PRIMARY KEY GENERATED ALWAYS AS IDENTITY, key INT, value TEXT, UNIQUE (key)); + INSERT INTO test_copy_gen (key, value) VALUES (11, 'foo'), (12, 'bar')` + if _, err := testDB.Exec(ctx, stmt); err != nil { + t.Fatal(err) + } + + rows := [][]interface{}{ + {13, "baz"}, // new row + {11, "moo"}, // replace "foo" with "moo" + } + err := testDB.Transact(ctx, sql.LevelDefault, func(tx *DB) error { + return tx.CopyUpsert(ctx, "test_copy_gen", []string{"key", "value"}, pgx.CopyFromRows(rows), []string{"key"}, "id") + }) + if err != nil { + t.Fatal(err) + } + + type row struct { + ID int64 + Key int + Value string + } + wantRows := []row{ + {1, 11, "moo"}, + {2, 12, "bar"}, + {3, 13, "baz"}, + } + var gotRows []row + if err := testDB.CollectStructs(ctx, &gotRows, `SELECT * FROM test_copy_gen ORDER BY ID`); err != nil { + t.Fatal(err) + } + if !cmp.Equal(gotRows, wantRows) { + t.Errorf("got %v, want %v", gotRows, wantRows) + } +} + +func pgxOnly(t *testing.T) { + conn, err := testDB.db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + conn.Raw(func(c interface{}) error { + if _, ok := c.(*stdlib.Conn); !ok { + t.Skip("skipping; DB driver not pgx") + } + return nil + }) +} |
