diff options
| author | Jonathan Amsterdam <jba@google.com> | 2020-04-16 07:19:18 -0400 |
|---|---|---|
| committer | Jonathan Amsterdam <jba@google.com> | 2020-04-16 18:58:15 +0000 |
| commit | 26de0025b3e114c5d38c103f265d8d65ba725020 (patch) | |
| tree | b2b25f81c1315d36df574429ccc14f286fc1f5dd /internal/database/database.go | |
| parent | 890b9ef865466059e8746dfbfb0fbcbb4095fb24 (diff) | |
| download | go-x-pkgsite-26de0025b3e114c5d38c103f265d8d65ba725020.tar.xz | |
internal/database: support returning values from a bulk insert
Add DB.BulkInsertReturning, which supports the INSERT ... RETURNING
feature.
Change-Id: I8b7ca21295addde1ef29331d6f3d587bd848b4fd
Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/719741
CI-Result: Cloud Build <devtools-proctor-result-processor@system.gserviceaccount.com>
Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
| -rw-r--r-- | internal/database/database.go | 42 |
1 files changed, 34 insertions, 8 deletions
diff --git a/internal/database/database.go b/internal/database/database.go index 5c208f1f..3aefd19f 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -160,6 +160,24 @@ func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, va defer derrors.Wrap(&err, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)", table, columns, len(values), conflictAction) + return db.bulkInsert(ctx, table, columns, nil, values, conflictAction, nil) +} + +// BulkInsertReturning is like BulkInsert, but supports returning values from the INSERT statement. +// In addition to the arguments of BulkInsert, it takes a list of columns to return and a function +// to scan those columns. To get the returned values, provide a function that scans them as if +// they were the selected columns of a query. See TestBulkInsert for an example. +func (db *DB) BulkInsertReturning(ctx context.Context, table string, columns []string, values []interface{}, conflictAction string, returningColumns []string, scanFunc func(*sql.Rows) error) (err error) { + defer derrors.Wrap(&err, "DB.BulkInsertReturning(ctx, %q, %v, [%d values], %q, %v, scanFunc)", + table, columns, len(values), conflictAction, returningColumns) + + if returningColumns == nil || scanFunc == nil { + return errors.New("need returningColumns and scan function") + } + return db.bulkInsert(ctx, table, columns, returningColumns, values, conflictAction, scanFunc) +} + +func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningColumns []string, values []interface{}, conflictAction string, scanFunc func(*sql.Rows) error) (err error) { if remainder := len(values) % len(columns); remainder != 0 { return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder) } @@ -179,22 +197,28 @@ func (db *DB) BulkInsert(ctx context.Context, table string, columns []string, va rightBound = len(values) } valueSlice := values[leftBound:rightBound] - query := buildInsertQuery(table, columns, valueSlice, conflictAction) - if _, err := db.Exec(ctx, query, valueSlice...); err != nil { - return fmt.Errorf("tx.ExecContext(ctx, [bulk insert query], values[%d:%d]): %v", leftBound, rightBound, err) + query := buildInsertQuery(table, columns, returningColumns, valueSlice, conflictAction) + var err error + if returningColumns == nil { + _, err = db.Exec(ctx, query, valueSlice...) + } else { + err = db.RunQuery(ctx, query, scanFunc, valueSlice...) + } + if err != nil { + return fmt.Errorf("running bulk insert query, values[%d:%d]): %v", leftBound, rightBound, err) } } return nil } // buildInsertQuery builds an multi-value insert query, following the format: -// INSERT TO <table> (<columns>) VALUES -// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it -// append ON CONFLICT DO NOTHING to the end of the query. +// INSERT TO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>) +// If conflictNoAction is true, it appends ON CONFLICT DO NOTHING to the query. +// If returningColumns is not empty, it appends a RETURNING clause to the query. // // When calling buildInsertQuery, it must be true that // len(values) % len(columns) == 0 -func buildInsertQuery(table string, columns []string, values []interface{}, conflictAction string) string { +func buildInsertQuery(table string, columns, returningColumns []string, values []interface{}, conflictAction string) string { var b strings.Builder fmt.Fprintf(&b, "INSERT INTO %s", table) fmt.Fprintf(&b, "(%s) VALUES", strings.Join(columns, ", ")) @@ -222,7 +246,9 @@ func buildInsertQuery(table string, columns []string, values []interface{}, conf if conflictAction != "" { b.WriteString(" " + conflictAction) } - + if len(returningColumns) > 0 { + fmt.Fprintf(&b, " RETURNING %s", strings.Join(returningColumns, ", ")) + } return b.String() } |
