aboutsummaryrefslogtreecommitdiff
path: root/internal/postgres
diff options
context:
space:
mode:
authorJonathan Amsterdam <jba@google.com>2019-08-09 16:14:46 -0400
committerJulie Qiu <julie@golang.org>2020-03-27 16:46:41 -0400
commita25876d462acdf3e89ba568487f396e491621a38 (patch)
tree55c6d8b0028141f6e8fe7e1cbe09e65b944d28dc /internal/postgres
parent7cdd419f0dc966646a92081b643cd25cafda00d9 (diff)
downloadgo-x-pkgsite-a25876d462acdf3e89ba568487f396e491621a38.tar.xz
internal/postgres: make sql.DB a regular field
By not embedding it, we don't expose all its methods. Add selected methods back. Fixes b/139178399 Change-Id: Ic146dc01c8531e1fd5b56e085f53a7735d44c146 Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/525149 Reviewed-by: Robert Findley <rfindley@google.com>
Diffstat (limited to 'internal/postgres')
-rw-r--r--internal/postgres/details.go30
-rw-r--r--internal/postgres/insert_version.go2
-rw-r--r--internal/postgres/postgres.go21
-rw-r--r--internal/postgres/postgres_test.go14
-rw-r--r--internal/postgres/search.go4
-rw-r--r--internal/postgres/versionstate.go20
6 files changed, 54 insertions, 37 deletions
diff --git a/internal/postgres/details.go b/internal/postgres/details.go
index 914a7964..f84abf77 100644
--- a/internal/postgres/details.go
+++ b/internal/postgres/details.go
@@ -65,7 +65,7 @@ func (db *DB) GetPackage(ctx context.Context, path string, version string) (*int
AND p.version = $2
LIMIT 1;`
- row := db.QueryRowContext(ctx, query, path, version)
+ row := db.queryRowContext(ctx, query, path, version)
if err := row.Scan(&commitTime, pq.Array(&licenseTypes),
pq.Array(&licensePaths), &readmeFilePath, &readmeContents, &modulePath,
&name, &synopsis, &v1path, &versionType, &documentation, &repositoryURL, &vcsType, &homepageURL); err != nil {
@@ -155,7 +155,7 @@ func (db *DB) GetLatestPackage(ctx context.Context, path string) (*internal.Vers
v.prerelease DESC
LIMIT 1;`
- row := db.QueryRowContext(ctx, query, path)
+ row := db.queryRowContext(ctx, query, path)
if err := row.Scan(&modulePath, pq.Array(&licenseTypes), pq.Array(&licensePaths), &version, &commitTime, &name, &synopsis, &v1path, &readmeFilePath, &readmeContents, &documentation, &repositoryURL, &vcsType, &homepageURL, &versionType); err != nil {
if err == sql.ErrNoRows {
return nil, xerrors.Errorf("package %s@%s: %w", path, version, derrors.NotFound)
@@ -230,9 +230,9 @@ func (db *DB) GetVersion(ctx context.Context, modulePath, version string) (*inte
licenseTypes, licensePaths []string
)
- rows, err := db.QueryContext(ctx, query, modulePath, version)
+ rows, err := db.queryContext(ctx, query, modulePath, version)
if err != nil {
- return nil, fmt.Errorf("db.QueryContext(ctx, %s, %q, %q): %v", query, modulePath, version, err)
+ return nil, fmt.Errorf("db.queryContext(ctx, %s, %q, %q): %v", query, modulePath, version, err)
}
defer rows.Close()
@@ -348,9 +348,9 @@ func getVersions(ctx context.Context, db *DB, path string, versionTypes []intern
query := fmt.Sprintf(baseQuery, strings.Join(vtQuery, " OR "), queryEnd)
- rows, err := db.QueryContext(ctx, query, params...)
+ rows, err := db.queryContext(ctx, query, params...)
if err != nil {
- return nil, fmt.Errorf("db.QueryContext(ctx, %q, %q): %v", query, path, err)
+ return nil, fmt.Errorf("db.queryContext(ctx, %q, %q): %v", query, path, err)
}
defer rows.Close()
@@ -395,9 +395,9 @@ func (db *DB) GetImports(ctx context.Context, path, version string) ([]string, e
ORDER BY
to_path;`
- rows, err := db.QueryContext(ctx, query, path, version)
+ rows, err := db.queryContext(ctx, query, path, version)
if err != nil {
- return nil, fmt.Errorf("db.QueryContext(ctx, %q, %q, %q): %v", query, path, version, err)
+ return nil, fmt.Errorf("db.queryContext(ctx, %q, %q, %q): %v", query, path, version, err)
}
defer rows.Close()
@@ -439,9 +439,9 @@ func (db *DB) GetImportedBy(ctx context.Context, path, modulePath string, limit,
LIMIT $3
OFFSET $4;`
- rows, err := db.QueryContext(ctx, query, path, modulePath, limit, offset)
+ rows, err := db.queryContext(ctx, query, path, modulePath, limit, offset)
if err != nil {
- return nil, 0, fmt.Errorf("db.Query(%q, %q) returned error: %v", query, path, err)
+ return nil, 0, fmt.Errorf("db.query(%q, %q) returned error: %v", query, path, err)
}
defer rows.Close()
@@ -473,9 +473,9 @@ func (db *DB) GetModuleLicenses(ctx context.Context, modulePath, version string)
licenses
WHERE
module_path = $1 AND version = $2;`
- rows, err := db.QueryContext(ctx, query, modulePath, version)
+ rows, err := db.queryContext(ctx, query, modulePath, version)
if err != nil {
- return nil, fmt.Errorf("db.QueryContext(ctx, %q, %q, %q): %v", query, modulePath, version, err)
+ return nil, fmt.Errorf("db.queryContext(ctx, %q, %q, %q): %v", query, modulePath, version, err)
}
defer rows.Close()
return collectLicenses(rows)
@@ -512,9 +512,9 @@ func (db *DB) GetPackageLicenses(ctx context.Context, pkgPath, modulePath, versi
AND p.version = l.version
AND p.license_file_path = l.file_path;`
- rows, err := db.QueryContext(ctx, query, pkgPath, modulePath, version)
+ rows, err := db.queryContext(ctx, query, pkgPath, modulePath, version)
if err != nil {
- return nil, fmt.Errorf("db.QueryContext(ctx, %q, %q): %v", query, pkgPath, err)
+ return nil, fmt.Errorf("db.queryContext(ctx, %q, %q): %v", query, pkgPath, err)
}
defer rows.Close()
return collectLicenses(rows)
@@ -611,7 +611,7 @@ func (db *DB) GetVersionInfo(ctx context.Context, modulePath string, version str
FROM
versions v
WHERE module_path = $1 and version = $2;`
- row := db.QueryRowContext(ctx, query, modulePath, version)
+ row := db.queryRowContext(ctx, query, modulePath, version)
if err := row.Scan(&commitTime, &readmeFilePath, &readmeContents, &versionType, &repositoryURL, &vcsType, &homepageURL); err != nil {
if err == sql.ErrNoRows {
return nil, xerrors.Errorf("module version %s@%s: %w", modulePath, version, derrors.NotFound)
diff --git a/internal/postgres/insert_version.go b/internal/postgres/insert_version.go
index 0930044e..c1b74ea8 100644
--- a/internal/postgres/insert_version.go
+++ b/internal/postgres/insert_version.go
@@ -369,7 +369,7 @@ func (db *DB) DeleteVersion(ctx context.Context, tx *sql.Tx, modulePath, version
if tx != nil {
_, err = tx.ExecContext(ctx, stmt, modulePath, version)
} else {
- _, err = db.ExecContext(ctx, stmt, modulePath, version)
+ _, err = db.execContext(ctx, stmt, modulePath, version)
}
return err
}
diff --git a/internal/postgres/postgres.go b/internal/postgres/postgres.go
index fe2d5a1a..17db450c 100644
--- a/internal/postgres/postgres.go
+++ b/internal/postgres/postgres.go
@@ -14,7 +14,19 @@ import (
// DB wraps a sql.DB to provide an API for interacting with discovery data
// stored in Postgres.
type DB struct {
- *sql.DB
+ db *sql.DB
+}
+
+func (db *DB) execContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
+ return db.db.ExecContext(ctx, query, args...)
+}
+
+func (db *DB) queryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
+ return db.db.QueryContext(ctx, query, args...)
+}
+
+func (db *DB) queryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
+ return db.db.QueryRowContext(ctx, query, args...)
}
// Open creates a new DB for the given Postgres connection string.
@@ -33,7 +45,7 @@ func Open(driverName, dbinfo string) (*DB, error) {
// Transact executes the given function in the context of a SQL transaction,
// rolling back the transaction if the function panics or returns an error.
func (db *DB) Transact(txFunc func(*sql.Tx) error) (err error) {
- tx, err := db.Begin()
+ tx, err := db.db.Begin()
if err != nil {
return fmt.Errorf("db.Begin(): %v", err)
}
@@ -152,3 +164,8 @@ func buildInsertQuery(table string, columns []string, values []interface{}, conf
return b.String(), nil
}
+
+// Close closes the database connection.
+func (db *DB) Close() error {
+ return db.db.Close()
+}
diff --git a/internal/postgres/postgres_test.go b/internal/postgres/postgres_test.go
index e7771f0d..6d991610 100644
--- a/internal/postgres/postgres_test.go
+++ b/internal/postgres/postgres_test.go
@@ -99,13 +99,13 @@ func TestBulkInsert(t *testing.T) {
colB TEXT,
PRIMARY KEY (colA)
);`, table)
- if _, err := testDB.ExecContext(ctx, createQuery); err != nil {
- t.Fatalf("testDB.ExecContext(ctx, %q): %v", createQuery, err)
+ if _, err := testDB.execContext(ctx, createQuery); err != nil {
+ t.Fatalf("testDB.execContext(ctx, %q): %v", createQuery, err)
}
defer func() {
dropTableQuery := fmt.Sprintf("DROP TABLE %s;", table)
- if _, err := testDB.ExecContext(ctx, dropTableQuery); err != nil {
- t.Fatalf("testDB.ExecContext(ctx, %q): %v", dropTableQuery, err)
+ if _, err := testDB.execContext(ctx, dropTableQuery); err != nil {
+ t.Fatalf("testDB.execContext(ctx, %q): %v", dropTableQuery, err)
}
}()
@@ -118,7 +118,7 @@ func TestBulkInsert(t *testing.T) {
if tc.wantCount != 0 {
var count int
query := "SELECT COUNT(*) FROM " + table
- row := testDB.QueryRow(query)
+ row := testDB.queryRowContext(ctx, query)
err := row.Scan(&count)
if err != nil {
t.Fatalf("testDB.QueryRow(%q): %v", query, err)
@@ -134,7 +134,7 @@ func TestBulkInsert(t *testing.T) {
func TestLargeBulkInsert(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
- if _, err := testDB.ExecContext(ctx, `CREATE TEMPORARY TABLE test_large_bulk (i BIGINT);`); err != nil {
+ if _, err := testDB.execContext(ctx, `CREATE TEMPORARY TABLE test_large_bulk (i BIGINT);`); err != nil {
t.Fatal(err)
}
const size = 150000
@@ -147,7 +147,7 @@ func TestLargeBulkInsert(t *testing.T) {
}); err != nil {
t.Fatal(err)
}
- rows, err := testDB.QueryContext(ctx, `SELECT i FROM test_large_bulk;`)
+ rows, err := testDB.queryContext(ctx, `SELECT i FROM test_large_bulk;`)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/postgres/search.go b/internal/postgres/search.go
index c84c09d6..1dc4f7f3 100644
--- a/internal/postgres/search.go
+++ b/internal/postgres/search.go
@@ -93,7 +93,7 @@ func (db *DB) Search(ctx context.Context, searchQuery string, limit, offset int)
package_path
LIMIT $2
OFFSET $3;`
- rows, err := db.QueryContext(ctx, query, searchQuery, limit, offset)
+ rows, err := db.queryContext(ctx, query, searchQuery, limit, offset)
if err != nil {
return nil, fmt.Errorf("db.QueryContext(ctx, %s, %q, %d, %d): %v", query, searchQuery, limit, offset, err)
}
@@ -139,7 +139,7 @@ func (db *DB) Search(ctx context.Context, searchQuery string, limit, offset int)
// locking out concurrent selects on the materialized view.
func (db *DB) RefreshSearchDocuments(ctx context.Context) error {
query := "REFRESH MATERIALIZED VIEW CONCURRENTLY mvw_search_documents;"
- if _, err := db.ExecContext(ctx, query); err != nil {
+ if _, err := db.execContext(ctx, query); err != nil {
return fmt.Errorf("db.ExecContext(ctx, %q): %v", query, err)
}
return nil
diff --git a/internal/postgres/versionstate.go b/internal/postgres/versionstate.go
index b5ecf299..e69a2429 100644
--- a/internal/postgres/versionstate.go
+++ b/internal/postgres/versionstate.go
@@ -65,9 +65,9 @@ func (db *DB) UpsertVersionState(ctx context.Context, modulePath, version, appVe
if fetchErr != nil {
sqlErrorMsg = sql.NullString{Valid: true, String: fetchErr.Error()}
}
- result, err := db.ExecContext(ctx, query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg)
+ result, err := db.execContext(ctx, query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg)
if err != nil {
- return fmt.Errorf("db.ExecContext(ctx, %q, %q, %q, %q, %q, %q, %v): %v", query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg, err)
+ return fmt.Errorf("db.execContext(ctx, %q, %q, %q, %q, %q, %q, %v): %v", query, modulePath, version, appVersion, timestamp, status, sqlErrorMsg, err)
}
affected, err := result.RowsAffected()
if err != nil {
@@ -88,7 +88,7 @@ func (db *DB) LatestIndexTimestamp(ctx context.Context) (time.Time, error) {
LIMIT 1`
var ts time.Time
- row := db.QueryRowContext(ctx, query)
+ row := db.queryRowContext(ctx, query)
switch err := row.Scan(&ts); err {
case sql.ErrNoRows:
return time.Time{}, nil
@@ -108,9 +108,9 @@ func (db *DB) UpdateVersionStatesForReprocessing(ctx context.Context, appVersion
last_processed_at = NULL
WHERE
app_version <= $1;`
- result, err := db.ExecContext(ctx, query, appVersion)
+ result, err := db.execContext(ctx, query, appVersion)
if err != nil {
- return fmt.Errorf("db.ExecContext(ctx, %q, %q): %v", query, appVersion, err)
+ return fmt.Errorf("db.execContext(ctx, %q, %q): %v", query, appVersion, err)
}
affected, err := result.RowsAffected()
if err != nil {
@@ -176,9 +176,9 @@ func scanVersionState(scan func(dest ...interface{}) error) (*internal.VersionSt
// for the query columns.
func (db *DB) queryVersionStates(ctx context.Context, queryFormat string, args ...interface{}) ([]*internal.VersionState, error) {
query := fmt.Sprintf(queryFormat, versionStateColumns)
- rows, err := db.QueryContext(ctx, query, args...)
+ rows, err := db.queryContext(ctx, query, args...)
if err != nil {
- return nil, fmt.Errorf("db.QueryContext(ctx, %q, %v): %v", query, args, err)
+ return nil, fmt.Errorf("db.queryContext(ctx, %q, %v): %v", query, args, err)
}
defer rows.Close()
@@ -245,7 +245,7 @@ func (db *DB) GetVersionState(ctx context.Context, modulePath, version string) (
module_path = $1
AND version = $2;`, versionStateColumns)
- row := db.QueryRowContext(ctx, query, modulePath, version)
+ row := db.queryRowContext(ctx, query, modulePath, version)
v, err := scanVersionState(row.Scan)
switch err {
case nil:
@@ -282,9 +282,9 @@ func (db *DB) GetVersionStats(ctx context.Context) (*VersionStats, error) {
indexTimestamp time.Time
count int
)
- rows, err := db.QueryContext(ctx, query)
+ rows, err := db.queryContext(ctx, query)
if err != nil {
- return nil, fmt.Errorf("db.QueryContext(ctx, %q): %v", query, err)
+ return nil, fmt.Errorf("db.queryContext(ctx, %q): %v", query, err)
}
defer rows.Close()
stats := &VersionStats{