aboutsummaryrefslogtreecommitdiff
path: root/internal/database/database.go
diff options
context:
space:
mode:
authorJonathan Amsterdam <jba@google.com>2020-06-03 13:15:58 -0400
committerJonathan Amsterdam <jba@google.com>2020-06-04 13:31:22 +0000
commit5d82fbe1c8532d03e2313d97fb03b7d424b3d481 (patch)
tree9298bf66a954235c6ba8b0658901eda8a1144625 /internal/database/database.go
parent01d5985063f47f3f8890d77d743ba9d54485e4eb (diff)
downloadgo-x-pkgsite-5d82fbe1c8532d03e2313d97fb03b7d424b3d481.tar.xz
internal/database: use prepared statement for BulkInsert
Using a prepared statement for the query seems to speed up an insert significantly. Change-Id: I32a1455b376e08fb435b2e756aa8df4b0c2bc1b3 Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/761661 CI-Result: Cloud Build <devtools-proctor-result-processor@system.gserviceaccount.com> Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
-rw-r--r--internal/database/database.go44
1 files changed, 34 insertions, 10 deletions
diff --git a/internal/database/database.go b/internal/database/database.go
index 2bd6b8cc..af1ad8d4 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -101,12 +101,24 @@ func (db *DB) QueryRow(ctx context.Context, query string, args ...interface{}) *
return db.db.QueryRowContext(ctx, query, args...)
}
+func (db *DB) Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
+ defer logQuery(ctx, "preparing "+query, nil)
+ if db.tx != nil {
+ return db.tx.PrepareContext(ctx, query)
+ }
+ return db.db.PrepareContext(ctx, query)
+}
+
// RunQuery executes query, then calls f on each row.
func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) error, params ...interface{}) error {
rows, err := db.Query(ctx, query, params...)
if err != nil {
return err
}
+ return processRows(rows, f)
+}
+
+func processRows(rows *sql.Rows, f func(*sql.Rows) error) error {
defer rows.Close()
for rows.Next() {
if err := f(rows); err != nil {
@@ -246,18 +258,32 @@ func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningCo
// handle it cautiously.
return fmt.Errorf("too many columns to insert: %d", len(columns))
}
+ fullQuery := buildInsertQuery(table, columns, returningColumns, stride, conflictAction)
+ stmt, err := db.Prepare(ctx, fullQuery)
+ if err != nil {
+ return err
+ }
+ defer stmt.Close()
for leftBound := 0; leftBound < len(values); leftBound += stride {
rightBound := leftBound + stride
if rightBound > len(values) {
rightBound = len(values)
+ stmt, err = db.Prepare(ctx, buildInsertQuery(table, columns, returningColumns, rightBound-leftBound, conflictAction))
+ if err != nil {
+ return err
+ }
+ defer stmt.Close()
}
valueSlice := values[leftBound:rightBound]
- query := buildInsertQuery(table, columns, returningColumns, valueSlice, conflictAction)
var err error
if returningColumns == nil {
- _, err = db.Exec(ctx, query, valueSlice...)
+ _, err = stmt.ExecContext(ctx, valueSlice...)
} else {
- err = db.RunQuery(ctx, query, scanFunc, valueSlice...)
+ rows, err := stmt.QueryContext(ctx, valueSlice...)
+ if err != nil {
+ return err
+ }
+ err = processRows(rows, scanFunc)
}
if err != nil {
return fmt.Errorf("running bulk insert query, values[%d:%d]): %w", leftBound, rightBound, err)
@@ -267,19 +293,17 @@ func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningCo
}
// buildInsertQuery builds an multi-value insert query, following the format:
-// INSERT TO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>)
-// If conflictNoAction is true, it appends ON CONFLICT DO NOTHING to the query.
+// INSERT TO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>) <conflictAction>
// If returningColumns is not empty, it appends a RETURNING clause to the query.
//
-// When calling buildInsertQuery, it must be true that
-// len(values) % len(columns) == 0
-func buildInsertQuery(table string, columns, returningColumns []string, values []interface{}, conflictAction string) string {
+// When calling buildInsertQuery, it must be true that nvalues % len(columns) == 0.
+func buildInsertQuery(table string, columns, returningColumns []string, nvalues int, conflictAction string) string {
var b strings.Builder
fmt.Fprintf(&b, "INSERT INTO %s", table)
fmt.Fprintf(&b, "(%s) VALUES", strings.Join(columns, ", "))
var placeholders []string
- for i := 1; i <= len(values); i++ {
+ for i := 1; i <= nvalues; i++ {
// Construct the full query by adding placeholders for each
// set of values that we want to insert.
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
@@ -293,7 +317,7 @@ func buildInsertQuery(table string, columns, returningColumns []string, values [
placeholders = nil
// Do not add a comma delimiter after the last set of values.
- if i == len(values) {
+ if i == nvalues {
break
}
b.WriteString(", ")