diff options
| author | Jonathan Amsterdam <jba@google.com> | 2019-11-19 19:04:01 -0500 |
|---|---|---|
| committer | Julie Qiu <julie@golang.org> | 2020-03-27 16:46:48 -0400 |
| commit | 65e2b7a804ce25942210f8db0e9552db9d7d6ff5 (patch) | |
| tree | b7e330862db226a7f431eeda083312dc52d96556 /internal/postgres | |
| parent | c77f28bf038ee25dce5a0cd513b721683bf940cc (diff) | |
| download | go-x-pkgsite-65e2b7a804ce25942210f8db0e9552db9d7d6ff5.tar.xz | |
internal/database, internal/testing/dbtest: site-agnostic DB functionality
Extract into a separate package the core functionality from
internal/postgres that doesn't depend on our particular schema.
This makes it available for other uses, like devtools commands and etl
autocomplete.
Do the same for testing functionality.
We now have three packages where before we had only one:
- internal/postgres: discovery-specific DB operations and test support
- internal/database: discovery-agnostic DB operations
- internal/testing/dbtest: discovery-agnostic DB test support
Change-Id: I54c59aee328dae71ba6c77170a72e7a83da7c785
Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/602327
Reviewed-by: Robert Findley <rfindley@google.com>
Diffstat (limited to 'internal/postgres')
| -rw-r--r-- | internal/postgres/details.go | 19 | ||||
| -rw-r--r-- | internal/postgres/directory.go | 5 | ||||
| -rw-r--r-- | internal/postgres/excluded.go | 4 | ||||
| -rw-r--r-- | internal/postgres/excluded_test.go | 2 | ||||
| -rw-r--r-- | internal/postgres/insert_version.go | 18 | ||||
| -rw-r--r-- | internal/postgres/insert_version_test.go | 2 | ||||
| -rw-r--r-- | internal/postgres/package.go | 5 | ||||
| -rw-r--r-- | internal/postgres/parent_directories_test.go | 4 | ||||
| -rw-r--r-- | internal/postgres/postgres.go | 302 | ||||
| -rw-r--r-- | internal/postgres/postgres_test.go | 149 | ||||
| -rw-r--r-- | internal/postgres/search.go | 29 | ||||
| -rw-r--r-- | internal/postgres/search_test.go | 4 | ||||
| -rw-r--r-- | internal/postgres/test_helper.go | 93 | ||||
| -rw-r--r-- | internal/postgres/versionstate.go | 17 |
14 files changed, 82 insertions, 571 deletions
diff --git a/internal/postgres/details.go b/internal/postgres/details.go index 81b08227..3c982859 100644 --- a/internal/postgres/details.go +++ b/internal/postgres/details.go @@ -16,6 +16,7 @@ import ( "github.com/lib/pq" "golang.org/x/discovery/internal" + "golang.org/x/discovery/internal/database" "golang.org/x/discovery/internal/derrors" "golang.org/x/discovery/internal/license" "golang.org/x/discovery/internal/version" @@ -62,7 +63,7 @@ func (db *DB) GetPackagesInVersion(ctx context.Context, modulePath, version stri return nil } - if err := db.runQuery(ctx, query, collect, modulePath, version); err != nil { + if err := db.db.RunQuery(ctx, query, collect, modulePath, version); err != nil { return nil, xerrors.Errorf("DB.GetPackagesInVersion(ctx, %q, %q): %w", err) } return packages, nil @@ -127,7 +128,7 @@ func getPackageVersions(ctx context.Context, db *DB, pkgPath string, versionType } query := fmt.Sprintf(baseQuery, versionTypeExpr(versionTypes), queryEnd) - rows, err := db.query(ctx, query, pkgPath) + rows, err := db.db.Query(ctx, query, pkgPath) if err != nil { return nil, err } @@ -211,7 +212,7 @@ func getModuleVersions(ctx context.Context, db *DB, modulePath string, versionTy vinfos = append(vinfos, &vi) return nil } - if err := db.runQuery(ctx, query, collect, internal.SeriesPathForModule(modulePath)); err != nil { + if err := db.db.RunQuery(ctx, query, collect, internal.SeriesPathForModule(modulePath)); err != nil { return nil, err } return vinfos, nil @@ -248,7 +249,7 @@ func (db *DB) GetImports(ctx context.Context, pkgPath, modulePath, version strin imports = append(imports, toPath) return nil } - if err := db.runQuery(ctx, query, collect, pkgPath, version, modulePath); err != nil { + if err := db.db.RunQuery(ctx, query, collect, pkgPath, version, modulePath); err != nil { return nil, err } return imports, nil @@ -287,7 +288,7 @@ func (db *DB) GetImportedBy(ctx context.Context, pkgPath, modulePath string, lim importedby = append(importedby, fromPath) return nil } - if err := db.runQuery(ctx, query, collect, pkgPath, modulePath, limit); err != nil { + if err := db.db.RunQuery(ctx, query, collect, pkgPath, modulePath, limit); err != nil { return nil, err } return importedby, nil @@ -310,7 +311,7 @@ func (db *DB) GetModuleLicenses(ctx context.Context, modulePath, version string) WHERE module_path = $1 AND version = $2 AND position('/' in file_path) = 0 ` - rows, err := db.query(ctx, query, modulePath, version) + rows, err := db.db.Query(ctx, query, modulePath, version) if err != nil { return nil, err } @@ -352,7 +353,7 @@ func (db *DB) GetPackageLicenses(ctx context.Context, pkgPath, modulePath, versi AND p.version = l.version AND p.license_file_path = l.file_path;` - rows, err := db.query(ctx, query, pkgPath, modulePath, version) + rows, err := db.db.Query(ctx, query, pkgPath, modulePath, version) if err != nil { return nil, err } @@ -472,9 +473,9 @@ func (db *DB) GetVersionInfo(ctx context.Context, modulePath string, version str } var vi internal.VersionInfo - row := db.queryRow(ctx, query, args...) + row := db.db.QueryRow(ctx, query, args...) if err := row.Scan(&vi.ModulePath, &vi.Version, &vi.CommitTime, - nullIsEmpty(&vi.ReadmeFilePath), &vi.ReadmeContents, &vi.VersionType, + database.NullIsEmpty(&vi.ReadmeFilePath), &vi.ReadmeContents, &vi.VersionType, jsonbScanner{&vi.SourceInfo}); err != nil { if err == sql.ErrNoRows { return nil, xerrors.Errorf("module version %s@%s: %w", modulePath, version, derrors.NotFound) diff --git a/internal/postgres/directory.go b/internal/postgres/directory.go index 539fab83..ff7b2aab 100644 --- a/internal/postgres/directory.go +++ b/internal/postgres/directory.go @@ -12,6 +12,7 @@ import ( "github.com/lib/pq" "golang.org/x/discovery/internal" + "golang.org/x/discovery/internal/database" "golang.org/x/discovery/internal/derrors" "golang.org/x/discovery/internal/stdlib" "golang.org/x/xerrors" @@ -86,7 +87,7 @@ func (db *DB) GetDirectory(ctx context.Context, dirPath, modulePath, version str &pkg.GOARCH, &vi.Version, &vi.ModulePath, - nullIsEmpty(&vi.ReadmeFilePath), + database.NullIsEmpty(&vi.ReadmeFilePath), &vi.ReadmeContents, &vi.CommitTime, &vi.VersionType, @@ -101,7 +102,7 @@ func (db *DB) GetDirectory(ctx context.Context, dirPath, modulePath, version str packages = append(packages, &pkg) return nil } - if err := db.runQuery(ctx, query, collect, args...); err != nil { + if err := db.db.RunQuery(ctx, query, collect, args...); err != nil { return nil, err } if len(packages) == 0 { diff --git a/internal/postgres/excluded.go b/internal/postgres/excluded.go index e82e7c61..a8186aac 100644 --- a/internal/postgres/excluded.go +++ b/internal/postgres/excluded.go @@ -17,7 +17,7 @@ func (db *DB) IsExcluded(ctx context.Context, path string) (_ bool, err error) { defer derrors.Wrap(&err, "DB.IsExcluded(ctx, %q)", path) const query = "SELECT prefix FROM excluded_prefixes WHERE starts_with($1, prefix);" - row := db.queryRow(ctx, query, path) + row := db.db.QueryRow(ctx, query, path) var prefix string err = row.Scan(&prefix) switch err { @@ -37,7 +37,7 @@ func (db *DB) IsExcluded(ctx context.Context, path string) (_ bool, err error) { func (db *DB) InsertExcludedPrefix(ctx context.Context, prefix, user, reason string) (err error) { defer derrors.Wrap(&err, "DB.InsertExcludedPrefix(ctx, %q, %q)", prefix, reason) - _, err = db.exec(ctx, "INSERT INTO excluded_prefixes (prefix, created_by, reason) VALUES ($1, $2, $3)", + _, err = db.db.Exec(ctx, "INSERT INTO excluded_prefixes (prefix, created_by, reason) VALUES ($1, $2, $3)", prefix, user, reason) return err } diff --git a/internal/postgres/excluded_test.go b/internal/postgres/excluded_test.go index a355e8dc..eedb096d 100644 --- a/internal/postgres/excluded_test.go +++ b/internal/postgres/excluded_test.go @@ -14,7 +14,7 @@ func TestIsExcluded(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if _, err := testDB.exec(ctx, "INSERT INTO excluded_prefixes (prefix, created_by, reason) VALUES ('bad', 'someone', 'because')"); err != nil { + if _, err := testDB.db.Exec(ctx, "INSERT INTO excluded_prefixes (prefix, created_by, reason) VALUES ('bad', 'someone', 'because')"); err != nil { t.Fatal(err) } diff --git a/internal/postgres/insert_version.go b/internal/postgres/insert_version.go index 1f8c2682..37c54ae8 100644 --- a/internal/postgres/insert_version.go +++ b/internal/postgres/insert_version.go @@ -16,6 +16,7 @@ import ( "github.com/lib/pq" "golang.org/x/discovery/internal" + "golang.org/x/discovery/internal/database" "golang.org/x/discovery/internal/derrors" "golang.org/x/discovery/internal/stdlib" "golang.org/x/discovery/internal/thirdparty/module" @@ -80,7 +81,7 @@ func (db *DB) saveVersion(ctx context.Context, version *internal.Version) error sort.Strings(p.Imports) } - err := db.Transact(func(tx *sql.Tx) error { + err := db.db.Transact(func(tx *sql.Tx) error { majorint, minorint, patchint, prerelease, err := extractSemverParts(version.Version) if err != nil { return fmt.Errorf("extractSemverParts(%q): %v", version.Version, err) @@ -96,7 +97,7 @@ func (db *DB) saveVersion(ctx context.Context, version *internal.Version) error if err != nil { return err } - if _, err := execTx(ctx, tx, + if _, err := database.ExecTx(ctx, tx, `INSERT INTO versions( module_path, version, @@ -145,7 +146,8 @@ func (db *DB) saveVersion(ctx context.Context, version *internal.Version) error "types", "coverage", } - if err := bulkInsert(ctx, tx, "licenses", licenseCols, licenseValues, onConflictDoNothing); err != nil { + if err := database.BulkInsert(ctx, tx, "licenses", licenseCols, licenseValues, + database.OnConflictDoNothing); err != nil { return err } } @@ -204,7 +206,7 @@ func (db *DB) saveVersion(ctx context.Context, version *internal.Version) error "goarch", "commit_time", } - if err := bulkInsert(ctx, tx, "packages", pkgCols, pkgValues, onConflictDoNothing); err != nil { + if err := database.BulkInsert(ctx, tx, "packages", pkgCols, pkgValues, database.OnConflictDoNothing); err != nil { return err } } @@ -216,7 +218,7 @@ func (db *DB) saveVersion(ctx context.Context, version *internal.Version) error "from_version", "to_path", } - if err := bulkInsert(ctx, tx, "imports", importCols, importValues, onConflictDoNothing); err != nil { + if err := database.BulkInsert(ctx, tx, "imports", importCols, importValues, database.OnConflictDoNothing); err != nil { return err } @@ -225,7 +227,7 @@ func (db *DB) saveVersion(ctx context.Context, version *internal.Version) error "from_module_path", "to_path", } - if err := bulkInsert(ctx, tx, "imports_unique", importUniqueCols, importUniqueValues, onConflictDoNothing); err != nil { + if err := database.BulkInsert(ctx, tx, "imports_unique", importUniqueCols, importUniqueValues, database.OnConflictDoNothing); err != nil { return err } } @@ -436,9 +438,9 @@ func (db *DB) DeleteVersion(ctx context.Context, tx *sql.Tx, modulePath, version // CASCADE constraints, that will trigger deletions from all other tables. const stmt = `DELETE FROM versions WHERE module_path=$1 AND version=$2` if tx == nil { - _, err = db.exec(ctx, stmt, modulePath, version) + _, err = db.db.Exec(ctx, stmt, modulePath, version) } else { - _, err = execTx(ctx, tx, stmt, modulePath, version) + _, err = database.ExecTx(ctx, tx, stmt, modulePath, version) } return err } diff --git a/internal/postgres/insert_version_test.go b/internal/postgres/insert_version_test.go index b6058c2b..7b8a4a7d 100644 --- a/internal/postgres/insert_version_test.go +++ b/internal/postgres/insert_version_test.go @@ -224,7 +224,7 @@ func TestPostgres_ReadAndWriteVersionOtherColumns(t *testing.T) { versions WHERE module_path = $1 AND version = $2` - row := testDB.queryRow(ctx, query, v.ModulePath, v.Version) + row := testDB.db.QueryRow(ctx, query, v.ModulePath, v.Version) var got other if err := row.Scan(&got.major, &got.minor, &got.patch, &got.prerelease, &got.seriesPath); err != nil { t.Fatal(err) diff --git a/internal/postgres/package.go b/internal/postgres/package.go index 0ab8d7ba..8f231435 100644 --- a/internal/postgres/package.go +++ b/internal/postgres/package.go @@ -11,6 +11,7 @@ import ( "github.com/lib/pq" "golang.org/x/discovery/internal" + "golang.org/x/discovery/internal/database" "golang.org/x/discovery/internal/derrors" "golang.org/x/discovery/internal/stdlib" "golang.org/x/xerrors" @@ -138,11 +139,11 @@ func (db *DB) GetPackage(ctx context.Context, pkgPath, modulePath, version strin pkg internal.VersionedPackage licenseTypes, licensePaths []string ) - row := db.queryRow(ctx, query, args...) + row := db.db.QueryRow(ctx, query, args...) err = row.Scan(&pkg.Path, &pkg.Name, &pkg.Synopsis, &pkg.V1Path, pq.Array(&licenseTypes), pq.Array(&licensePaths), &pkg.DocumentationHTML, &pkg.GOOS, &pkg.GOARCH, &pkg.Version, - &pkg.CommitTime, nullIsEmpty(&pkg.ReadmeFilePath), &pkg.ReadmeContents, + &pkg.CommitTime, database.NullIsEmpty(&pkg.ReadmeFilePath), &pkg.ReadmeContents, &pkg.ModulePath, &pkg.VersionType, jsonbScanner{&pkg.SourceInfo}) if err != nil { if err == sql.ErrNoRows { diff --git a/internal/postgres/parent_directories_test.go b/internal/postgres/parent_directories_test.go index 1fd297d9..b40dc02e 100644 --- a/internal/postgres/parent_directories_test.go +++ b/internal/postgres/parent_directories_test.go @@ -76,7 +76,7 @@ func TestToTsvectorParentDirectoriesStoredProcedure(t *testing.T) { } var got []string - err := testDB.queryRow(ctx, + err := testDB.db.QueryRow(ctx, `SELECT tsvector_to_array(tsv_parent_directories) FROM packages WHERE path = $1;`, tc.path).Scan(pq.Array(&got)) if err != nil { @@ -86,7 +86,7 @@ func TestToTsvectorParentDirectoriesStoredProcedure(t *testing.T) { t.Errorf("tsvector_to_array FROM packages for %q mismatch (-want +got):\n%s", tc.path, diff) } - err = testDB.queryRow(ctx, + err = testDB.db.QueryRow(ctx, `SELECT tsvector_to_array(tsv_parent_directories) FROM search_documents WHERE package_path = $1;`, tc.path).Scan(pq.Array(&got)) if err != nil { diff --git a/internal/postgres/postgres.go b/internal/postgres/postgres.go index 3a958c8c..d437838a 100644 --- a/internal/postgres/postgres.go +++ b/internal/postgres/postgres.go @@ -5,310 +5,38 @@ package postgres import ( - "context" "database/sql" - "fmt" - "regexp" - "strings" - "sync/atomic" - "time" - "unicode" - "golang.org/x/discovery/internal/config" - "golang.org/x/discovery/internal/derrors" - "golang.org/x/discovery/internal/log" + "golang.org/x/discovery/internal/database" ) -// DB wraps a sql.DB to provide an API for interacting with discovery data -// stored in Postgres. type DB struct { - db *sql.DB + db *database.DB } -// GetSQLDB returns the underlying SQL database for the postgres instance. This -// allows the ETL to perform streaming operations on the database. -func (db *DB) GetSQLDB() *sql.DB { - return db.db -} - -func (db *DB) exec(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) { - defer logQuery(query, args)(&err) - - return db.db.ExecContext(ctx, query, args...) -} - -func execTx(ctx context.Context, tx *sql.Tx, query string, args ...interface{}) (res sql.Result, err error) { - defer logQuery(query, args)(&err) - - return tx.ExecContext(ctx, query, args...) -} - -func (db *DB) query(ctx context.Context, query string, args ...interface{}) (_ *sql.Rows, err error) { - defer logQuery(query, args)(&err) - return db.db.QueryContext(ctx, query, args...) -} - -func (db *DB) queryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { - defer logQuery(query, args)(nil) - return db.db.QueryRowContext(ctx, query, args...) -} - -var ( - queryCounter int64 // atomic: per-process counter for unique query IDs - queryLoggingDisabled bool // For use in tests only: not concurrency-safe. -) - -type queryEndLogEntry struct { - ID string - Query string - Args string - DurationSeconds float64 - Error string `json:",omitempty"` -} - -func logQuery(query string, args []interface{}) func(*error) { - if queryLoggingDisabled { - return func(*error) {} - } - const maxlen = 300 // maximum length of displayed query - - // To make the query more compact and readable, replace newlines with spaces - // and collapse adjacent whitespace. - var r []rune - for _, c := range query { - if c == '\n' { - c = ' ' - } - if len(r) == 0 || !unicode.IsSpace(r[len(r)-1]) || !unicode.IsSpace(c) { - r = append(r, c) - } - } - query = string(r) - if len(query) > maxlen { - query = query[:maxlen] + "..." - } - - instanceID := config.InstanceID() - if instanceID == "" { - instanceID = "local" - } else { - // Instance IDs are long strings. The low-order part seems quite random, so - // shortening the ID will still likely result in something unique. - instanceID = instanceID[len(instanceID)-4:] - } - n := atomic.AddInt64(&queryCounter, 1) - uid := fmt.Sprintf("%s-%d", instanceID, n) - - // Construct a short string of the args. - const ( - maxArgs = 20 - maxArgLen = 50 - ) - var argStrings []string - for i := 0; i < len(args) && i < maxArgs; i++ { - s := fmt.Sprint(args[i]) - if len(s) > maxArgLen { - s = s[:maxArgLen] + "..." - } - argStrings = append(argStrings, s) - } - if len(args) > maxArgs { - argStrings = append(argStrings, "...") - } - argString := strings.Join(argStrings, ", ") - - log.Debugf("%s %s args=%s", uid, query, argString) - start := time.Now() - return func(errp *error) { - dur := time.Since(start) - if errp == nil { // happens with queryRow - log.Debugf("%s done", uid) - } else { - derrors.Wrap(errp, "DB running query %s", uid) - entry := queryEndLogEntry{ - ID: uid, - Query: query, - Args: argString, - DurationSeconds: dur.Seconds(), - } - if *errp == nil { - log.Debug(entry) - } else { - entry.Error = (*errp).Error() - log.Error(entry) - } - } - } -} - -// Open creates a new DB for the given Postgres connection string. -func Open(driverName, dbinfo string) (_ *DB, err error) { - defer derrors.Wrap(&err, "postgres.Open(%q, %q)", - driverName, redactPassword(dbinfo)) - - db, err := sql.Open(driverName, dbinfo) +// Open opens a new postgres DB. +// TODO(jba): take a *sql.DB. +func Open(driverName, dbinfo string) (*DB, error) { + db, err := database.Open(driverName, dbinfo) if err != nil { return nil, err } - - if err = db.Ping(); err != nil { - return nil, err - } return &DB{db}, nil } -var passwordRegexp = regexp.MustCompile(`password=\S+`) - -func redactPassword(dbinfo string) string { - return passwordRegexp.ReplaceAllLiteralString(dbinfo, "password=REDACTED") -} - -// Transact executes the given function in the context of a SQL transaction, -// rolling back the transaction if the function panics or returns an error. -func (db *DB) Transact(txFunc func(*sql.Tx) error) (err error) { - tx, err := db.db.Begin() - if err != nil { - return fmt.Errorf("db.Begin(): %v", err) - } - - defer func() { - if p := recover(); p != nil { - tx.Rollback() - panic(p) - } else if err != nil { - tx.Rollback() - } else { - if err = tx.Commit(); err != nil { - err = fmt.Errorf("tx.Commit(): %v", err) - } - } - }() - - if err := txFunc(tx); err != nil { - return fmt.Errorf("txFunc(tx): %v", err) - } - return nil -} - -const onConflictDoNothing = "ON CONFLICT DO NOTHING" - -// bulkInsert constructs and executes a multi-value insert statement. The -// query is constructed using the format: INSERT TO <table> (<columns>) VALUES -// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it -// append ON CONFLICT DO NOTHING to the end of the query. The query is executed -// using a PREPARE statement with the provided values. -func bulkInsert(ctx context.Context, tx *sql.Tx, table string, columns []string, values []interface{}, conflictAction string) (err error) { - defer derrors.Wrap(&err, "bulkInsert(ctx, tx, %q, %v, [%d values], %q)", - table, columns, len(values), conflictAction) - - if remainder := len(values) % len(columns); remainder != 0 { - return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", remainder) - } - - // Postgres supports up to 65535 parameters, but stop well before that - // so we don't construct humongous queries. - const maxParameters = 1000 - stride := (maxParameters / len(columns)) * len(columns) - if stride == 0 { - // This is a pathological case (len(columns) > maxParameters), but we - // handle it cautiously. - return fmt.Errorf("too many columns to insert: %d", len(columns)) - } - for leftBound := 0; leftBound < len(values); leftBound += stride { - rightBound := leftBound + stride - if rightBound > len(values) { - rightBound = len(values) - } - valueSlice := values[leftBound:rightBound] - query, err := buildInsertQuery(table, columns, valueSlice, conflictAction) - if err != nil { - return fmt.Errorf("buildInsertQuery(%q, %v, values[%d:%d], %q): %v", table, columns, leftBound, rightBound, conflictAction, err) - } - - if _, err := execTx(ctx, tx, query, valueSlice...); err != nil { - return fmt.Errorf("tx.ExecContext(ctx, [bulk insert query], values[%d:%d]): %v", leftBound, rightBound, err) - } - } - return nil -} - -// buildInsertQuery builds an multi-value insert query, following the format: -// INSERT TO <table> (<columns>) VALUES -// (<placeholders-for-each-item-in-values>) If conflictNoAction is true, it -// append ON CONFLICT DO NOTHING to the end of the query. -// -// When calling buildInsertQuery, it must be true that -// len(values) % len(columns) == 0 -func buildInsertQuery(table string, columns []string, values []interface{}, conflictAction string) (string, error) { - var b strings.Builder - fmt.Fprintf(&b, "INSERT INTO %s", table) - fmt.Fprintf(&b, "(%s) VALUES", strings.Join(columns, ", ")) - - var placeholders []string - for i := 1; i <= len(values); i++ { - // Construct the full query by adding placeholders for each - // set of values that we want to insert. - placeholders = append(placeholders, fmt.Sprintf("$%d", i)) - if i%len(columns) != 0 { - continue - } - - // When the end of a set is reached, write it to the query - // builder and reset placeholders. - fmt.Fprintf(&b, "(%s)", strings.Join(placeholders, ", ")) - placeholders = []string{} - - // Do not add a comma delimiter after the last set of values. - if i == len(values) { - break - } - b.WriteString(", ") - } - if conflictAction != "" { - b.WriteString(" " + conflictAction) - } - - return b.String(), nil -} - -// Close closes the database connection. +// Close closes a DB. func (db *DB) Close() error { return db.db.Close() } -// runQuery executes query, then calls f on each row. -func (db *DB) runQuery(ctx context.Context, query string, f func(*sql.Rows) error, params ...interface{}) error { - rows, err := db.query(ctx, query, params...) - if err != nil { - return err - } - defer rows.Close() - - for rows.Next() { - if err := f(rows); err != nil { - return err - } - } - return rows.Err() -} - -// emptyStringScanner wraps the functionality of sql.NullString to just write -// an empty string if the value is NULL. -type emptyStringScanner struct { - ptr *string -} - -func (e emptyStringScanner) Scan(value interface{}) error { - var ns sql.NullString - if err := ns.Scan(value); err != nil { - return err - } - *e.ptr = ns.String - return nil +// Underlying returns the *database.DB inside db. +func (db *DB) Underlying() *database.DB { + return db.db } -// nullIsEmpty returns a sql.Scanner that writes the empty string to s if the -// sql.Value is NULL. -func nullIsEmpty(s *string) sql.Scanner { - return emptyStringScanner{s} +// TODO(jba): remove. +// GetSQLDB returns the underlying SQL database for the postgres instance. This +// allows the ETL to perform streaming operations on the database. +func (db *DB) GetSQLDB() *sql.DB { + return db.db.Underlying() } diff --git a/internal/postgres/postgres_test.go b/internal/postgres/postgres_test.go index 73813676..352a9c49 100644 --- a/internal/postgres/postgres_test.go +++ b/internal/postgres/postgres_test.go @@ -5,9 +5,6 @@ package postgres import ( - "context" - "database/sql" - "fmt" "testing" "time" ) @@ -19,149 +16,3 @@ var testDB *DB func TestMain(m *testing.M) { RunDBTests("discovery_postgres_test", m, &testDB) } - -func TestBulkInsert(t *testing.T) { - table := "test_bulk_insert" - for _, tc := range []struct { - name string - columns []string - values []interface{} - conflictAction string - wantErr bool - wantCount int - }{ - { - - name: "test-one-row", - columns: []string{"colA"}, - values: []interface{}{"valueA"}, - wantCount: 1, - }, - { - - name: "test-multiple-rows", - columns: []string{"colA"}, - values: []interface{}{"valueA1", "valueA2", "valueA3"}, - wantCount: 3, - }, - { - - name: "test-invalid-column-name", - columns: []string{"invalid_col"}, - values: []interface{}{"valueA"}, - wantErr: true, - }, - { - - name: "test-mismatch-num-cols-and-vals", - columns: []string{"colA", "colB"}, - values: []interface{}{"valueA1", "valueB1", "valueA2"}, - wantErr: true, - }, - { - - name: "test-conflict-no-action-true", - columns: []string{"colA"}, - values: []interface{}{"valueA", "valueA"}, - conflictAction: onConflictDoNothing, - wantCount: 1, - }, - { - - name: "test-conflict-no-action-false", - columns: []string{"colA"}, - values: []interface{}{"valueA", "valueA"}, - wantErr: true, - }, - { - - // This should execute the statement - // INSERT INTO series (path) VALUES ('''); TRUNCATE series CASCADE;)'); - // which will insert a row with path value: - // '); TRUNCATE series CASCADE;) - // Rather than the statement - // INSERT INTO series (path) VALUES (''); TRUNCATE series CASCADE;)); - // which would truncate most tables in the database. - name: "test-sql-injection", - columns: []string{"colA"}, - values: []interface{}{fmt.Sprintf("''); TRUNCATE %s CASCADE;))", table)}, - conflictAction: onConflictDoNothing, - wantCount: 1, - }, - } { - t.Run(tc.name, func(t *testing.T) { - defer ResetTestDB(testDB, t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - createQuery := fmt.Sprintf(`CREATE TABLE %s ( - colA TEXT NOT NULL, - colB TEXT, - PRIMARY KEY (colA) - );`, table) - if _, err := testDB.exec(ctx, createQuery); err != nil { - t.Fatal(err) - } - defer func() { - dropTableQuery := fmt.Sprintf("DROP TABLE %s;", table) - if _, err := testDB.exec(ctx, dropTableQuery); err != nil { - t.Fatal(err) - } - }() - - if err := testDB.Transact(func(tx *sql.Tx) error { - return bulkInsert(ctx, tx, table, tc.columns, tc.values, tc.conflictAction) - }); tc.wantErr && err == nil || !tc.wantErr && err != nil { - t.Errorf("testDB.Transact: %v | wantErr = %t", err, tc.wantErr) - } - - if tc.wantCount != 0 { - var count int - query := "SELECT COUNT(*) FROM " + table - row := testDB.queryRow(ctx, query) - err := row.Scan(&count) - if err != nil { - t.Fatalf("testDB.queryRow(%q): %v", query, err) - } - if count != tc.wantCount { - t.Errorf("testDB.queryRow(%q) = %d; want = %d", query, count, tc.wantCount) - } - } - }) - } -} - -func TestLargeBulkInsert(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - if _, err := testDB.exec(ctx, `CREATE TEMPORARY TABLE test_large_bulk (i BIGINT);`); err != nil { - t.Fatal(err) - } - const size = 150000 - vals := make([]interface{}, size) - for i := 0; i < size; i++ { - vals[i] = i + 1 - } - if err := testDB.Transact(func(tx *sql.Tx) error { - return bulkInsert(ctx, tx, "test_large_bulk", []string{"i"}, vals, "") - }); err != nil { - t.Fatal(err) - } - rows, err := testDB.query(ctx, `SELECT i FROM test_large_bulk;`) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - sum := int64(0) - for rows.Next() { - var i int64 - if err := rows.Scan(&i); err != nil { - t.Fatal(err) - } - sum += i - } - var want int64 = size * (size + 1) / 2 - if sum != want { - t.Errorf("sum = %d, want %d", sum, want) - } -} diff --git a/internal/postgres/search.go b/internal/postgres/search.go index d302150a..4b731a5a 100644 --- a/internal/postgres/search.go +++ b/internal/postgres/search.go @@ -19,6 +19,7 @@ import ( "go.opencensus.io/stats" "go.opencensus.io/stats/view" "go.opencensus.io/tag" + "golang.org/x/discovery/internal/database" "golang.org/x/discovery/internal/derrors" "golang.org/x/discovery/internal/log" "golang.org/x/discovery/internal/stdlib" @@ -305,7 +306,7 @@ type estimateResponse struct { // EstimateResultsCount uses the hyperloglog algorithm to estimate the number // of results for the given search term. func (db *DB) estimateResultsCount(ctx context.Context, q string) estimateResponse { - row := db.queryRow(ctx, hllQuery, q) + row := db.db.QueryRow(ctx, hllQuery, q) var estimate sql.NullInt64 if err := row.Scan(&estimate); err != nil { return estimateResponse{err: fmt.Errorf("row.Scan(): %v", err)} @@ -353,7 +354,7 @@ func (db *DB) deepSearch(ctx context.Context, q string, limit, offset int) searc results = append(results, &r) return nil } - err := db.runQuery(ctx, query, collect, q, limit, offset) + err := db.db.RunQuery(ctx, query, collect, q, limit, offset) if err != nil { results = nil } @@ -384,7 +385,7 @@ func (db *DB) popularSearch(ctx context.Context, searchQuery string, limit, offs results = append(results, &r) return nil } - err := db.runQuery(ctx, query, collect, searchQuery, limit, offset) + err := db.db.RunQuery(ctx, query, collect, searchQuery, limit, offset) if err != nil { results = nil } @@ -445,7 +446,7 @@ func (db *DB) popularSearcher(cutoff int) searcher { results = append(results, &r) return nil } - err := db.runQuery(ctx, query, collect, searchQuery, limit, offset) + err := db.db.RunQuery(ctx, query, collect, searchQuery, limit, offset) if err != nil { results = nil } else if len(results) != limit { @@ -511,7 +512,7 @@ func (db *DB) addPackageDataToSearchResults(ctx context.Context, results []*Sear } return nil } - return db.runQuery(ctx, query, collect) + return db.db.RunQuery(ctx, query, collect) } // DeepSearch executes a full scan of the search table in two steps, by first @@ -630,7 +631,7 @@ func (db *DB) Search(ctx context.Context, q string, limit, offset int) (_ []*Sea results = append(results, &sr) return nil } - if err := db.runQuery(ctx, query, collect, q, limit, offset); err != nil { + if err := db.db.RunQuery(ctx, query, collect, q, limit, offset); err != nil { return nil, err } return results, nil @@ -719,7 +720,7 @@ func (db *DB) UpsertSearchDocument(ctx context.Context, path string) (err error) } pathTokens := strings.Join(generatePathTokens(path), " ") - _, err = db.exec(ctx, upsertSearchStatement, path, pathTokens) + _, err = db.db.Exec(ctx, upsertSearchStatement, path, pathTokens) return err } @@ -748,7 +749,7 @@ func (db *DB) GetPackagesForSearchDocumentUpsert(ctx context.Context, limit int) } return nil } - if err := db.runQuery(ctx, query, collect, limit); err != nil { + if err := db.db.RunQuery(ctx, query, collect, limit); err != nil { return nil, err } sort.Strings(paths) @@ -782,7 +783,7 @@ func (db *DB) getSearchDocument(ctx context.Context, path string) (*searchDocume FROM search_documents WHERE package_path=$1` - row := db.queryRow(ctx, query, path) + row := db.db.QueryRow(ctx, query, path) var ( sd searchDocument t pq.NullTime @@ -812,7 +813,7 @@ func (db *DB) UpdateSearchDocumentsImportedByCount(ctx context.Context) (nUpdate if err != nil { return 0, err } - err = db.Transact(func(tx *sql.Tx) error { + err = db.db.Transact(func(tx *sql.Tx) error { if err := insertImportedByCounts(ctx, tx, counts); err != nil { return err } @@ -831,7 +832,7 @@ func (db *DB) computeImportedByCounts(ctx context.Context) (counts map[string]in counts = map[string]int{} // Get all (from_path, to_path) pairs, deduped. // Also get the from_path's module path. - rows, err := db.query(ctx, ` + rows, err := db.db.Query(ctx, ` SELECT from_path, from_module_path, to_path FROM @@ -871,7 +872,7 @@ func insertImportedByCounts(ctx context.Context, tx *sql.Tx, counts map[string]i imported_by_count INTEGER DEFAULT 0 NOT NULL ) ON COMMIT DROP; ` - if _, err := execTx(ctx, tx, createTableQuery); err != nil { + if _, err := database.ExecTx(ctx, tx, createTableQuery); err != nil { return fmt.Errorf("CREATE TABLE: %v", err) } values := make([]interface{}, 0, 2*len(counts)) @@ -879,7 +880,7 @@ func insertImportedByCounts(ctx context.Context, tx *sql.Tx, counts map[string]i values = append(values, p, c) } columns := []string{"package_path", "imported_by_count"} - return bulkInsert(ctx, tx, "computed_imported_by_counts", columns, values, "") + return database.BulkInsert(ctx, tx, "computed_imported_by_counts", columns, values, "") } func compareImportedByCounts(ctx context.Context, tx *sql.Tx) (err error) { @@ -950,7 +951,7 @@ func updateImportedByCounts(ctx context.Context, tx *sql.Tx) (int64, error) { FROM computed_imported_by_counts c WHERE s.package_path = c.package_path;` - res, err := execTx(ctx, tx, updateStmt) + res, err := database.ExecTx(ctx, tx, updateStmt) if err != nil { return 0, fmt.Errorf("error updating imported_by_count and imported_by_count_updated_at for search documents: %v", err) } diff --git a/internal/postgres/search_test.go b/internal/postgres/search_test.go index f6992380..359c2f07 100644 --- a/internal/postgres/search_test.go +++ b/internal/postgres/search_test.go @@ -649,7 +649,7 @@ func TestHllHash(t *testing.T) { h := md5.New() io.WriteString(h, test) want := int64(binary.BigEndian.Uint64(h.Sum(nil)[0:8])) - row := testDB.queryRow(context.Background(), "SELECT hll_hash($1)", test) + row := testDB.db.QueryRow(context.Background(), "SELECT hll_hash($1)", test) var got int64 if err := row.Scan(&got); err != nil { t.Fatal(err) @@ -674,7 +674,7 @@ func TestHllZeros(t *testing.T) { {(1 << 63) - 1, 1}, } for _, test := range tests { - row := testDB.queryRow(context.Background(), "SELECT hll_zeros($1)", test.i) + row := testDB.db.QueryRow(context.Background(), "SELECT hll_zeros($1)", test.i) var got int if err := row.Scan(&got); err != nil { t.Fatal(err) diff --git a/internal/postgres/test_helper.go b/internal/postgres/test_helper.go index 66eae38c..8c35d8ad 100644 --- a/internal/postgres/test_helper.go +++ b/internal/postgres/test_helper.go @@ -9,15 +9,15 @@ import ( "database/sql" "fmt" "log" - "net/url" "os" "path/filepath" - "strings" "testing" "github.com/golang-migrate/migrate/v4" "golang.org/x/discovery/internal" + "golang.org/x/discovery/internal/database" "golang.org/x/discovery/internal/derrors" + "golang.org/x/discovery/internal/testing/dbtest" "golang.org/x/discovery/internal/testing/sample" "golang.org/x/discovery/internal/testing/testhelper" @@ -29,84 +29,9 @@ import ( _ "github.com/lib/pq" ) -func getEnv(key, fallback string) string { - if value, ok := os.LookupEnv(key); ok { - return value - } - return fallback -} - -// dbConnURI generates a postgres connection string in URI format. This is -// necessary as migrate expects a URI. -func dbConnURI(dbName string) string { - var ( - user = getEnv("GO_DISCOVERY_DATABASE_TEST_USER", "postgres") - password = getEnv("GO_DISCOVERY_DATABASE_TEST_PASSWORD", "") - host = getEnv("GO_DISCOVERY_DATABASE_TEST_HOST", "localhost") - port = getEnv("GO_DISCOVERY_DATABASE_TEST_PORT", "5432") - ) - cs := fmt.Sprintf("postgres://%s/%s?sslmode=disable&user=%s&password=%s&port=%s", - host, dbName, url.QueryEscape(user), url.QueryEscape(password), url.QueryEscape(port)) - return cs -} - -// multiErr can be used to combine one or more errors into a single error. -type multiErr []error - -func (m multiErr) Error() string { - var sb strings.Builder - for _, err := range m { - sep := "" - if sb.Len() > 0 { - sep = "|" - } - if err != nil { - sb.WriteString(sep + err.Error()) - } - } - return sb.String() -} - -// connectAndExecute connects to the postgres database specified by uri and -// executes dbFunc, then cleans up the database connection. -func connectAndExecute(uri string, dbFunc func(*sql.DB) error) (outerErr error) { - pg, err := sql.Open("postgres", uri) - if err != nil { - return err - } - defer func() { - if err := pg.Close(); err != nil { - outerErr = multiErr{outerErr, err} - } - }() - return dbFunc(pg) -} - -// createDBIfNotExists checks whether the given dbName is an existing database, -// and creates one if not. -func createDBIfNotExists(dbName string) error { - return connectAndExecute(dbConnURI(""), func(pg *sql.DB) error { - rows, err := pg.Query("SELECT 1 from pg_database WHERE datname = $1 LIMIT 1", dbName) - if err != nil { - return err - } - defer rows.Close() - if !rows.Next() { - if err := rows.Err(); err != nil { - return err - } - log.Printf("Test database %q does not exist, creating.", dbName) - if _, err := pg.Exec(fmt.Sprintf("CREATE DATABASE %q;", dbName)); err != nil { - return fmt.Errorf("error creating %q: %v", dbName, err) - } - } - return nil - }) -} - // recreateDB drops and recreates the database named dbName. func recreateDB(dbName string) error { - return connectAndExecute(dbConnURI(""), func(pg *sql.DB) error { + return dbtest.ConnectAndExecute(dbtest.DBConnURI(""), func(pg *sql.DB) error { if _, err := pg.Exec(fmt.Sprintf("DROP DATABASE %q;", dbName)); err != nil { return fmt.Errorf("error dropping %q: %v", dbName, err) } @@ -128,7 +53,7 @@ func migrationsSource() string { // migration. If this operation fails in the migration step, it returns // isMigrationError=true to signal that the database should be recreated. func tryToMigrate(dbName string) (isMigrationError bool, outerErr error) { - dbURI := dbConnURI(dbName) + dbURI := dbtest.DBConnURI(dbName) source := migrationsSource() m, err := migrate.New(source, dbURI) if err != nil { @@ -136,7 +61,7 @@ func tryToMigrate(dbName string) (isMigrationError bool, outerErr error) { } defer func() { if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil { - outerErr = multiErr{outerErr, srcErr, dbErr} + outerErr = dbtest.MultiErr{outerErr, srcErr, dbErr} } }() if err := m.Up(); err != nil && err != migrate.ErrNoChange { @@ -150,7 +75,7 @@ func tryToMigrate(dbName string) (isMigrationError bool, outerErr error) { func SetupTestDB(dbName string) (_ *DB, err error) { defer derrors.Wrap(&err, "SetupTestDB(%q)", dbName) - if err := createDBIfNotExists(dbName); err != nil { + if err := dbtest.CreateDBIfNotExists(dbName); err != nil { return nil, fmt.Errorf("createDBIfNotExists(%q): %v", dbName, err) } if isMigrationError, err := tryToMigrate(dbName); err != nil { @@ -166,14 +91,14 @@ func SetupTestDB(dbName string) (_ *DB, err error) { return nil, fmt.Errorf("unfixable error migrating database: %v.\nConsider running ./devtools/drop_test_dbs.sh", err) } } - return Open("postgres", dbConnURI(dbName)) + return Open("postgres", dbtest.DBConnURI(dbName)) } // ResetTestDB truncates all data from the given test DB. It should be called // after every test that mutates the database. func ResetTestDB(db *DB, t *testing.T) { t.Helper() - if err := db.Transact(func(tx *sql.Tx) error { + if err := db.db.Transact(func(tx *sql.Tx) error { if _, err := tx.Exec(` TRUNCATE versions CASCADE; TRUNCATE imports_unique;`); err != nil { @@ -195,7 +120,7 @@ func ResetTestDB(db *DB, t *testing.T) { // named dbName. The given *DB reference will be set to the instantiated test func RunDBTests(dbName string, m *testing.M, testDB **DB) { - queryLoggingDisabled = true + database.QueryLoggingDisabled = true db, err := SetupTestDB(dbName) if err != nil { log.Fatal(err) diff --git a/internal/postgres/versionstate.go b/internal/postgres/versionstate.go index 2a0d85f4..03d1ed50 100644 --- a/internal/postgres/versionstate.go +++ b/internal/postgres/versionstate.go @@ -13,6 +13,7 @@ import ( "github.com/lib/pq" "go.opencensus.io/trace" "golang.org/x/discovery/internal" + "golang.org/x/discovery/internal/database" "golang.org/x/discovery/internal/derrors" "golang.org/x/discovery/internal/log" ) @@ -33,8 +34,8 @@ func (db *DB) InsertIndexVersions(ctx context.Context, versions []*internal.Inde DO UPDATE SET index_timestamp=excluded.index_timestamp, next_processed_after=CURRENT_TIMESTAMP` - return db.Transact(func(tx *sql.Tx) error { - return bulkInsert(ctx, tx, "module_version_states", cols, vals, conflictAction) + return db.db.Transact(func(tx *sql.Tx) error { + return database.BulkInsert(ctx, tx, "module_version_states", cols, vals, conflictAction) }) } @@ -70,7 +71,7 @@ func (db *DB) UpsertVersionState(ctx context.Context, modulePath, version, appVe if fetchErr != nil { sqlErrorMsg = sql.NullString{Valid: true, String: fetchErr.Error()} } - result, err := db.exec(ctx, query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg) + result, err := db.db.Exec(ctx, query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg) if err != nil { return err } @@ -95,7 +96,7 @@ func (db *DB) LatestIndexTimestamp(ctx context.Context) (_ time.Time, err error) LIMIT 1` var ts time.Time - row := db.queryRow(ctx, query) + row := db.db.QueryRow(ctx, query) switch err := row.Scan(&ts); err { case sql.ErrNoRows: return time.Time{}, nil @@ -117,7 +118,7 @@ func (db *DB) UpdateVersionStatesForReprocessing(ctx context.Context, appVersion last_processed_at = NULL WHERE app_version <= $1;` - result, err := db.exec(ctx, query, appVersion) + result, err := db.db.Exec(ctx, query, appVersion) if err != nil { return err } @@ -185,7 +186,7 @@ func scanVersionState(scan func(dest ...interface{}) error) (*internal.VersionSt // for the query columns. func (db *DB) queryVersionStates(ctx context.Context, queryFormat string, args ...interface{}) ([]*internal.VersionState, error) { query := fmt.Sprintf(queryFormat, versionStateColumns) - rows, err := db.query(ctx, query, args...) + rows, err := db.db.Query(ctx, query, args...) if err != nil { return nil, err } @@ -262,7 +263,7 @@ func (db *DB) GetVersionState(ctx context.Context, modulePath, version string) ( module_path = $1 AND version = $2;`, versionStateColumns) - row := db.queryRow(ctx, query, modulePath, version) + row := db.db.QueryRow(ctx, query, modulePath, version) v, err := scanVersionState(row.Scan) switch err { case nil: @@ -299,7 +300,7 @@ func (db *DB) GetVersionStats(ctx context.Context) (_ *VersionStats, err error) stats := &VersionStats{ VersionCounts: make(map[int]int), } - err = db.runQuery(ctx, query, func(rows *sql.Rows) error { + err = db.db.RunQuery(ctx, query, func(rows *sql.Rows) error { var ( status sql.NullInt64 indexTimestamp time.Time |
