aboutsummaryrefslogtreecommitdiff
path: root/internal/database/database.go
diff options
context:
space:
mode:
authorJonathan Amsterdam <jba@google.com>2020-04-11 06:47:20 -0400
committerJonathan Amsterdam <jba@google.com>2020-04-16 12:36:42 +0000
commitf3d7078125541e8902382be09e4a770414d90a59 (patch)
tree6cc3d97ffa7c5cc073fe52d8cc902e3f81a9435d /internal/database/database.go
parente90fb1c4adde446b8e963b04e01d8bc4f42a2b2b (diff)
downloadgo-x-pkgsite-f3d7078125541e8902382be09e4a770414d90a59.tar.xz
internal/database: support bulk updates
Implement BulkUpdate. See the function's doc comment for details. Change-Id: I050227a16ebe2f93cfc535fe95d1cd310fd6d00b Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/716478 CI-Result: Cloud Build <devtools-proctor-result-processor@system.gserviceaccount.com> Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database.go')
-rw-r--r--internal/database/database.go74
1 files changed, 74 insertions, 0 deletions
diff --git a/internal/database/database.go b/internal/database/database.go
index 31e3b59a..57804c9a 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -18,6 +18,7 @@ import (
"time"
"unicode"
+ "github.com/lib/pq"
"golang.org/x/discovery/internal/config"
"golang.org/x/discovery/internal/derrors"
"golang.org/x/discovery/internal/log"
@@ -224,6 +225,79 @@ func buildInsertQuery(table string, columns []string, values []interface{}, conf
return b.String()
}
+// maxBulkUpdateArrayLen is the maximum size of an array that BulkUpdate will send to
+// Postgres. (Postgres has no size limit on arrays, but we want to keep the statements
+// to a reasonable size.)
+// It is a variable for testing.
+var maxBulkUpdateArrayLen = 10000
+
+// BulkUpdate executes multiple UPDATE statements in a transaction.
+//
+// Columns must contain the names of some of table's columns. The first is treated
+// as a key; that is, the values to update are matched with existing rows by comparing
+// the values of the first column.
+//
+// Types holds the database type of each column. For example,
+// []string{"INT", "TEXT"}
+//
+// Values contains one slice of values per column. (Note that this is unlike BulkInsert, which
+// takes a single slice of interleaved values.)
+func (db *DB) BulkUpdate(ctx context.Context, table string, columns, types []string, values [][]interface{}) (err error) {
+ defer derrors.Wrap(&err, "DB.BulkUpdate(ctx, tx, %q, %v, [%d values])",
+ table, columns, len(values))
+
+ if len(columns) < 2 {
+ return errors.New("need at least two columns")
+ }
+ if len(columns) != len(values) {
+ return errors.New("len(values) != len(columns)")
+ }
+ nRows := len(values[0])
+ for _, v := range values[1:] {
+ if len(v) != nRows {
+ return errors.New("all values slices must be the same length")
+ }
+ }
+ query := buildBulkUpdateQuery(table, columns, types)
+ for left := 0; left < nRows; left += maxBulkUpdateArrayLen {
+ right := left + maxBulkUpdateArrayLen
+ if right > nRows {
+ right = nRows
+ }
+ var args []interface{}
+ for _, vs := range values {
+ args = append(args, pq.Array(vs[left:right]))
+ }
+ if _, err := db.Exec(ctx, query, args...); err != nil {
+ return fmt.Errorf("db.Exec(%q, values[%d:%d]): %v", query, left, right, err)
+ }
+ }
+ return nil
+}
+
+func buildBulkUpdateQuery(table string, columns, types []string) string {
+ var sets, unnests []string
+ // Build "c = data.c" for each non-key column.
+ for _, c := range columns[1:] {
+ sets = append(sets, fmt.Sprintf("%s = data.%[1]s", c))
+ }
+ // Build "UNNEST($1::TYPE) AS c" for each column.
+ // We need the type, or Postgres complains that UNNEST is not unique.
+ for i, c := range columns {
+ unnests = append(unnests, fmt.Sprintf("UNNEST($%d::%s[]) AS %s", i+1, types[i], c))
+ }
+ return fmt.Sprintf(`
+ UPDATE %[1]s
+ SET %[2]s
+ FROM (SELECT %[3]s) AS data
+ WHERE %[1]s.%[4]s = data.%[4]s`,
+ table, // 1
+ strings.Join(sets, ", "), // 2
+ strings.Join(unnests, ", "), // 3
+ columns[0], // 4
+ )
+}
+
// QueryLoggingDisabled stops logging of queries when true.
// For use in tests only: not concurrency-safe.
var QueryLoggingDisabled bool