diff options
| author | Jonathan Amsterdam <jba@google.com> | 2020-06-03 13:15:58 -0400 |
|---|---|---|
| committer | Jonathan Amsterdam <jba@google.com> | 2020-06-04 13:31:22 +0000 |
| commit | 5d82fbe1c8532d03e2313d97fb03b7d424b3d481 (patch) | |
| tree | 9298bf66a954235c6ba8b0658901eda8a1144625 /internal/database/database.go | |
| parent | 01d5985063f47f3f8890d77d743ba9d54485e4eb (diff) | |
| download | go-x-pkgsite-5d82fbe1c8532d03e2313d97fb03b7d424b3d481.tar.xz | |
internal/database: use prepared statement for BulkInsert
Using a prepared statement for the query seems to speed up an insert significantly.
Change-Id: I32a1455b376e08fb435b2e756aa8df4b0c2bc1b3
Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/761661
CI-Result: Cloud Build <devtools-proctor-result-processor@system.gserviceaccount.com>
Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
| -rw-r--r-- | internal/database/database.go | 44 |
1 files changed, 34 insertions, 10 deletions
diff --git a/internal/database/database.go b/internal/database/database.go index 2bd6b8cc..af1ad8d4 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -101,12 +101,24 @@ func (db *DB) QueryRow(ctx context.Context, query string, args ...interface{}) * return db.db.QueryRowContext(ctx, query, args...) } +func (db *DB) Prepare(ctx context.Context, query string) (*sql.Stmt, error) { + defer logQuery(ctx, "preparing "+query, nil) + if db.tx != nil { + return db.tx.PrepareContext(ctx, query) + } + return db.db.PrepareContext(ctx, query) +} + // RunQuery executes query, then calls f on each row. func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) error, params ...interface{}) error { rows, err := db.Query(ctx, query, params...) if err != nil { return err } + return processRows(rows, f) +} + +func processRows(rows *sql.Rows, f func(*sql.Rows) error) error { defer rows.Close() for rows.Next() { if err := f(rows); err != nil { @@ -246,18 +258,32 @@ func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningCo // handle it cautiously. return fmt.Errorf("too many columns to insert: %d", len(columns)) } + fullQuery := buildInsertQuery(table, columns, returningColumns, stride, conflictAction) + stmt, err := db.Prepare(ctx, fullQuery) + if err != nil { + return err + } + defer stmt.Close() for leftBound := 0; leftBound < len(values); leftBound += stride { rightBound := leftBound + stride if rightBound > len(values) { rightBound = len(values) + stmt, err = db.Prepare(ctx, buildInsertQuery(table, columns, returningColumns, rightBound-leftBound, conflictAction)) + if err != nil { + return err + } + defer stmt.Close() } valueSlice := values[leftBound:rightBound] - query := buildInsertQuery(table, columns, returningColumns, valueSlice, conflictAction) var err error if returningColumns == nil { - _, err = db.Exec(ctx, query, valueSlice...) + _, err = stmt.ExecContext(ctx, valueSlice...) } else { - err = db.RunQuery(ctx, query, scanFunc, valueSlice...) + rows, err := stmt.QueryContext(ctx, valueSlice...) + if err != nil { + return err + } + err = processRows(rows, scanFunc) } if err != nil { return fmt.Errorf("running bulk insert query, values[%d:%d]): %w", leftBound, rightBound, err) @@ -267,19 +293,17 @@ func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningCo } // buildInsertQuery builds an multi-value insert query, following the format: -// INSERT TO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>) -// If conflictNoAction is true, it appends ON CONFLICT DO NOTHING to the query. +// INSERT TO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>) <conflictAction> // If returningColumns is not empty, it appends a RETURNING clause to the query. // -// When calling buildInsertQuery, it must be true that -// len(values) % len(columns) == 0 -func buildInsertQuery(table string, columns, returningColumns []string, values []interface{}, conflictAction string) string { +// When calling buildInsertQuery, it must be true that nvalues % len(columns) == 0. +func buildInsertQuery(table string, columns, returningColumns []string, nvalues int, conflictAction string) string { var b strings.Builder fmt.Fprintf(&b, "INSERT INTO %s", table) fmt.Fprintf(&b, "(%s) VALUES", strings.Join(columns, ", ")) var placeholders []string - for i := 1; i <= len(values); i++ { + for i := 1; i <= nvalues; i++ { // Construct the full query by adding placeholders for each // set of values that we want to insert. placeholders = append(placeholders, fmt.Sprintf("$%d", i)) @@ -293,7 +317,7 @@ func buildInsertQuery(table string, columns, returningColumns []string, values [ placeholders = nil // Do not add a comma delimiter after the last set of values. - if i == len(values) { + if i == nvalues { break } b.WriteString(", ") |
