diff options
| author | Jonathan Amsterdam <jba@google.com> | 2019-08-09 16:14:46 -0400 |
|---|---|---|
| committer | Julie Qiu <julie@golang.org> | 2020-03-27 16:46:41 -0400 |
| commit | a25876d462acdf3e89ba568487f396e491621a38 (patch) | |
| tree | 55c6d8b0028141f6e8fe7e1cbe09e65b944d28dc /internal/postgres | |
| parent | 7cdd419f0dc966646a92081b643cd25cafda00d9 (diff) | |
| download | go-x-pkgsite-a25876d462acdf3e89ba568487f396e491621a38.tar.xz | |
internal/postgres: make sql.DB a regular field
By not embedding it, we don't expose all its methods.
Add selected methods back.
Fixes b/139178399
Change-Id: Ic146dc01c8531e1fd5b56e085f53a7735d44c146
Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/525149
Reviewed-by: Robert Findley <rfindley@google.com>
Diffstat (limited to 'internal/postgres')
| -rw-r--r-- | internal/postgres/details.go | 30 | ||||
| -rw-r--r-- | internal/postgres/insert_version.go | 2 | ||||
| -rw-r--r-- | internal/postgres/postgres.go | 21 | ||||
| -rw-r--r-- | internal/postgres/postgres_test.go | 14 | ||||
| -rw-r--r-- | internal/postgres/search.go | 4 | ||||
| -rw-r--r-- | internal/postgres/versionstate.go | 20 |
6 files changed, 54 insertions, 37 deletions
diff --git a/internal/postgres/details.go b/internal/postgres/details.go index 914a7964..f84abf77 100644 --- a/internal/postgres/details.go +++ b/internal/postgres/details.go @@ -65,7 +65,7 @@ func (db *DB) GetPackage(ctx context.Context, path string, version string) (*int AND p.version = $2 LIMIT 1;` - row := db.QueryRowContext(ctx, query, path, version) + row := db.queryRowContext(ctx, query, path, version) if err := row.Scan(&commitTime, pq.Array(&licenseTypes), pq.Array(&licensePaths), &readmeFilePath, &readmeContents, &modulePath, &name, &synopsis, &v1path, &versionType, &documentation, &repositoryURL, &vcsType, &homepageURL); err != nil { @@ -155,7 +155,7 @@ func (db *DB) GetLatestPackage(ctx context.Context, path string) (*internal.Vers v.prerelease DESC LIMIT 1;` - row := db.QueryRowContext(ctx, query, path) + row := db.queryRowContext(ctx, query, path) if err := row.Scan(&modulePath, pq.Array(&licenseTypes), pq.Array(&licensePaths), &version, &commitTime, &name, &synopsis, &v1path, &readmeFilePath, &readmeContents, &documentation, &repositoryURL, &vcsType, &homepageURL, &versionType); err != nil { if err == sql.ErrNoRows { return nil, xerrors.Errorf("package %s@%s: %w", path, version, derrors.NotFound) @@ -230,9 +230,9 @@ func (db *DB) GetVersion(ctx context.Context, modulePath, version string) (*inte licenseTypes, licensePaths []string ) - rows, err := db.QueryContext(ctx, query, modulePath, version) + rows, err := db.queryContext(ctx, query, modulePath, version) if err != nil { - return nil, fmt.Errorf("db.QueryContext(ctx, %s, %q, %q): %v", query, modulePath, version, err) + return nil, fmt.Errorf("db.queryContext(ctx, %s, %q, %q): %v", query, modulePath, version, err) } defer rows.Close() @@ -348,9 +348,9 @@ func getVersions(ctx context.Context, db *DB, path string, versionTypes []intern query := fmt.Sprintf(baseQuery, strings.Join(vtQuery, " OR "), queryEnd) - rows, err := db.QueryContext(ctx, query, params...) + rows, err := db.queryContext(ctx, query, params...) if err != nil { - return nil, fmt.Errorf("db.QueryContext(ctx, %q, %q): %v", query, path, err) + return nil, fmt.Errorf("db.queryContext(ctx, %q, %q): %v", query, path, err) } defer rows.Close() @@ -395,9 +395,9 @@ func (db *DB) GetImports(ctx context.Context, path, version string) ([]string, e ORDER BY to_path;` - rows, err := db.QueryContext(ctx, query, path, version) + rows, err := db.queryContext(ctx, query, path, version) if err != nil { - return nil, fmt.Errorf("db.QueryContext(ctx, %q, %q, %q): %v", query, path, version, err) + return nil, fmt.Errorf("db.queryContext(ctx, %q, %q, %q): %v", query, path, version, err) } defer rows.Close() @@ -439,9 +439,9 @@ func (db *DB) GetImportedBy(ctx context.Context, path, modulePath string, limit, LIMIT $3 OFFSET $4;` - rows, err := db.QueryContext(ctx, query, path, modulePath, limit, offset) + rows, err := db.queryContext(ctx, query, path, modulePath, limit, offset) if err != nil { - return nil, 0, fmt.Errorf("db.Query(%q, %q) returned error: %v", query, path, err) + return nil, 0, fmt.Errorf("db.query(%q, %q) returned error: %v", query, path, err) } defer rows.Close() @@ -473,9 +473,9 @@ func (db *DB) GetModuleLicenses(ctx context.Context, modulePath, version string) licenses WHERE module_path = $1 AND version = $2;` - rows, err := db.QueryContext(ctx, query, modulePath, version) + rows, err := db.queryContext(ctx, query, modulePath, version) if err != nil { - return nil, fmt.Errorf("db.QueryContext(ctx, %q, %q, %q): %v", query, modulePath, version, err) + return nil, fmt.Errorf("db.queryContext(ctx, %q, %q, %q): %v", query, modulePath, version, err) } defer rows.Close() return collectLicenses(rows) @@ -512,9 +512,9 @@ func (db *DB) GetPackageLicenses(ctx context.Context, pkgPath, modulePath, versi AND p.version = l.version AND p.license_file_path = l.file_path;` - rows, err := db.QueryContext(ctx, query, pkgPath, modulePath, version) + rows, err := db.queryContext(ctx, query, pkgPath, modulePath, version) if err != nil { - return nil, fmt.Errorf("db.QueryContext(ctx, %q, %q): %v", query, pkgPath, err) + return nil, fmt.Errorf("db.queryContext(ctx, %q, %q): %v", query, pkgPath, err) } defer rows.Close() return collectLicenses(rows) @@ -611,7 +611,7 @@ func (db *DB) GetVersionInfo(ctx context.Context, modulePath string, version str FROM versions v WHERE module_path = $1 and version = $2;` - row := db.QueryRowContext(ctx, query, modulePath, version) + row := db.queryRowContext(ctx, query, modulePath, version) if err := row.Scan(&commitTime, &readmeFilePath, &readmeContents, &versionType, &repositoryURL, &vcsType, &homepageURL); err != nil { if err == sql.ErrNoRows { return nil, xerrors.Errorf("module version %s@%s: %w", modulePath, version, derrors.NotFound) diff --git a/internal/postgres/insert_version.go b/internal/postgres/insert_version.go index 0930044e..c1b74ea8 100644 --- a/internal/postgres/insert_version.go +++ b/internal/postgres/insert_version.go @@ -369,7 +369,7 @@ func (db *DB) DeleteVersion(ctx context.Context, tx *sql.Tx, modulePath, version if tx != nil { _, err = tx.ExecContext(ctx, stmt, modulePath, version) } else { - _, err = db.ExecContext(ctx, stmt, modulePath, version) + _, err = db.execContext(ctx, stmt, modulePath, version) } return err } diff --git a/internal/postgres/postgres.go b/internal/postgres/postgres.go index fe2d5a1a..17db450c 100644 --- a/internal/postgres/postgres.go +++ b/internal/postgres/postgres.go @@ -14,7 +14,19 @@ import ( // DB wraps a sql.DB to provide an API for interacting with discovery data // stored in Postgres. type DB struct { - *sql.DB + db *sql.DB +} + +func (db *DB) execContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return db.db.ExecContext(ctx, query, args...) +} + +func (db *DB) queryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return db.db.QueryContext(ctx, query, args...) +} + +func (db *DB) queryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return db.db.QueryRowContext(ctx, query, args...) } // Open creates a new DB for the given Postgres connection string. @@ -33,7 +45,7 @@ func Open(driverName, dbinfo string) (*DB, error) { // Transact executes the given function in the context of a SQL transaction, // rolling back the transaction if the function panics or returns an error. func (db *DB) Transact(txFunc func(*sql.Tx) error) (err error) { - tx, err := db.Begin() + tx, err := db.db.Begin() if err != nil { return fmt.Errorf("db.Begin(): %v", err) } @@ -152,3 +164,8 @@ func buildInsertQuery(table string, columns []string, values []interface{}, conf return b.String(), nil } + +// Close closes the database connection. +func (db *DB) Close() error { + return db.db.Close() +} diff --git a/internal/postgres/postgres_test.go b/internal/postgres/postgres_test.go index e7771f0d..6d991610 100644 --- a/internal/postgres/postgres_test.go +++ b/internal/postgres/postgres_test.go @@ -99,13 +99,13 @@ func TestBulkInsert(t *testing.T) { colB TEXT, PRIMARY KEY (colA) );`, table) - if _, err := testDB.ExecContext(ctx, createQuery); err != nil { - t.Fatalf("testDB.ExecContext(ctx, %q): %v", createQuery, err) + if _, err := testDB.execContext(ctx, createQuery); err != nil { + t.Fatalf("testDB.execContext(ctx, %q): %v", createQuery, err) } defer func() { dropTableQuery := fmt.Sprintf("DROP TABLE %s;", table) - if _, err := testDB.ExecContext(ctx, dropTableQuery); err != nil { - t.Fatalf("testDB.ExecContext(ctx, %q): %v", dropTableQuery, err) + if _, err := testDB.execContext(ctx, dropTableQuery); err != nil { + t.Fatalf("testDB.execContext(ctx, %q): %v", dropTableQuery, err) } }() @@ -118,7 +118,7 @@ func TestBulkInsert(t *testing.T) { if tc.wantCount != 0 { var count int query := "SELECT COUNT(*) FROM " + table - row := testDB.QueryRow(query) + row := testDB.queryRowContext(ctx, query) err := row.Scan(&count) if err != nil { t.Fatalf("testDB.QueryRow(%q): %v", query, err) @@ -134,7 +134,7 @@ func TestBulkInsert(t *testing.T) { func TestLargeBulkInsert(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - if _, err := testDB.ExecContext(ctx, `CREATE TEMPORARY TABLE test_large_bulk (i BIGINT);`); err != nil { + if _, err := testDB.execContext(ctx, `CREATE TEMPORARY TABLE test_large_bulk (i BIGINT);`); err != nil { t.Fatal(err) } const size = 150000 @@ -147,7 +147,7 @@ func TestLargeBulkInsert(t *testing.T) { }); err != nil { t.Fatal(err) } - rows, err := testDB.QueryContext(ctx, `SELECT i FROM test_large_bulk;`) + rows, err := testDB.queryContext(ctx, `SELECT i FROM test_large_bulk;`) if err != nil { t.Fatal(err) } diff --git a/internal/postgres/search.go b/internal/postgres/search.go index c84c09d6..1dc4f7f3 100644 --- a/internal/postgres/search.go +++ b/internal/postgres/search.go @@ -93,7 +93,7 @@ func (db *DB) Search(ctx context.Context, searchQuery string, limit, offset int) package_path LIMIT $2 OFFSET $3;` - rows, err := db.QueryContext(ctx, query, searchQuery, limit, offset) + rows, err := db.queryContext(ctx, query, searchQuery, limit, offset) if err != nil { return nil, fmt.Errorf("db.QueryContext(ctx, %s, %q, %d, %d): %v", query, searchQuery, limit, offset, err) } @@ -139,7 +139,7 @@ func (db *DB) Search(ctx context.Context, searchQuery string, limit, offset int) // locking out concurrent selects on the materialized view. func (db *DB) RefreshSearchDocuments(ctx context.Context) error { query := "REFRESH MATERIALIZED VIEW CONCURRENTLY mvw_search_documents;" - if _, err := db.ExecContext(ctx, query); err != nil { + if _, err := db.execContext(ctx, query); err != nil { return fmt.Errorf("db.ExecContext(ctx, %q): %v", query, err) } return nil diff --git a/internal/postgres/versionstate.go b/internal/postgres/versionstate.go index b5ecf299..e69a2429 100644 --- a/internal/postgres/versionstate.go +++ b/internal/postgres/versionstate.go @@ -65,9 +65,9 @@ func (db *DB) UpsertVersionState(ctx context.Context, modulePath, version, appVe if fetchErr != nil { sqlErrorMsg = sql.NullString{Valid: true, String: fetchErr.Error()} } - result, err := db.ExecContext(ctx, query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg) + result, err := db.execContext(ctx, query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg) if err != nil { - return fmt.Errorf("db.ExecContext(ctx, %q, %q, %q, %q, %q, %q, %v): %v", query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg, err) + return fmt.Errorf("db.execContext(ctx, %q, %q, %q, %q, %q, %q, %v): %v", query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg, err) } affected, err := result.RowsAffected() if err != nil { @@ -88,7 +88,7 @@ func (db *DB) LatestIndexTimestamp(ctx context.Context) (time.Time, error) { LIMIT 1` var ts time.Time - row := db.QueryRowContext(ctx, query) + row := db.queryRowContext(ctx, query) switch err := row.Scan(&ts); err { case sql.ErrNoRows: return time.Time{}, nil @@ -108,9 +108,9 @@ func (db *DB) UpdateVersionStatesForReprocessing(ctx context.Context, appVersion last_processed_at = NULL WHERE app_version <= $1;` - result, err := db.ExecContext(ctx, query, appVersion) + result, err := db.execContext(ctx, query, appVersion) if err != nil { - return fmt.Errorf("db.ExecContext(ctx, %q, %q): %v", query, appVersion, err) + return fmt.Errorf("db.execContext(ctx, %q, %q): %v", query, appVersion, err) } affected, err := result.RowsAffected() if err != nil { @@ -176,9 +176,9 @@ func scanVersionState(scan func(dest ...interface{}) error) (*internal.VersionSt // for the query columns. func (db *DB) queryVersionStates(ctx context.Context, queryFormat string, args ...interface{}) ([]*internal.VersionState, error) { query := fmt.Sprintf(queryFormat, versionStateColumns) - rows, err := db.QueryContext(ctx, query, args...) + rows, err := db.queryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("db.QueryContext(ctx, %q, %v): %v", query, args, err) + return nil, fmt.Errorf("db.queryContext(ctx, %q, %v): %v", query, args, err) } defer rows.Close() @@ -245,7 +245,7 @@ func (db *DB) GetVersionState(ctx context.Context, modulePath, version string) ( module_path = $1 AND version = $2;`, versionStateColumns) - row := db.QueryRowContext(ctx, query, modulePath, version) + row := db.queryRowContext(ctx, query, modulePath, version) v, err := scanVersionState(row.Scan) switch err { case nil: @@ -282,9 +282,9 @@ func (db *DB) GetVersionStats(ctx context.Context) (*VersionStats, error) { indexTimestamp time.Time count int ) - rows, err := db.QueryContext(ctx, query) + rows, err := db.queryContext(ctx, query) if err != nil { - return nil, fmt.Errorf("db.QueryContext(ctx, %q): %v", query, err) + return nil, fmt.Errorf("db.queryContext(ctx, %q): %v", query, err) } defer rows.Close() stats := &VersionStats{ |
