diff options
| author | Jonathan Amsterdam <jba@google.com> | 2020-06-09 10:41:39 -0400 |
|---|---|---|
| committer | Jonathan Amsterdam <jba@google.com> | 2020-06-09 22:51:11 +0000 |
| commit | 07a265813a215f44b5e2973d4f194d15e04fb850 (patch) | |
| tree | d4ab6ae293085b75c9469b171387893ef56884dc /internal/database/database.go | |
| parent | 5f3d28792fdeabf75b026a258b05935d1bfb8417 (diff) | |
| download | go-x-pkgsite-07a265813a215f44b5e2973d4f194d15e04fb850.tar.xz | |
internal/database: support bulk upsert
Add DB.BulkUpsert, which adds an ON CONFLICT clause to the INSERT
that replaces existing column values.
Change-Id: I59f36be0bcb0c0854f42da489e265f2a1396c439
Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/766360
Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
| -rw-r--r-- | internal/database/database.go | 34 |
1 files changed, 30 insertions, 4 deletions
diff --git a/internal/database/database.go b/internal/database/database.go index 25c90b96..48d0a606 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -219,10 +219,11 @@ func (db *DB) MaxRetries() int { const OnConflictDoNothing = "ON CONFLICT DO NOTHING" // BulkInsert constructs and executes a multi-value insert statement. The -// query is constructed using the format: INSERT TO <table> (<columns>) VALUES -// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it -// append ON CONFLICT DO NOTHING to the end of the query. The query is executed -// using a PREPARE statement with the provided values. +// query is constructed using the format: +// INSERT INTO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>) +// If conflictAction is not empty, it is appended to the statement. +// +// The query is executed using a PREPARE statement with the provided values. func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, values []interface{}, conflictAction string) (err error) { defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)", table, columns, len(values), conflictAction) @@ -244,6 +245,21 @@ func (db *DB) BulkInsertReturning(ctx context.Context, table string, columns []s return db.bulkInsert(ctx, table, columns, returningColumns, values, conflictAction, scanFunc) } +// BulkUpsert is like BulkInsert, but instead of a conflict action, a list of +// conflicting columns is provided. An "ON CONFLICT (conflict_columns) DO +// UPDATE" clause is added to the statement, with assignments "c=excluded.c" for +// every column c. +func (db *DB) BulkUpsert(ctx context.Context, table string, columns []string, values []interface{}, conflictColumns []string) error { + conflictAction := buildUpsertConflictAction(columns, conflictColumns) + return db.BulkInsert(ctx, table, columns, values, conflictAction) +} + +// BulkUpsertReturning is like BulkInsertReturning, but performs an upsert like BulkUpsert. +func (db *DB) BulkUpsertReturning(ctx context.Context, table string, columns []string, values []interface{}, conflictColumns, returningColumns []string, scanFunc func(*sql.Rows) error) error { + conflictAction := buildUpsertConflictAction(columns, conflictColumns) + return db.BulkInsertReturning(ctx, table, columns, values, conflictAction, returningColumns, scanFunc) +} + func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningColumns []string, values []interface{}, conflictAction string, scanFunc func(*sql.Rows) error) (err error) { if remainder := len(values) % len(columns); remainder != 0 { return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder) @@ -338,6 +354,16 @@ func buildInsertQuery(table string, columns, returningColumns []string, nvalues return b.String() } +func buildUpsertConflictAction(columns, conflictColumns []string) string { + var sets []string + for _, c := range columns { + sets = append(sets, fmt.Sprintf("%s=excluded.%[1]s", c)) + } + return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s", + strings.Join(conflictColumns, ", "), + strings.Join(sets, ", ")) +} + // maxBulkUpdateArrayLen is the maximum size of an array that BulkUpdate will send to // Postgres. (Postgres has no size limit on arrays, but we want to keep the statements // to a reasonable size.) |
