diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/api.go | 61 | ||||
| -rw-r--r-- | internal/api/api_test.go | 157 | ||||
| -rw-r--r-- | internal/frontend/server.go | 1 |
3 files changed, 215 insertions, 4 deletions
diff --git a/internal/api/api.go b/internal/api/api.go index 1d21ba7e..d1419c40 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -19,6 +19,13 @@ import ( "golang.org/x/pkgsite/internal/version" ) +const ( + // maxSearchResults is the maximum number of search results to return for a search query. + maxSearchResults = 1000 + // searchResultsPerPage is the number of search results to return per page for paginated search results. + searchResultsPerPage = 100 +) + // ServePackage handles requests for the v1 package metadata endpoint. func ServePackage(w http.ResponseWriter, r *http.Request, ds internal.DataSource) (err error) { defer derrors.Wrap(&err, "ServePackage") @@ -206,6 +213,7 @@ func ServeModule(w http.ResponseWriter, r *http.Request, ds internal.DataSource) defer derrors.Wrap(&err, "ServeModule") modulePath := strings.TrimPrefix(r.URL.Path, "/v1/module/") + modulePath = strings.Trim(modulePath, "/") if modulePath == "" { return serveErrorJSON(w, http.StatusBadRequest, "missing module path", nil) } @@ -339,6 +347,51 @@ func ServeModulePackages(w http.ResponseWriter, r *http.Request, ds internal.Dat return serveJSON(w, http.StatusOK, resp) } +// ServeSearch handles requests for the v1 search endpoint. +func ServeSearch(w http.ResponseWriter, r *http.Request, ds internal.DataSource) (err error) { + defer derrors.Wrap(&err, "ServeSearch") + + var params SearchParams + if err := ParseParams(r.URL.Query(), ¶ms); err != nil { + return serveErrorJSON(w, http.StatusBadRequest, err.Error(), nil) + } + + if params.Query == "" { + return serveErrorJSON(w, http.StatusBadRequest, "missing query", nil) + } + + dbresults, err := ds.Search(r.Context(), params.Query, internal.SearchOptions{ + MaxResults: maxSearchResults, + SearchSymbols: params.Symbol != "", + SymbolFilter: params.Symbol, + }) + if err != nil { + return err + } + + var results []SearchResult + for _, r := range dbresults { + if params.Filter != "" { + if !strings.Contains(r.Synopsis, params.Filter) && !strings.Contains(r.PackagePath, params.Filter) { + continue + } + } + results = append(results, SearchResult{ + PackagePath: r.PackagePath, + ModulePath: r.ModulePath, + Version: r.Version, + Synopsis: r.Synopsis, + }) + } + + resp, err := paginate(results, params.ListParams, searchResultsPerPage) + if err != nil { + return serveErrorJSON(w, http.StatusBadRequest, err.Error(), nil) + } + + return serveJSON(w, http.StatusOK, resp) +} + // needsResolution reports whether the version string is a sentinel like "latest" or "master". func needsResolution(v string) bool { return v == version.Latest || v == version.Master || v == version.Main @@ -363,6 +416,9 @@ func serveErrorJSON(w http.ResponseWriter, status int, message string, candidate }) } +// paginate returns a paginated response for the given list of items and pagination parameters. +// It uses offset-based pagination with a token that encodes the offset. +// The default limit is used if the provided limit is non-positive. func paginate[T any](all []T, lp ListParams, defaultLimit int) (PaginatedResponse[T], error) { limit := lp.Limit if limit <= 0 { @@ -381,10 +437,7 @@ func paginate[T any](all []T, lp ListParams, defaultLimit int) (PaginatedRespons if offset > len(all) { offset = len(all) } - end := offset + limit - if end > len(all) { - end = len(all) - } + end := min(offset+limit, len(all)) var nextToken string if end < len(all) { diff --git a/internal/api/api_test.go b/internal/api/api_test.go index a0b63904..16a92278 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strconv" "testing" "github.com/google/go-cmp/cmp" @@ -382,3 +383,159 @@ func TestServeModulePackages(t *testing.T) { }) } } + +func TestServeSearch(t *testing.T) { + ctx := context.Background() + ds := fakedatasource.New() + + ds.MustInsertModule(ctx, &internal.Module{ + ModuleInfo: internal.ModuleInfo{ModulePath: "example.com", Version: "v1.0.0"}, + Units: []*internal.Unit{{ + UnitMeta: internal.UnitMeta{ + Path: "example.com/pkg", + ModuleInfo: internal.ModuleInfo{ModulePath: "example.com", Version: "v1.0.0"}, + Name: "pkg", + }, + Documentation: []*internal.Documentation{{Synopsis: "A great package."}}, + }}, + }) + + for _, test := range []struct { + name string + url string + wantStatus int + wantCount int + }{ + { + name: "basic search", + url: "/v1/search?q=great", + wantStatus: http.StatusOK, + wantCount: 1, + }, + { + name: "no results", + url: "/v1/search?q=nonexistent", + wantStatus: http.StatusOK, + wantCount: 0, + }, + { + name: "missing query", + url: "/v1/search", + wantStatus: http.StatusBadRequest, + }, + { + name: "search with filter", + url: "/v1/search?q=great&filter=example.com", + wantStatus: http.StatusOK, + wantCount: 1, + }, + { + name: "search with non-matching filter", + url: "/v1/search?q=great&filter=nomatch", + wantStatus: http.StatusOK, + wantCount: 0, + }, + } { + t.Run(test.name, func(t *testing.T) { + r := httptest.NewRequest("GET", test.url, nil) + w := httptest.NewRecorder() + + err := ServeSearch(w, r, ds) + if err != nil { + t.Fatalf("ServeSearch returned error: %v", err) + } + + if w.Code != test.wantStatus { + t.Errorf("%s: status = %d, want %d", test.name, w.Code, test.wantStatus) + } + + if test.wantStatus == http.StatusOK { + var got PaginatedResponse[SearchResult] + if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { + t.Fatalf("%s: json.Unmarshal: %v", test.name, err) + } + if len(got.Items) != test.wantCount { + t.Errorf("%s: count = %d, want %d", test.name, len(got.Items), test.wantCount) + } + } + }) + } +} + +func TestServeSearchPagination(t *testing.T) { + ctx := context.Background() + ds := fakedatasource.New() + + for i := 0; i < 10; i++ { + pkgPath := "example.com/pkg" + strconv.Itoa(i) + ds.MustInsertModule(ctx, &internal.Module{ + ModuleInfo: internal.ModuleInfo{ModulePath: pkgPath, Version: "v1.0.0"}, + Units: []*internal.Unit{{ + UnitMeta: internal.UnitMeta{ + Path: pkgPath, + ModuleInfo: internal.ModuleInfo{ModulePath: pkgPath, Version: "v1.0.0"}, + Name: "pkg", + }, + Documentation: []*internal.Documentation{{Synopsis: "Synopsis" + strconv.Itoa(i)}}, + }}, + }) + } + + for _, test := range []struct { + name string + url string + wantCount int + wantTotal int + wantNextToken string + }{ + { + name: "first page", + url: "/v1/search?q=Synopsis&limit=3", + wantCount: 3, + wantTotal: 10, + wantNextToken: "3", + }, + { + name: "second page", + url: "/v1/search?q=Synopsis&limit=3&token=3", + wantCount: 3, + wantTotal: 10, + wantNextToken: "6", + }, + { + name: "last page", + url: "/v1/search?q=Synopsis&limit=3&token=9", + wantCount: 1, + wantTotal: 10, + wantNextToken: "", + }, + } { + t.Run(test.name, func(t *testing.T) { + r := httptest.NewRequest("GET", test.url, nil) + w := httptest.NewRecorder() + + if err := ServeSearch(w, r, ds); err != nil { + t.Fatalf("ServeSearch error: %v", err) + } + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + + var got PaginatedResponse[SearchResult] + if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { + t.Fatalf("json.Unmarshal: %v", err) + } + + if len(got.Items) != test.wantCount { + t.Errorf("count = %d, want %d", len(got.Items), test.wantCount) + } + if got.Total != test.wantTotal { + t.Errorf("total = %d, want %d", got.Total, test.wantTotal) + } + if got.NextPageToken != test.wantNextToken { + t.Errorf("nextToken = %q, want %q", got.NextPageToken, test.wantNextToken) + } + }) + } +} diff --git a/internal/frontend/server.go b/internal/frontend/server.go index acd9c909..a1fd4516 100644 --- a/internal/frontend/server.go +++ b/internal/frontend/server.go @@ -240,6 +240,7 @@ func (s *Server) Install(handle func(string, http.Handler), cacher Cacher, authV handle("GET /v1/module/", s.errorHandler(api.ServeModule)) handle("GET /v1/versions/", s.errorHandler(api.ServeModuleVersions)) handle("GET /v1/packages/", s.errorHandler(api.ServeModulePackages)) + handle("GET /v1/search", s.errorHandler(api.ServeSearch)) handle("/opensearch.xml", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { serveFileFS(w, r, s.staticFS, "shared/opensearch.xml") })) |
