aboutsummaryrefslogtreecommitdiff
path: root/internal/postgres
diff options
context:
space:
mode:
authorJulie Qiu <julie@golang.org>2021-07-17 18:17:26 -0400
committerJulie Qiu <julie@golang.org>2021-07-20 13:22:19 +0000
commitf471605e911992ffec5634e693b82324eaae3efa (patch)
tree93fe4be9691185bab56f0fd9590784233f29fb71 /internal/postgres
parent88d596252dbac48b59ded3a5003f55fd1fc5df4d (diff)
downloadgo-x-pkgsite-f471605e911992ffec5634e693b82324eaae3efa.tar.xz
internal/postgres/{symbolsearch}: add QueryMultiWord
QueryMultiWord is added, which is used for symbol search when there are multiple words in the search input. For golang/go#44142 Change-Id: Ic3bdac04a520464188e73abaa9be0ba5f1037ebb Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/335264 Trust: Julie Qiu <julie@golang.org> Run-TryBot: Julie Qiu <julie@golang.org> TryBot-Result: kokoro <noreply+kokoro@google.com> Reviewed-by: Jonathan Amsterdam <jba@google.com>
Diffstat (limited to 'internal/postgres')
-rw-r--r--internal/postgres/symbolsearch.go8
-rw-r--r--internal/postgres/symbolsearch/gen_query.go6
-rw-r--r--internal/postgres/symbolsearch/query.gen.go91
-rw-r--r--internal/postgres/symbolsearch/symbolsearch.go70
-rw-r--r--internal/postgres/symbolsearch_test.go27
5 files changed, 167 insertions, 35 deletions
diff --git a/internal/postgres/symbolsearch.go b/internal/postgres/symbolsearch.go
index 8ddf3270..b9be6f49 100644
--- a/internal/postgres/symbolsearch.go
+++ b/internal/postgres/symbolsearch.go
@@ -130,10 +130,14 @@ func (db *DB) symbolSearch(ctx context.Context, q string, limit, offset, maxResu
// There is only 1 element, split by 2 dots, so the search must
// be for <package>.<type>.<methodOrFieldName>.
query = symbolsearch.QueryPackageDotSymbol
+ default:
+ // There is no situation where we will get results for oe element
+ // containing more than 2 dots.
+ err = fmt.Errorf("unsupported query structure: %q", q)
}
} else {
- // TODO: add additional queries based on q.
- err = fmt.Errorf("unsupported query structure: %q", q)
+ // The search query contains multiple words, separated by spaces.
+ query = symbolsearch.QueryMultiWord
}
if err == nil {
err = db.db.RunQuery(ctx, query, collect, q, limit, offset)
diff --git a/internal/postgres/symbolsearch/gen_query.go b/internal/postgres/symbolsearch/gen_query.go
index 6a97153f..0cf42ac1 100644
--- a/internal/postgres/symbolsearch/gen_query.go
+++ b/internal/postgres/symbolsearch/gen_query.go
@@ -55,10 +55,14 @@ package symbolsearch
// QueryOneDot is used when the search query is one element
// containing a dot. This means it can either be <package>.<symbol> or
// <type>.<methodOrField>.
+%s
+
+// QueryMultiWord is used when the search query is multiple elements.
%s`,
formatQuery("QuerySymbol", symbolsearch.RawQuerySymbol),
formatQuery("QueryPackageDotSymbol", symbolsearch.RawQueryPackageDotSymbol),
- formatQuery("QueryOneDot", symbolsearch.RawQueryOneDot))
+ formatQuery("QueryOneDot", symbolsearch.RawQueryOneDot),
+ formatQuery("QueryMultiWord", symbolsearch.RawQueryMultiWord))
func formatQuery(name, query string) string {
return fmt.Sprintf("const %s = `%s`", name, query)
diff --git a/internal/postgres/symbolsearch/query.gen.go b/internal/postgres/symbolsearch/query.gen.go
index 1c51d97c..baf8591d 100644
--- a/internal/postgres/symbolsearch/query.gen.go
+++ b/internal/postgres/symbolsearch/query.gen.go
@@ -24,9 +24,9 @@ WITH results AS (
ssd.package_symbol_id,
ssd.goos,
ssd.goarch,
- (ln(exp(1)+imported_by_count)
- * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END
- * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END) AS score
+ ln(exp(1)+imported_by_count)
+ * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END
+ * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END AS score
FROM symbol_search_documents ssd
INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id
@@ -76,13 +76,17 @@ WITH results AS (
ssd.package_symbol_id,
ssd.goos,
ssd.goarch,
- (ln(exp(1)+imported_by_count)
- * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END
- * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END) AS score
+ ln(exp(1)+imported_by_count)
+ * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END
+ * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END AS score
FROM symbol_search_documents ssd
INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id
- WHERE sd.name = split_part($1, '.', 1) AND s.tsv_name_tokens @@ to_tsquery('symbols', substring(replace($1, '_', '-') from E'[^.]*\.(.+)$'))
+ WHERE (
+ sd.name = split_part($1, '.', 1)
+ ) AND (
+ s.tsv_name_tokens @@ to_tsquery('symbols', substring(replace($1, '_', '-') from E'[^.]*\.(.+)$'))
+ )
)
SELECT
r.package_path,
@@ -128,13 +132,78 @@ WITH results AS (
ssd.package_symbol_id,
ssd.goos,
ssd.goarch,
- (ln(exp(1)+imported_by_count)
- * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END
- * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END) AS score
+ ln(exp(1)+imported_by_count)
+ * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END
+ * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END AS score
FROM symbol_search_documents ssd
INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id
- WHERE (sd.name = split_part($1, '.', 1) AND s.tsv_name_tokens @@ to_tsquery('symbols', substring(replace($1, '_', '-') from E'[^.]*\.(.+)$'))) OR (s.tsv_name_tokens @@ to_tsquery('symbols', replace($1, '_', '-')))
+ WHERE (
+ sd.name = split_part($1, '.', 1)
+ ) AND (
+ s.tsv_name_tokens @@ to_tsquery('symbols', substring(replace($1, '_', '-') from E'[^.]*\.(.+)$'))
+ ) OR s.tsv_name_tokens @@ to_tsquery('symbols', replace($1, '_', '-'))
+)
+SELECT
+ r.package_path,
+ r.module_path,
+ r.version,
+ r.package_name,
+ r.synopsis,
+ r.license_types,
+ r.commit_time,
+ r.imported_by_count,
+ r.symbol_name,
+ r.goos,
+ r.goarch,
+ ps.type AS symbol_type,
+ ps.synopsis AS symbol_synopsis,
+ COUNT(*) OVER() AS total
+FROM results r
+INNER JOIN package_symbols ps ON r.package_symbol_id = ps.id
+WHERE r.score > 0.1
+ORDER BY
+ score DESC,
+ commit_time DESC,
+ symbol_name,
+ package_path
+LIMIT $2
+OFFSET $3;`
+
+// QueryMultiWord is used when the search query is multiple elements.
+const QueryMultiWord = `
+WITH results AS (
+ SELECT
+ s.name AS symbol_name,
+ sd.package_path,
+ sd.module_path,
+ sd.version,
+ sd.name AS package_name,
+ sd.synopsis,
+ sd.license_types,
+ sd.commit_time,
+ sd.imported_by_count,
+ ssd.package_symbol_id,
+ ssd.goos,
+ ssd.goarch,
+ (
+ ts_rank(
+ '{0.1, 0.2, 1.0, 1.0}',
+ sd.tsv_path_tokens,
+ to_tsquery('symbols', replace(replace($1, '_', '-'), ' ', ' | '))
+ )
+ * ln(exp(1)+imported_by_count)
+ * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END
+ * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END
+ ) AS score
+ FROM symbol_search_documents ssd
+ INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
+ INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id
+ WHERE (
+ s.tsv_name_tokens @@ to_tsquery('symbols', replace(replace($1, '_', '-'), ' ', ' | '))
+ ) AND (
+ sd.tsv_path_tokens @@ to_tsquery('symbols', replace(replace($1, '_', '-'), ' ', ' | '))
+ )
)
SELECT
r.package_path,
diff --git a/internal/postgres/symbolsearch/symbolsearch.go b/internal/postgres/symbolsearch/symbolsearch.go
index 5e9df1b3..bcf291b6 100644
--- a/internal/postgres/symbolsearch/symbolsearch.go
+++ b/internal/postgres/symbolsearch/symbolsearch.go
@@ -23,6 +23,7 @@ var (
RawQuerySymbol = fmt.Sprintf(symbolSearchBaseQuery, scoreMultipliers, filterSymbol)
RawQueryPackageDotSymbol = fmt.Sprintf(symbolSearchBaseQuery, scoreMultipliers, filterPackageDotSymbol)
RawQueryOneDot = fmt.Sprintf(symbolSearchBaseQuery, scoreMultipliers, filterOneDot)
+ RawQueryMultiWord = fmt.Sprintf(symbolSearchBaseQuery, formatScore(scoreMultiWord), filterMultiWord)
)
var (
@@ -30,41 +31,92 @@ var (
// <symbol> or <type>.<methodOrField>.
filterSymbol = fmt.Sprintf(`s.tsv_name_tokens @@ %s`, toTSQuery("$1"))
+ // filterSymbol is used when $1 contains the full symbol name, either
+ // <symbol> or <type>.<methodOrField>, and has multiple words.
+ filterSymbolOR = fmt.Sprintf(`s.tsv_name_tokens @@ %s`, toTSQuery(splitOR))
+
// filterPackageDotSymbol is used when $1 is either <package>.<symbol> OR
// <package>.<type>.<methodOrField>.
- filterPackageDotSymbol = fmt.Sprintf(
+ filterPackageDotSymbol = fmt.Sprintf("%s AND %s",
// Split the package name from $1, which can be assumed to be the
// element preceding the first dot.
- `sd.name = split_part($1, '.', 1) AND s.tsv_name_tokens @@ %s`,
+ formatFilter("sd.name = split_part($1, '.', 1)"),
// Split the symbol name from $1, which can be assumed to be everything
// following the first dot.
- toTSQuery("substring($1 from E'[^.]*\\.(.+)$')"))
+ fmt.Sprintf(formatFilter("s.tsv_name_tokens @@ %s"),
+ toTSQuery("substring($1 from E'[^.]*\\.(.+)$')")))
// filterOneDot is used when $1 is one word containing a single dot, which
// means it is either <package>.<symbol> or <type>.<methodOrField>.
- filterOneDot = fmt.Sprintf("(%s) OR (%s)", filterPackageDotSymbol, filterSymbol)
+ filterOneDot = fmt.Sprintf("%s OR %s", filterPackageDotSymbol, filterSymbol)
+
+ // filterPackage is used to filter matching elements from
+ // sd.tsv_path_tokens.
+ filterPackage = fmt.Sprintf(`sd.tsv_path_tokens @@ %s`, toTSQuery(splitOR))
+
+ // filterMultiWord when $1 contains multiple words, separated by spaces.
+ // One element for the query must match a symbol name, and one (could be
+ // the same element) must match the package name.
+ filterMultiWord = fmt.Sprintf("%s AND %s", formatFilter(filterSymbolOR),
+ formatFilter(filterPackage))
)
var (
+ // scoreMultiWord is the score when $1 contains multiple words.
+ scoreMultiWord = fmt.Sprintf("%s%s", rankPathTokens, formatMultiplier(scoreMultipliers))
+
// scoreMultipliers is the score of multiplying the multiplers.
//
// It is also used as the score for QuerySymbol and QueryPackageDotIdentifier.
// In both cases, the matching symbols will be filtered in the WHERE
// clause, and the only remaining information to rank the results by are
// the multiplers.
- scoreMultipliers = fmt.Sprintf("%s\n\t\t* %s\n\t\t* %s",
- popularityMultiplier, redistributableMultipler, goModMultipler)
+ scoreMultipliers = fmt.Sprintf("%s%s%s",
+ popularityMultiplier,
+ formatMultiplier(redistributableMultipler),
+ formatMultiplier(goModMultipler))
+
+ rankPathTokens = fmt.Sprintf(
+ "ts_rank(%s,%s,%s"+indent(")", 3),
+ indent("'{0.1, 0.2, 1.0, 1.0}'", 4),
+ indent("sd.tsv_path_tokens", 4),
+ indent(toTSQuery(splitOR), 4))
// Popularity multipler to increase ranking of popular packages.
popularityMultiplier = `ln(exp(1)+imported_by_count)`
// Multipler based on whether the module license is non-redistributable.
- redistributableMultipler = fmt.Sprintf(`CASE WHEN sd.redistributable THEN 1 ELSE %f END`, nonRedistributablePenalty)
+ redistributableMultipler = fmt.Sprintf(
+ `CASE WHEN sd.redistributable THEN 1 ELSE %f END`,
+ nonRedistributablePenalty)
// Multipler based on wehther the module has a go.mod file.
- goModMultipler = fmt.Sprintf(`CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE %f END`, noGoModPenalty)
+ goModMultipler = fmt.Sprintf(
+ `CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE %f END`,
+ noGoModPenalty)
)
+func formatScore(s string) string {
+ return fmt.Sprintf("(\n\t\t\t\t%s\n\t\t\t)", s)
+}
+
+func formatFilter(s string) string {
+ return fmt.Sprintf("(\n\t\t\t%s\n\t\t)", s)
+}
+
+func formatMultiplier(s string) string {
+ return indent(fmt.Sprintf("* %s", s), 3)
+}
+
+func indent(s string, n int) string {
+ for i := 0; i <= n; i++ {
+ s = "\t" + s
+ }
+ return "\n" + s
+}
+
+const splitOR = "replace($1, ' ', ' | ')"
+
// Penalties to search scores, applied as multipliers to the score.
const (
// Module license is non-redistributable.
@@ -102,7 +154,7 @@ WITH results AS (
ssd.package_symbol_id,
ssd.goos,
ssd.goarch,
- (%s) AS score
+ %s AS score
FROM symbol_search_documents ssd
INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id
INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id
diff --git a/internal/postgres/symbolsearch_test.go b/internal/postgres/symbolsearch_test.go
index be9bc90b..ea68b7c0 100644
--- a/internal/postgres/symbolsearch_test.go
+++ b/internal/postgres/symbolsearch_test.go
@@ -81,18 +81,21 @@ func TestSymbolSearch(t *testing.T) {
q: "Type.Method",
want: checkResult(sample.Method),
},
- /*
- {
- name: "test search by <package> <identifier>",
- q: sample.PackageName + " function",
- want: checkResult(sample.Function.SymbolMeta),
- },
- {
- name: "test search by <package-subpath> <identifier>",
- q: "module_name/foo function",
- want: checkResult(sample.Function.SymbolMeta),
- },
- */
+ {
+ name: "test search by <package> <identifier>",
+ q: "foo function",
+ want: checkResult(sample.Function.SymbolMeta),
+ },
+ {
+ name: "test search by <package-subpath> <package-name> <identifier>",
+ q: "github.com/valid foo function",
+ want: checkResult(sample.Function.SymbolMeta),
+ },
+ {
+ name: "test search by <package-subpath> <identifier> subpath contains _",
+ q: "module_name/foo function",
+ want: checkResult(sample.Function.SymbolMeta),
+ },
} {
t.Run(test.name, func(t *testing.T) {
resp, err := testDB.hedgedSearch(ctx, test.q, 2, 0, 100, symbolSearchers, nil)