aboutsummaryrefslogtreecommitdiff
path: root/internal/database/database.go
diff options
context:
space:
mode:
authorJonathan Amsterdam <jba@google.com>2020-04-16 07:19:18 -0400
committerJonathan Amsterdam <jba@google.com>2020-04-16 18:58:15 +0000
commit26de0025b3e114c5d38c103f265d8d65ba725020 (patch)
treeb2b25f81c1315d36df574429ccc14f286fc1f5dd /internal/database/database.go
parent890b9ef865466059e8746dfbfb0fbcbb4095fb24 (diff)
downloadgo-x-pkgsite-26de0025b3e114c5d38c103f265d8d65ba725020.tar.xz
internal/database: support returning values from a bulk insert
Add DB.BulkInsertReturning, which supports the INSERT ... RETURNING feature. Change-Id: I8b7ca21295addde1ef29331d6f3d587bd848b4fd Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/719741 CI-Result: Cloud Build <devtools-proctor-result-processor@system.gserviceaccount.com> Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
-rw-r--r--internal/database/database.go42
1 files changed, 34 insertions, 8 deletions
diff --git a/internal/database/database.go b/internal/database/database.go
index 5c208f1f..3aefd19f 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -160,6 +160,24 @@ func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, va
defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)",
table, columns, len(values), conflictAction)
+ return db.bulkInsert(ctx, table, columns, nil, values, conflictAction, nil)
+}
+
+// BulkInsertReturning is like BulkInsert, but supports returning values from the INSERT statement.
+// In addition to the arguments of BulkInsert, it takes a list of columns to return and a function
+// to scan those columns. To get the returned values, provide a function that scans them as if
+// they were the selected columns of a query. See TestBulkInsert for an example.
+func (db *DB) BulkInsertReturning(ctx context.Context, table string, columns []string, values []interface{}, conflictAction string, returningColumns []string, scanFunc func(*sql.Rows) error) (err error) {
+ defer derrors.Wrap(&err, "DB.BulkInsertReturning(ctx, %q, %v, [%d values], %q, %v, scanFunc)",
+ table, columns, len(values), conflictAction, returningColumns)
+
+ if returningColumns == nil || scanFunc == nil {
+ return errors.New("need returningColumns and scan function")
+ }
+ return db.bulkInsert(ctx, table, columns, returningColumns, values, conflictAction, scanFunc)
+}
+
+func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningColumns []string, values []interface{}, conflictAction string, scanFunc func(*sql.Rows) error) (err error) {
if remainder := len(values) % len(columns); remainder != 0 {
return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder)
}
@@ -179,22 +197,28 @@ func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, va
rightBound = len(values)
}
valueSlice := values[leftBound:rightBound]
- query := buildInsertQuery(table, columns, valueSlice, conflictAction)
- if _, err := db.Exec(ctx, query, valueSlice...); err != nil {
- return fmt.Errorf("tx.ExecContext(ctx, [bulk insert query], values[%d:%d]): %v", leftBound, rightBound, err)
+ query := buildInsertQuery(table, columns, returningColumns, valueSlice, conflictAction)
+ var err error
+ if returningColumns == nil {
+ _, err = db.Exec(ctx, query, valueSlice...)
+ } else {
+ err = db.RunQuery(ctx, query, scanFunc, valueSlice...)
+ }
+ if err != nil {
+ return fmt.Errorf("running bulk insert query, values[%d:%d]): %v", leftBound, rightBound, err)
}
}
return nil
}
// buildInsertQuery builds an multi-value insert query, following the format:
-// INSERT TO <table> (<columns>) VALUES
-// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it
-// append ON CONFLICT DO NOTHING to the end of the query.
+// INSERT TO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>)
+// If conflictNoAction is true, it appends ON CONFLICT DO NOTHING to the query.
+// If returningColumns is not empty, it appends a RETURNING clause to the query.
//
// When calling buildInsertQuery, it must be true that
// len(values) % len(columns) == 0
-func buildInsertQuery(table string, columns []string, values []interface{}, conflictAction string) string {
+func buildInsertQuery(table string, columns, returningColumns []string, values []interface{}, conflictAction string) string {
var b strings.Builder
fmt.Fprintf(&b, "INSERT INTO %s", table)
fmt.Fprintf(&b, "(%s) VALUES", strings.Join(columns, ", "))
@@ -222,7 +246,9 @@ func buildInsertQuery(table string, columns []string, values []interface{}, conf
if conflictAction != "" {
b.WriteString(" " + conflictAction)
}
-
+ if len(returningColumns) > 0 {
+ fmt.Fprintf(&b, " RETURNING %s", strings.Join(returningColumns, ", "))
+ }
return b.String()
}