aboutsummaryrefslogtreecommitdiff
path: root/internal/database/database.go
diff options
context:
space:
mode:
authorJonathan Amsterdam <jba@google.com>2020-06-09 10:41:39 -0400
committerJonathan Amsterdam <jba@google.com>2020-06-09 22:51:11 +0000
commit07a265813a215f44b5e2973d4f194d15e04fb850 (patch)
treed4ab6ae293085b75c9469b171387893ef56884dc /internal/database/database.go
parent5f3d28792fdeabf75b026a258b05935d1bfb8417 (diff)
downloadgo-x-pkgsite-07a265813a215f44b5e2973d4f194d15e04fb850.tar.xz
internal/database: support bulk upsert
Add DB.BulkUpsert, which adds an ON CONFLICT clause to the INSERT that replaces existing column values. Change-Id: I59f36be0bcb0c0854f42da489e265f2a1396c439 Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/766360 Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
-rw-r--r--internal/database/database.go34
1 files changed, 30 insertions, 4 deletions
diff --git a/internal/database/database.go b/internal/database/database.go
index 25c90b96..48d0a606 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -219,10 +219,11 @@ func (db *DB) MaxRetries() int {
const OnConflictDoNothing = "ON CONFLICT DO NOTHING"
// BulkInsert constructs and executes a multi-value insert statement. The
-// query is constructed using the format: INSERT TO <table> (<columns>) VALUES
-// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it
-// append ON CONFLICT DO NOTHING to the end of the query. The query is executed
-// using a PREPARE statement with the provided values.
+// query is constructed using the format:
+// INSERT INTO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>)
+// If conflictAction is not empty, it is appended to the statement.
+//
+// The query is executed using a PREPARE statement with the provided values.
func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, values []interface{}, conflictAction string) (err error) {
defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)",
table, columns, len(values), conflictAction)
@@ -244,6 +245,21 @@ func (db *DB) BulkInsertReturning(ctx context.Context, table string, columns []s
return db.bulkInsert(ctx, table, columns, returningColumns, values, conflictAction, scanFunc)
}
+// BulkUpsert is like BulkInsert, but instead of a conflict action, a list of
+// conflicting columns is provided. An "ON CONFLICT (conflict_columns) DO
+// UPDATE" clause is added to the statement, with assignments "c=excluded.c" for
+// every column c.
+func (db *DB) BulkUpsert(ctx context.Context, table string, columns []string, values []interface{}, conflictColumns []string) error {
+ conflictAction := buildUpsertConflictAction(columns, conflictColumns)
+ return db.BulkInsert(ctx, table, columns, values, conflictAction)
+}
+
+// BulkUpsertReturning is like BulkInsertReturning, but performs an upsert like BulkUpsert.
+func (db *DB) BulkUpsertReturning(ctx context.Context, table string, columns []string, values []interface{}, conflictColumns, returningColumns []string, scanFunc func(*sql.Rows) error) error {
+ conflictAction := buildUpsertConflictAction(columns, conflictColumns)
+ return db.BulkInsertReturning(ctx, table, columns, values, conflictAction, returningColumns, scanFunc)
+}
+
func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningColumns []string, values []interface{}, conflictAction string, scanFunc func(*sql.Rows) error) (err error) {
if remainder := len(values) % len(columns); remainder != 0 {
return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder)
@@ -338,6 +354,16 @@ func buildInsertQuery(table string, columns, returningColumns []string, nvalues
return b.String()
}
+func buildUpsertConflictAction(columns, conflictColumns []string) string {
+ var sets []string
+ for _, c := range columns {
+ sets = append(sets, fmt.Sprintf("%s=excluded.%[1]s", c))
+ }
+ return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s",
+ strings.Join(conflictColumns, ", "),
+ strings.Join(sets, ", "))
+}
+
// maxBulkUpdateArrayLen is the maximum size of an array that BulkUpdate will send to
// Postgres. (Postgres has no size limit on arrays, but we want to keep the statements
// to a reasonable size.)