diff options
Diffstat (limited to 'src/net/http/httputil/reverseproxy_test.go')
| -rw-r--r-- | src/net/http/httputil/reverseproxy_test.go | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 549bc67b87..6805262755 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -16,7 +16,9 @@ import ( "log" "net/http" "net/http/httptest" + "net/http/httptrace" "net/http/internal/ascii" + "net/textproto" "net/url" "os" "reflect" @@ -1671,3 +1673,83 @@ func TestReverseProxyRewriteReplacesOut(t *testing.T) { t.Errorf("got response %q, want %q", got, want) } } + +func Test1xxResponses(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "</style.css>; rel=preload; as=style") + h.Add("Link", "</script.js>; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "</foo.js>; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + w.Write([]byte("Hello")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil) + + res, err := frontendClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } +} |
