From 5d82fbe1c8532d03e2313d97fb03b7d424b3d481 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Jun 2020 13:15:58 -0400 Subject: internal/database: use prepared statement for BulkInsert Using a prepared statement for the query seems to speed up an insert significantly. Change-Id: I32a1455b376e08fb435b2e756aa8df4b0c2bc1b3 Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/761661 CI-Result: Cloud Build Reviewed-by: Julie Qiu --- internal/database/database.go | 44 +++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) (limited to 'internal/database/database.go') diff --git a/internal/database/database.go b/internal/database/database.go index 2bd6b8cc..af1ad8d4 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -101,12 +101,24 @@ func (db *DB) QueryRow(ctx context.Context, query string, args ...interface{}) * return db.db.QueryRowContext(ctx, query, args...) } +func (db *DB) Prepare(ctx context.Context, query string) (*sql.Stmt, error) { + defer logQuery(ctx, "preparing "+query, nil) + if db.tx != nil { + return db.tx.PrepareContext(ctx, query) + } + return db.db.PrepareContext(ctx, query) +} + // RunQuery executes query, then calls f on each row. func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) error, params ...interface{}) error { rows, err := db.Query(ctx, query, params...) if err != nil { return err } + return processRows(rows, f) +} + +func processRows(rows *sql.Rows, f func(*sql.Rows) error) error { defer rows.Close() for rows.Next() { if err := f(rows); err != nil { @@ -246,18 +258,32 @@ func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningCo // handle it cautiously. return fmt.Errorf("too many columns to insert: %d", len(columns)) } + fullQuery := buildInsertQuery(table, columns, returningColumns, stride, conflictAction) + stmt, err := db.Prepare(ctx, fullQuery) + if err != nil { + return err + } + defer stmt.Close() for leftBound := 0; leftBound < len(values); leftBound += stride { rightBound := leftBound + stride if rightBound > len(values) { rightBound = len(values) + stmt, err = db.Prepare(ctx, buildInsertQuery(table, columns, returningColumns, rightBound-leftBound, conflictAction)) + if err != nil { + return err + } + defer stmt.Close() } valueSlice := values[leftBound:rightBound] - query := buildInsertQuery(table, columns, returningColumns, valueSlice, conflictAction) var err error if returningColumns == nil { - _, err = db.Exec(ctx, query, valueSlice...) + _, err = stmt.ExecContext(ctx, valueSlice...) } else { - err = db.RunQuery(ctx, query, scanFunc, valueSlice...) + rows, err := stmt.QueryContext(ctx, valueSlice...) + if err != nil { + return err + } + err = processRows(rows, scanFunc) } if err != nil { return fmt.Errorf("running bulk insert query, values[%d:%d]): %w", leftBound, rightBound, err) @@ -267,19 +293,17 @@ func (db *DB) bulkInsert(ctx context.Context, table string, columns, returningCo } // buildInsertQuery builds an multi-value insert query, following the format: -// INSERT TO () VALUES () -// If conflictNoAction is true, it appends ON CONFLICT DO NOTHING to the query. +// INSERT TO
() VALUES () // If returningColumns is not empty, it appends a RETURNING clause to the query. // -// When calling buildInsertQuery, it must be true that -// len(values) % len(columns) == 0 -func buildInsertQuery(table string, columns, returningColumns []string, values []interface{}, conflictAction string) string { +// When calling buildInsertQuery, it must be true that nvalues % len(columns) == 0. +func buildInsertQuery(table string, columns, returningColumns []string, nvalues int, conflictAction string) string { var b strings.Builder fmt.Fprintf(&b, "INSERT INTO %s", table) fmt.Fprintf(&b, "(%s) VALUES", strings.Join(columns, ", ")) var placeholders []string - for i := 1; i <= len(values); i++ { + for i := 1; i <= nvalues; i++ { // Construct the full query by adding placeholders for each // set of values that we want to insert. placeholders = append(placeholders, fmt.Sprintf("$%d", i)) @@ -293,7 +317,7 @@ func buildInsertQuery(table string, columns, returningColumns []string, values [ placeholders = nil // Do not add a comma delimiter after the last set of values. - if i == len(values) { + if i == nvalues { break } b.WriteString(", ") -- cgit v1.3