diff options
| author | Jonathan Amsterdam <jba@google.com> | 2019-11-19 19:04:01 -0500 |
|---|---|---|
| committer | Julie Qiu <julie@golang.org> | 2020-03-27 16:46:48 -0400 |
| commit | 65e2b7a804ce25942210f8db0e9552db9d7d6ff5 (patch) | |
| tree | b7e330862db226a7f431eeda083312dc52d96556 /internal/database | |
| parent | c77f28bf038ee25dce5a0cd513b721683bf940cc (diff) | |
| download | go-x-pkgsite-65e2b7a804ce25942210f8db0e9552db9d7d6ff5.tar.xz | |
internal/database, internal/testing/dbtest: site-agnostic DB functionality
Extract into a separate package the core functionality from
internal/postgres that doesn't depend on our particular schema.
This makes it available for other uses, like devtools commands and etl
autocomplete.
Do the same for testing functionality.
We now have three packages where before we had only one:
- internal/postgres: discovery-specific DB operations and test support
- internal/database: discovery-agnostic DB operations
- internal/testing/dbtest: discovery-agnostic DB test support
Change-Id: I54c59aee328dae71ba6c77170a72e7a83da7c785
Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/602327
Reviewed-by: Robert Findley <rfindley@google.com>
Diffstat (limited to 'internal/database')
| -rw-r--r-- | internal/database/database.go | 321 | ||||
| -rw-r--r-- | internal/database/database_test.go | 185 |
2 files changed, 506 insertions, 0 deletions
diff --git a/internal/database/database.go b/internal/database/database.go new file mode 100644 index 00000000..b8b33f95 --- /dev/null +++ b/internal/database/database.go @@ -0,0 +1,321 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package database adds some useful functionality to a sql.DB. +// It is independent of the database driver and the +// DB schema. +package database + +import ( + "context" + "database/sql" + "fmt" + "regexp" + "strings" + "sync/atomic" + "time" + "unicode" + + "golang.org/x/discovery/internal/config" + "golang.org/x/discovery/internal/derrors" + "golang.org/x/discovery/internal/log" +) + +// DB wraps a sql.DB. The methods it exports correspond closely to those of +// sql.DB. They enhance the original by requiring a context argument, and by +// logging the query and any resulting errors. +type DB struct { + db *sql.DB +} + +// TODO(jba): remove +func (db *DB) Underlying() *sql.DB { + return db.db +} + +// Open creates a new DB for the given connection string. +func Open(driverName, dbinfo string) (_ *DB, err error) { + defer derrors.Wrap(&err, "database.Open(%q, %q)", + driverName, redactPassword(dbinfo)) + + db, err := sql.Open(driverName, dbinfo) + if err != nil { + return nil, err + } + if err = db.Ping(); err != nil { + return nil, err + } + return &DB{db}, nil +} + +var passwordRegexp = regexp.MustCompile(`password=\S+`) + +func redactPassword(dbinfo string) string { + return passwordRegexp.ReplaceAllLiteralString(dbinfo, "password=REDACTED") +} + +// Close closes the database connection. +func (db *DB) Close() error { + return db.db.Close() +} + +// Exec executes a SQL statement. +func (db *DB) Exec(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) { + defer logQuery(query, args)(&err) + + return db.db.ExecContext(ctx, query, args...) +} + +// ExecTx runs a query in a transaction. +func ExecTx(ctx context.Context, tx *sql.Tx, query string, args ...interface{}) (res sql.Result, err error) { + defer logQuery(query, args)(&err) + + return tx.ExecContext(ctx, query, args...) +} + +// Query runs the DB query. +func (db *DB) Query(ctx context.Context, query string, args ...interface{}) (_ *sql.Rows, err error) { + defer logQuery(query, args)(&err) + return db.db.QueryContext(ctx, query, args...) +} + +// QueryRow runs the query and returns a single row. +func (db *DB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { + defer logQuery(query, args)(nil) + return db.db.QueryRowContext(ctx, query, args...) +} + +// RunQuery executes query, then calls f on each row. +func (db *DB) RunQuery(ctx context.Context, query string, f func(*sql.Rows) error, params ...interface{}) error { + rows, err := db.Query(ctx, query, params...) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + if err := f(rows); err != nil { + return err + } + } + return rows.Err() +} + +// Transact executes the given function in the context of a SQL transaction, +// rolling back the transaction if the function panics or returns an error. +func (db *DB) Transact(txFunc func(*sql.Tx) error) (err error) { + tx, err := db.db.Begin() + if err != nil { + return fmt.Errorf("db.Begin(): %v", err) + } + + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) + } else if err != nil { + tx.Rollback() + } else { + if err = tx.Commit(); err != nil { + err = fmt.Errorf("tx.Commit(): %v", err) + } + } + }() + + if err := txFunc(tx); err != nil { + return fmt.Errorf("txFunc(tx): %v", err) + } + return nil +} + +const OnConflictDoNothing = "ON CONFLICT DO NOTHING" + +// BulkInsert constructs and executes a multi-value insert statement. The +// query is constructed using the format: INSERT TO <table> (<columns>) VALUES +// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it +// append ON CONFLICT DO NOTHING to the end of the query. The query is executed +// using a PREPARE statement with the provided values. +func BulkInsert(ctx context.Context, tx *sql.Tx, table string, columns []string, values []interface{}, conflictAction string) (err error) { + defer derrors.Wrap(&err, "bulkInsert(ctx, tx, %q, %v, [%d values], %q)", + table, columns, len(values), conflictAction) + + if remainder := len(values) % len(columns); remainder != 0 { + return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder) + } + + // Postgres supports up to 65535 parameters, but stop well before that + // so we don't construct humongous queries. + const maxParameters = 1000 + stride := (maxParameters / len(columns)) * len(columns) + if stride == 0 { + // This is a pathological case (len(columns) > maxParameters), but we + // handle it cautiously. + return fmt.Errorf("too many columns to insert: %d", len(columns)) + } + for leftBound := 0; leftBound < len(values); leftBound += stride { + rightBound := leftBound + stride + if rightBound > len(values) { + rightBound = len(values) + } + valueSlice := values[leftBound:rightBound] + query, err := buildInsertQuery(table, columns, valueSlice, conflictAction) + if err != nil { + return fmt.Errorf("buildInsertQuery(%q, %v, values[%d:%d], %q): %v", table, columns, leftBound, rightBound, conflictAction, err) + } + + if _, err := ExecTx(ctx, tx, query, valueSlice...); err != nil { + return fmt.Errorf("tx.ExecContext(ctx, [bulk insert query], values[%d:%d]): %v", leftBound, rightBound, err) + } + } + return nil +} + +// buildInsertQuery builds an multi-value insert query, following the format: +// INSERT TO <table> (<columns>) VALUES +// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it +// append ON CONFLICT DO NOTHING to the end of the query. +// +// When calling buildInsertQuery, it must be true that +// len(values) % len(columns) == 0 +func buildInsertQuery(table string, columns []string, values []interface{}, conflictAction string) (string, error) { + var b strings.Builder + fmt.Fprintf(&b, "INSERT INTO %s", table) + fmt.Fprintf(&b, "(%s) VALUES", strings.Join(columns, ", ")) + + var placeholders []string + for i := 1; i <= len(values); i++ { + // Construct the full query by adding placeholders for each + // set of values that we want to insert. + placeholders = append(placeholders, fmt.Sprintf("$%d", i)) + if i%len(columns) != 0 { + continue + } + + // When the end of a set is reached, write it to the query + // builder and reset placeholders. + fmt.Fprintf(&b, "(%s)", strings.Join(placeholders, ", ")) + placeholders = []string{} + + // Do not add a comma delimiter after the last set of values. + if i == len(values) { + break + } + b.WriteString(", ") + } + if conflictAction != "" { + b.WriteString(" " + conflictAction) + } + + return b.String(), nil +} + +// QueryLoggingDisabled stops logging of queries when true. +// For use in tests only: not concurrency-safe. +var QueryLoggingDisabled bool + +var queryCounter int64 // atomic: per-process counter for unique query IDs + +type queryEndLogEntry struct { + ID string + Query string + Args string + DurationSeconds float64 + Error string `json:",omitempty"` +} + +func logQuery(query string, args []interface{}) func(*error) { + if QueryLoggingDisabled { + return func(*error) {} + } + const maxlen = 300 // maximum length of displayed query + + // To make the query more compact and readable, replace newlines with spaces + // and collapse adjacent whitespace. + var r []rune + for _, c := range query { + if c == '\n' { + c = ' ' + } + if len(r) == 0 || !unicode.IsSpace(r[len(r)-1]) || !unicode.IsSpace(c) { + r = append(r, c) + } + } + query = string(r) + if len(query) > maxlen { + query = query[:maxlen] + "..." + } + + instanceID := config.InstanceID() + if instanceID == "" { + instanceID = "local" + } else { + // Instance IDs are long strings. The low-order part seems quite random, so + // shortening the ID will still likely result in something unique. + instanceID = instanceID[len(instanceID)-4:] + } + n := atomic.AddInt64(&queryCounter, 1) + uid := fmt.Sprintf("%s-%d", instanceID, n) + + // Construct a short string of the args. + const ( + maxArgs = 20 + maxArgLen = 50 + ) + var argStrings []string + for i := 0; i < len(args) && i < maxArgs; i++ { + s := fmt.Sprint(args[i]) + if len(s) > maxArgLen { + s = s[:maxArgLen] + "..." + } + argStrings = append(argStrings, s) + } + if len(args) > maxArgs { + argStrings = append(argStrings, "...") + } + argString := strings.Join(argStrings, ", ") + + log.Debugf("%s %s args=%s", uid, query, argString) + start := time.Now() + return func(errp *error) { + dur := time.Since(start) + if errp == nil { // happens with queryRow + log.Debugf("%s done", uid) + } else { + derrors.Wrap(errp, "DB running query %s", uid) + entry := queryEndLogEntry{ + ID: uid, + Query: query, + Args: argString, + DurationSeconds: dur.Seconds(), + } + if *errp == nil { + log.Debug(entry) + } else { + entry.Error = (*errp).Error() + log.Error(entry) + } + } + } +} + +// emptyStringScanner wraps the functionality of sql.NullString to just write +// an empty string if the value is NULL. +type emptyStringScanner struct { + ptr *string +} + +func (e emptyStringScanner) Scan(value interface{}) error { + var ns sql.NullString + if err := ns.Scan(value); err != nil { + return err + } + *e.ptr = ns.String + return nil +} + +// NullIsEmpty returns a sql.Scanner that writes the empty string to s if the +// sql.Value is NULL. +func NullIsEmpty(s *string) sql.Scanner { + return emptyStringScanner{s} +} diff --git a/internal/database/database_test.go b/internal/database/database_test.go new file mode 100644 index 00000000..ef72ace7 --- /dev/null +++ b/internal/database/database_test.go @@ -0,0 +1,185 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package database + +import ( + "context" + "database/sql" + "fmt" + "log" + "os" + "testing" + "time" + + "golang.org/x/discovery/internal/testing/dbtest" +) + +const testTimeout = 5 * time.Second + +var testDB *DB + +func TestMain(m *testing.M) { + const dbName = "discovery_postgres_test" + + if err := dbtest.CreateDBIfNotExists(dbName); err != nil { + log.Fatal(err) + } + var err error + testDB, err = Open("postgres", dbtest.DBConnURI(dbName)) + if err != nil { + log.Fatal(err) + } + code := m.Run() + if err := testDB.Close(); err != nil { + log.Fatal(err) + } + os.Exit(code) +} + +func TestBulkInsert(t *testing.T) { + table := "test_bulk_insert" + + for _, tc := range []struct { + name string + columns []string + values []interface{} + conflictAction string + wantErr bool + wantCount int + }{ + { + + name: "test-one-row", + columns: []string{"colA"}, + values: []interface{}{"valueA"}, + wantCount: 1, + }, + { + + name: "test-multiple-rows", + columns: []string{"colA"}, + values: []interface{}{"valueA1", "valueA2", "valueA3"}, + wantCount: 3, + }, + { + + name: "test-invalid-column-name", + columns: []string{"invalid_col"}, + values: []interface{}{"valueA"}, + wantErr: true, + }, + { + + name: "test-mismatch-num-cols-and-vals", + columns: []string{"colA", "colB"}, + values: []interface{}{"valueA1", "valueB1", "valueA2"}, + wantErr: true, + }, + { + + name: "test-conflict-no-action-true", + columns: []string{"colA"}, + values: []interface{}{"valueA", "valueA"}, + conflictAction: OnConflictDoNothing, + wantCount: 1, + }, + { + + name: "test-conflict-no-action-false", + columns: []string{"colA"}, + values: []interface{}{"valueA", "valueA"}, + wantErr: true, + }, + { + + // This should execute the statement + // INSERT INTO series (path) VALUES ('''); TRUNCATE series CASCADE;)'); + // which will insert a row with path value: + // '); TRUNCATE series CASCADE;) + // Rather than the statement + // INSERT INTO series (path) VALUES (''); TRUNCATE series CASCADE;)); + // which would truncate most tables in the database. + name: "test-sql-injection", + columns: []string{"colA"}, + values: []interface{}{fmt.Sprintf("''); TRUNCATE %s CASCADE;))", table)}, + conflictAction: OnConflictDoNothing, + wantCount: 1, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + createQuery := fmt.Sprintf(`CREATE TABLE %s ( + colA TEXT NOT NULL, + colB TEXT, + PRIMARY KEY (colA) + );`, table) + if _, err := testDB.Exec(ctx, createQuery); err != nil { + t.Fatal(err) + } + defer func() { + dropTableQuery := fmt.Sprintf("DROP TABLE %s;", table) + if _, err := testDB.Exec(ctx, dropTableQuery); err != nil { + t.Fatal(err) + } + }() + + if err := testDB.Transact(func(tx *sql.Tx) error { + return BulkInsert(ctx, tx, table, tc.columns, tc.values, tc.conflictAction) + }); tc.wantErr && err == nil || !tc.wantErr && err != nil { + t.Errorf("testDB.Transact: %v | wantErr = %t", err, tc.wantErr) + } + + if tc.wantCount != 0 { + var count int + query := "SELECT COUNT(*) FROM " + table + row := testDB.QueryRow(ctx, query) + err := row.Scan(&count) + if err != nil { + t.Fatalf("testDB.queryRow(%q): %v", query, err) + } + if count != tc.wantCount { + t.Errorf("testDB.queryRow(%q) = %d; want = %d", query, count, tc.wantCount) + } + } + }) + } +} + +func TestLargeBulkInsert(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + if _, err := testDB.Exec(ctx, `CREATE TEMPORARY TABLE test_large_bulk (i BIGINT);`); err != nil { + t.Fatal(err) + } + const size = 150000 + vals := make([]interface{}, size) + for i := 0; i < size; i++ { + vals[i] = i + 1 + } + if err := testDB.Transact(func(tx *sql.Tx) error { + return BulkInsert(ctx, tx, "test_large_bulk", []string{"i"}, vals, "") + }); err != nil { + t.Fatal(err) + } + rows, err := testDB.Query(ctx, `SELECT i FROM test_large_bulk;`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + sum := int64(0) + for rows.Next() { + var i int64 + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + sum += i + } + var want int64 = size * (size + 1) / 2 + if sum != want { + t.Errorf("sum = %d, want %d", sum, want) + } +} |
