diff options
| author | Julie Qiu <julie@golang.org> | 2021-07-17 18:17:26 -0400 |
|---|---|---|
| committer | Julie Qiu <julie@golang.org> | 2021-07-20 13:22:19 +0000 |
| commit | f471605e911992ffec5634e693b82324eaae3efa (patch) | |
| tree | 93fe4be9691185bab56f0fd9590784233f29fb71 /internal/postgres | |
| parent | 88d596252dbac48b59ded3a5003f55fd1fc5df4d (diff) | |
| download | go-x-pkgsite-f471605e911992ffec5634e693b82324eaae3efa.tar.xz | |
internal/postgres/{symbolsearch}: add QueryMultiWord
QueryMultiWord is added, which is used for symbol search when there are
multiple words in the search input.
For golang/go#44142
Change-Id: Ic3bdac04a520464188e73abaa9be0ba5f1037ebb
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/335264
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Diffstat (limited to 'internal/postgres')
| -rw-r--r-- | internal/postgres/symbolsearch.go | 8 | ||||
| -rw-r--r-- | internal/postgres/symbolsearch/gen_query.go | 6 | ||||
| -rw-r--r-- | internal/postgres/symbolsearch/query.gen.go | 91 | ||||
| -rw-r--r-- | internal/postgres/symbolsearch/symbolsearch.go | 70 | ||||
| -rw-r--r-- | internal/postgres/symbolsearch_test.go | 27 |
5 files changed, 167 insertions, 35 deletions
diff --git a/internal/postgres/symbolsearch.go b/internal/postgres/symbolsearch.go index 8ddf3270..b9be6f49 100644 --- a/internal/postgres/symbolsearch.go +++ b/internal/postgres/symbolsearch.go @@ -130,10 +130,14 @@ func (db *DB) symbolSearch(ctx context.Context, q string, limit, offset, maxResu // There is only 1 element, split by 2 dots, so the search must // be for <package>.<type>.<methodOrFieldName>. query = symbolsearch.QueryPackageDotSymbol + default: + // There is no situation where we will get results for oe element + // containing more than 2 dots. + err = fmt.Errorf("unsupported query structure: %q", q) } } else { - // TODO: add additional queries based on q. - err = fmt.Errorf("unsupported query structure: %q", q) + // The search query contains multiple words, separated by spaces. + query = symbolsearch.QueryMultiWord } if err == nil { err = db.db.RunQuery(ctx, query, collect, q, limit, offset) diff --git a/internal/postgres/symbolsearch/gen_query.go b/internal/postgres/symbolsearch/gen_query.go index 6a97153f..0cf42ac1 100644 --- a/internal/postgres/symbolsearch/gen_query.go +++ b/internal/postgres/symbolsearch/gen_query.go @@ -55,10 +55,14 @@ package symbolsearch // QueryOneDot is used when the search query is one element // containing a dot. This means it can either be <package>.<symbol> or // <type>.<methodOrField>. +%s + +// QueryMultiWord is used when the search query is multiple elements. %s`, formatQuery("QuerySymbol", symbolsearch.RawQuerySymbol), formatQuery("QueryPackageDotSymbol", symbolsearch.RawQueryPackageDotSymbol), - formatQuery("QueryOneDot", symbolsearch.RawQueryOneDot)) + formatQuery("QueryOneDot", symbolsearch.RawQueryOneDot), + formatQuery("QueryMultiWord", symbolsearch.RawQueryMultiWord)) func formatQuery(name, query string) string { return fmt.Sprintf("const %s = `%s`", name, query) diff --git a/internal/postgres/symbolsearch/query.gen.go b/internal/postgres/symbolsearch/query.gen.go index 1c51d97c..baf8591d 100644 --- a/internal/postgres/symbolsearch/query.gen.go +++ b/internal/postgres/symbolsearch/query.gen.go @@ -24,9 +24,9 @@ WITH results AS ( ssd.package_symbol_id, ssd.goos, ssd.goarch, - (ln(exp(1)+imported_by_count) - * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END - * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END) AS score + ln(exp(1)+imported_by_count) + * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END + * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END AS score FROM symbol_search_documents ssd INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id @@ -76,13 +76,17 @@ WITH results AS ( ssd.package_symbol_id, ssd.goos, ssd.goarch, - (ln(exp(1)+imported_by_count) - * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END - * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END) AS score + ln(exp(1)+imported_by_count) + * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END + * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END AS score FROM symbol_search_documents ssd INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id - WHERE sd.name = split_part($1, '.', 1) AND s.tsv_name_tokens @@ to_tsquery('symbols', substring(replace($1, '_', '-') from E'[^.]*\.(.+)$')) + WHERE ( + sd.name = split_part($1, '.', 1) + ) AND ( + s.tsv_name_tokens @@ to_tsquery('symbols', substring(replace($1, '_', '-') from E'[^.]*\.(.+)$')) + ) ) SELECT r.package_path, @@ -128,13 +132,78 @@ WITH results AS ( ssd.package_symbol_id, ssd.goos, ssd.goarch, - (ln(exp(1)+imported_by_count) - * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END - * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END) AS score + ln(exp(1)+imported_by_count) + * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END + * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END AS score FROM symbol_search_documents ssd INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id - WHERE (sd.name = split_part($1, '.', 1) AND s.tsv_name_tokens @@ to_tsquery('symbols', substring(replace($1, '_', '-') from E'[^.]*\.(.+)$'))) OR (s.tsv_name_tokens @@ to_tsquery('symbols', replace($1, '_', '-'))) + WHERE ( + sd.name = split_part($1, '.', 1) + ) AND ( + s.tsv_name_tokens @@ to_tsquery('symbols', substring(replace($1, '_', '-') from E'[^.]*\.(.+)$')) + ) OR s.tsv_name_tokens @@ to_tsquery('symbols', replace($1, '_', '-')) +) +SELECT + r.package_path, + r.module_path, + r.version, + r.package_name, + r.synopsis, + r.license_types, + r.commit_time, + r.imported_by_count, + r.symbol_name, + r.goos, + r.goarch, + ps.type AS symbol_type, + ps.synopsis AS symbol_synopsis, + COUNT(*) OVER() AS total +FROM results r +INNER JOIN package_symbols ps ON r.package_symbol_id = ps.id +WHERE r.score > 0.1 +ORDER BY + score DESC, + commit_time DESC, + symbol_name, + package_path +LIMIT $2 +OFFSET $3;` + +// QueryMultiWord is used when the search query is multiple elements. +const QueryMultiWord = ` +WITH results AS ( + SELECT + s.name AS symbol_name, + sd.package_path, + sd.module_path, + sd.version, + sd.name AS package_name, + sd.synopsis, + sd.license_types, + sd.commit_time, + sd.imported_by_count, + ssd.package_symbol_id, + ssd.goos, + ssd.goarch, + ( + ts_rank( + '{0.1, 0.2, 1.0, 1.0}', + sd.tsv_path_tokens, + to_tsquery('symbols', replace(replace($1, '_', '-'), ' ', ' | ')) + ) + * ln(exp(1)+imported_by_count) + * CASE WHEN sd.redistributable THEN 1 ELSE 0.500000 END + * CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE 0.800000 END + ) AS score + FROM symbol_search_documents ssd + INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id + INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id + WHERE ( + s.tsv_name_tokens @@ to_tsquery('symbols', replace(replace($1, '_', '-'), ' ', ' | ')) + ) AND ( + sd.tsv_path_tokens @@ to_tsquery('symbols', replace(replace($1, '_', '-'), ' ', ' | ')) + ) ) SELECT r.package_path, diff --git a/internal/postgres/symbolsearch/symbolsearch.go b/internal/postgres/symbolsearch/symbolsearch.go index 5e9df1b3..bcf291b6 100644 --- a/internal/postgres/symbolsearch/symbolsearch.go +++ b/internal/postgres/symbolsearch/symbolsearch.go @@ -23,6 +23,7 @@ var ( RawQuerySymbol = fmt.Sprintf(symbolSearchBaseQuery, scoreMultipliers, filterSymbol) RawQueryPackageDotSymbol = fmt.Sprintf(symbolSearchBaseQuery, scoreMultipliers, filterPackageDotSymbol) RawQueryOneDot = fmt.Sprintf(symbolSearchBaseQuery, scoreMultipliers, filterOneDot) + RawQueryMultiWord = fmt.Sprintf(symbolSearchBaseQuery, formatScore(scoreMultiWord), filterMultiWord) ) var ( @@ -30,41 +31,92 @@ var ( // <symbol> or <type>.<methodOrField>. filterSymbol = fmt.Sprintf(`s.tsv_name_tokens @@ %s`, toTSQuery("$1")) + // filterSymbol is used when $1 contains the full symbol name, either + // <symbol> or <type>.<methodOrField>, and has multiple words. + filterSymbolOR = fmt.Sprintf(`s.tsv_name_tokens @@ %s`, toTSQuery(splitOR)) + // filterPackageDotSymbol is used when $1 is either <package>.<symbol> OR // <package>.<type>.<methodOrField>. - filterPackageDotSymbol = fmt.Sprintf( + filterPackageDotSymbol = fmt.Sprintf("%s AND %s", // Split the package name from $1, which can be assumed to be the // element preceding the first dot. - `sd.name = split_part($1, '.', 1) AND s.tsv_name_tokens @@ %s`, + formatFilter("sd.name = split_part($1, '.', 1)"), // Split the symbol name from $1, which can be assumed to be everything // following the first dot. - toTSQuery("substring($1 from E'[^.]*\\.(.+)$')")) + fmt.Sprintf(formatFilter("s.tsv_name_tokens @@ %s"), + toTSQuery("substring($1 from E'[^.]*\\.(.+)$')"))) // filterOneDot is used when $1 is one word containing a single dot, which // means it is either <package>.<symbol> or <type>.<methodOrField>. - filterOneDot = fmt.Sprintf("(%s) OR (%s)", filterPackageDotSymbol, filterSymbol) + filterOneDot = fmt.Sprintf("%s OR %s", filterPackageDotSymbol, filterSymbol) + + // filterPackage is used to filter matching elements from + // sd.tsv_path_tokens. + filterPackage = fmt.Sprintf(`sd.tsv_path_tokens @@ %s`, toTSQuery(splitOR)) + + // filterMultiWord when $1 contains multiple words, separated by spaces. + // One element for the query must match a symbol name, and one (could be + // the same element) must match the package name. + filterMultiWord = fmt.Sprintf("%s AND %s", formatFilter(filterSymbolOR), + formatFilter(filterPackage)) ) var ( + // scoreMultiWord is the score when $1 contains multiple words. + scoreMultiWord = fmt.Sprintf("%s%s", rankPathTokens, formatMultiplier(scoreMultipliers)) + // scoreMultipliers is the score of multiplying the multiplers. // // It is also used as the score for QuerySymbol and QueryPackageDotIdentifier. // In both cases, the matching symbols will be filtered in the WHERE // clause, and the only remaining information to rank the results by are // the multiplers. - scoreMultipliers = fmt.Sprintf("%s\n\t\t* %s\n\t\t* %s", - popularityMultiplier, redistributableMultipler, goModMultipler) + scoreMultipliers = fmt.Sprintf("%s%s%s", + popularityMultiplier, + formatMultiplier(redistributableMultipler), + formatMultiplier(goModMultipler)) + + rankPathTokens = fmt.Sprintf( + "ts_rank(%s,%s,%s"+indent(")", 3), + indent("'{0.1, 0.2, 1.0, 1.0}'", 4), + indent("sd.tsv_path_tokens", 4), + indent(toTSQuery(splitOR), 4)) // Popularity multipler to increase ranking of popular packages. popularityMultiplier = `ln(exp(1)+imported_by_count)` // Multipler based on whether the module license is non-redistributable. - redistributableMultipler = fmt.Sprintf(`CASE WHEN sd.redistributable THEN 1 ELSE %f END`, nonRedistributablePenalty) + redistributableMultipler = fmt.Sprintf( + `CASE WHEN sd.redistributable THEN 1 ELSE %f END`, + nonRedistributablePenalty) // Multipler based on wehther the module has a go.mod file. - goModMultipler = fmt.Sprintf(`CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE %f END`, noGoModPenalty) + goModMultipler = fmt.Sprintf( + `CASE WHEN COALESCE(has_go_mod, true) THEN 1 ELSE %f END`, + noGoModPenalty) ) +func formatScore(s string) string { + return fmt.Sprintf("(\n\t\t\t\t%s\n\t\t\t)", s) +} + +func formatFilter(s string) string { + return fmt.Sprintf("(\n\t\t\t%s\n\t\t)", s) +} + +func formatMultiplier(s string) string { + return indent(fmt.Sprintf("* %s", s), 3) +} + +func indent(s string, n int) string { + for i := 0; i <= n; i++ { + s = "\t" + s + } + return "\n" + s +} + +const splitOR = "replace($1, ' ', ' | ')" + // Penalties to search scores, applied as multipliers to the score. const ( // Module license is non-redistributable. @@ -102,7 +154,7 @@ WITH results AS ( ssd.package_symbol_id, ssd.goos, ssd.goarch, - (%s) AS score + %s AS score FROM symbol_search_documents ssd INNER JOIN search_documents sd ON sd.unit_id = ssd.unit_id INNER JOIN symbol_names s ON s.id = ssd.symbol_name_id diff --git a/internal/postgres/symbolsearch_test.go b/internal/postgres/symbolsearch_test.go index be9bc90b..ea68b7c0 100644 --- a/internal/postgres/symbolsearch_test.go +++ b/internal/postgres/symbolsearch_test.go @@ -81,18 +81,21 @@ func TestSymbolSearch(t *testing.T) { q: "Type.Method", want: checkResult(sample.Method), }, - /* - { - name: "test search by <package> <identifier>", - q: sample.PackageName + " function", - want: checkResult(sample.Function.SymbolMeta), - }, - { - name: "test search by <package-subpath> <identifier>", - q: "module_name/foo function", - want: checkResult(sample.Function.SymbolMeta), - }, - */ + { + name: "test search by <package> <identifier>", + q: "foo function", + want: checkResult(sample.Function.SymbolMeta), + }, + { + name: "test search by <package-subpath> <package-name> <identifier>", + q: "github.com/valid foo function", + want: checkResult(sample.Function.SymbolMeta), + }, + { + name: "test search by <package-subpath> <identifier> subpath contains _", + q: "module_name/foo function", + want: checkResult(sample.Function.SymbolMeta), + }, } { t.Run(test.name, func(t *testing.T) { resp, err := testDB.hedgedSearch(ctx, test.q, 2, 0, 100, symbolSearchers, nil) |
