diff options
Diffstat (limited to 'src/net/http/serve_test.go')
| -rw-r--r-- | src/net/http/serve_test.go | 1180 |
1 files changed, 496 insertions, 684 deletions
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 4fadc56c9e..a93f6eff1b 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -246,15 +246,13 @@ var vtests = []struct { {"http://someHost.com/someDir", "/someDir/"}, } -func TestHostHandlers(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) } +func testHostHandlers(t *testing.T, mode testMode) { mux := NewServeMux() for _, h := range handlers { mux.Handle(h.pattern, stringHandler(h.msg)) } - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -487,9 +485,9 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { // properly sets the query string in the redirect URL. // See Issue 17841. func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { - setParallel(t) - defer afterTest(t) - + run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode}) +} +func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) { writeBackQuery := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.URL.RawQuery) } @@ -502,8 +500,7 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { fmt.Fprintf(w, "%s:bar", r.URL.RawQuery) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts tests := [...]struct { path string @@ -546,7 +543,6 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { setParallel(t) - defer afterTest(t) mux := NewServeMux() mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/")) @@ -578,9 +574,6 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""}, } - ts := httptest.NewServer(mux) - defer ts.Close() - for i, tt := range tests { req, _ := NewRequest(tt.method, tt.url, nil) w := httptest.NewRecorder() @@ -602,13 +595,10 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { } } -func TestShouldRedirectConcurrency(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) } +func testShouldRedirectConcurrency(t *testing.T, mode testMode) { mux := NewServeMux() - ts := httptest.NewServer(mux) - defer ts.Close() + newClientServerTest(t, mode, mux) mux.HandleFunc("/", func(w ResponseWriter, r *Request) {}) } @@ -656,13 +646,12 @@ func benchmarkServeMux(b *testing.B, runHandler bool) { } } -func TestServerTimeouts(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) } +func testServerTimeouts(t *testing.T, mode testMode) { // Try three times, with increasing timeouts. tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second} for i, timeout := range tries { - err := testServerTimeouts(timeout) + err := testServerTimeoutsWithTimeout(t, timeout, mode) if err == nil { return } @@ -674,16 +663,15 @@ func TestServerTimeouts(t *testing.T) { t.Fatal("all attempts failed") } -func testServerTimeouts(timeout time.Duration) error { +func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - })) - ts.Config.ReadTimeout = timeout - ts.Config.WriteTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + ts.Config.WriteTimeout = timeout + }).ts // Hit the HTTP server successfully. c := ts.Client() @@ -749,22 +737,20 @@ func testServerTimeouts(timeout time.Duration) error { } // Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437) -func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { +func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) { + run(t, testWriteDeadlineExtendedOnNewRequest) +} +func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {})) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}), + func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }, + ).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - t.Fatal(err) - } for i := 1; i <= 3; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -785,9 +771,6 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { t.Fatalf("http2 Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } time.Sleep(ts.Config.WriteTimeout / 2) } } @@ -810,33 +793,31 @@ func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) { } // Test that the HTTP/2 server RSTs stream on slow write. -func TestHTTP2WriteDeadlineEnforcedPerStream(t *testing.T) { +func TestWriteDeadlineEnforcedPerStream(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) - defer afterTest(t) - tryTimeouts(t, testHTTP2WriteDeadlineEnforcedPerStream) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testWriteDeadlineEnforcedPerStream(t, mode, timeout) + }) + }) } -func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { +func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request times out - })) - ts.Config.WriteTimeout = timeout / 2 - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = timeout / 2 + }).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -844,12 +825,9 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #1: %v", err) + return fmt.Errorf("Get #1: %v", err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } req, err = NewRequest("GET", ts.URL, nil) if err != nil { @@ -858,45 +836,42 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { r, err = c.Do(req) if err == nil { r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } - return fmt.Errorf("http2 Get #2 expected error, got nil") + return fmt.Errorf("Get #2 expected error, got nil") } - expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 - if !strings.Contains(err.Error(), expected) { - return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + if mode == http2Mode { + expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 + if !strings.Contains(err.Error(), expected) { + return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + } } return nil } // Test that the HTTP/2 server does not send RST when WriteDeadline not set. -func TestHTTP2NoWriteDeadline(t *testing.T) { +func TestNoWriteDeadline(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) defer afterTest(t) - tryTimeouts(t, testHTTP2NoWriteDeadline) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testNoWriteDeadline(t, mode, timeout) + }) + }) } -func testHTTP2NoWriteDeadline(timeout time.Duration) error { +func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request timesout - })) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + })).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } for i := 0; i < 2; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -905,12 +880,9 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #%d: %v", i, err) + return fmt.Errorf("Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } } return nil } @@ -918,15 +890,14 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { // golang.org/issue/4741 -- setting only a write timeout that triggers // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. -func TestOnlyWriteTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) } +func testOnlyWriteTimeout(t *testing.T, mode testMode) { var ( mu sync.RWMutex conn net.Conn ) var afterTimeoutErrc = make(chan error, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { buf := make([]byte, 512<<10) _, err := w.Write(buf) if err != nil { @@ -942,10 +913,9 @@ func TestOnlyWriteTimeout(t *testing.T) { conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) _, err = w.Write(buf) afterTimeoutErrc <- err - })) - ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} + }).ts c := ts.Client() @@ -992,9 +962,12 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) { } // TestIdentityResponse verifies that a handler can unset -func TestIdentityResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) } +func testIdentityResponse(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56019") + } + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -1012,9 +985,7 @@ func TestIdentityResponse(t *testing.T) { } }) - ts := httptest.NewServer(handler) - defer ts.Close() - + ts := newClientServerTest(t, mode, handler).ts c := ts.Client() // Note: this relies on the assumption (which is true) that @@ -1048,6 +1019,10 @@ func TestIdentityResponse(t *testing.T) { } res.Body.Close() + if mode != http1Mode { + return + } + // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1070,9 +1045,7 @@ func TestIdentityResponse(t *testing.T) { func testTCPConnectionCloses(t *testing.T, req string, h Handler) { setParallel(t) - defer afterTest(t) - s := httptest.NewServer(h) - defer s.Close() + s := newClientServerTest(t, http1Mode, h).ts conn, err := net.Dial("tcp", s.Listener.Addr().String()) if err != nil { @@ -1114,9 +1087,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(handler) - defer ts.Close() + ts := newClientServerTest(t, http1Mode, handler).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1192,14 +1163,12 @@ func TestHTTP10KeepAlive304Response(t *testing.T) { } // Issue 15703 -func TestKeepAliveFinalChunkWithEOF(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) } +func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() // force chunked encoding w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}")) })) - defer cst.close() type data struct { Addr string } @@ -1222,16 +1191,11 @@ func TestKeepAliveFinalChunkWithEOF(t *testing.T) { } } -func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } -func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } - -func testSetsRemoteAddr(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) } +func testSetsRemoteAddr(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1276,17 +1240,18 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr { // Issue 12943 func TestServerAllowsBlockingRemoteAddr(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "RA:%s", r.RemoteAddr) - })) + run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode}) +} +func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) { conns := make(chan net.Conn) - ts.Listener = &blockingRemoteAddrListener{ - Listener: ts.Listener, - conns: conns, - } - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "RA:%s", r.RemoteAddr) + }), func(ts *httptest.Server) { + ts.Listener = &blockingRemoteAddrListener{ + Listener: ts.Listener, + conns: conns, + } + }).ts c := ts.Client() c.Timeout = time.Second @@ -1351,13 +1316,9 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { // TestHeadResponses verifies that all MIME type sniffing and Content-Length // counting of GET requests also happens on HEAD requests. -func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) } -func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) } - -func testHeadResponses(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) } +func testHeadResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("<html>")) if err != nil { t.Errorf("ResponseWriter.Write: %v", err) @@ -1369,7 +1330,6 @@ func testHeadResponses(t *testing.T, h2 bool) { t.Errorf("Copy(ResponseWriter, ...): %v", err) } })) - defer cst.close() res, err := cst.c.Head(cst.ts.URL) if err != nil { t.Error(err) @@ -1393,14 +1353,16 @@ func testHeadResponses(t *testing.T, h2 bool) { } func TestTLSHandshakeTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode}) +} +func testTLSHandshakeTimeout(t *testing.T, mode testMode) { errc := make(chanWriter, 10) // but only expecting 1 - ts.Config.ReadTimeout = 250 * time.Millisecond - ts.Config.ErrorLog = log.New(errc, "", 0) - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), + func(ts *httptest.Server) { + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.ErrorLog = log.New(errc, "", 0) + }, + ).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -1423,19 +1385,18 @@ func TestTLSHandshakeTimeout(t *testing.T) { } } -func TestTLSServer(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) } +func testTLSServer(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") if r.TLS.HandshakeComplete { w.Header().Set("X-TLS-HandshakeComplete", "true") } } - })) - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + }).ts // Connect an idle TCP connection to this server before we run // our real tests. This idle connection used to block forever @@ -1528,14 +1489,15 @@ func TestServeTLS(t *testing.T) { // Test that the HTTPS server nicely rejects plaintext HTTP/1.x requests. func TestTLSServerRejectHTTPRequests(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode}) +} +func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("unexpected HTTPS request") - })) - var errBuf bytes.Buffer - ts.Config.ErrorLog = log.New(&errBuf, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + var errBuf bytes.Buffer + ts.Config.ErrorLog = log.New(&errBuf, "", 0) + }).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1727,11 +1689,9 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. -// http2 test: TestServer_Response_Automatic100Continue -func TestServerExpect(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) } +func testServerExpect(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only // conditionally want to do. @@ -1741,8 +1701,7 @@ func TestServerExpect(t *testing.T) { } else { w.WriteHeader(StatusUnauthorized) } - })) - defer ts.Close() + })).ts runTest := func(test serverExpectTest) { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -2287,11 +2246,8 @@ func (c cancelableTimeoutContext) Err() error { return nil } -func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } -func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } -func testTimeoutHandler(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) } +func testTimeoutHandler(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2301,8 +2257,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h2, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2348,10 +2303,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) { } // See issues 8209 and 8414. -func TestTimeoutHandlerRace(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) } +func testTimeoutHandlerRace(t *testing.T, mode testMode) { delayHi := HandlerFunc(func(w ResponseWriter, r *Request) { ms, _ := strconv.Atoi(r.URL.Path[1:]) if ms == 0 { @@ -2363,8 +2316,7 @@ func TestTimeoutHandlerRace(t *testing.T) { } }) - ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts c := ts.Client() @@ -2393,16 +2345,13 @@ func TestTimeoutHandlerRace(t *testing.T) { // See issues 8209 and 8414. // Both issues involved panics in the implementation of TimeoutHandler. -func TestTimeoutHandlerRaceHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) } +func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) { delay204 := HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(204) }) - ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts var wg sync.WaitGroup gate := make(chan bool, 50) @@ -2433,9 +2382,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { } // Issue 9162 -func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) } +func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2446,8 +2394,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h1Mode, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2491,15 +2438,17 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { // Issue 14568. func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { + run(t, testTimeoutHandlerStartTimerWhenServing) +} +func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping sleeping test in -short mode") } - defer afterTest(t) var handler HandlerFunc = func(w ResponseWriter, _ *Request) { w.WriteHeader(StatusNoContent) } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts defer ts.Close() c := ts.Client() @@ -2518,9 +2467,8 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { } } -func TestTimeoutHandlerContextCanceled(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) } +func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) { writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Type", "text/plain") @@ -2540,7 +2488,7 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() h := NewTestTimeoutHandler(sayHi, ctx) - cst := newClientServerTest(t, h1Mode, h) + cst := newClientServerTest(t, mode, h) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -2560,15 +2508,13 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { } // https://golang.org/issue/15948 -func TestTimeoutHandlerEmptyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) } +func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) { var handler HandlerFunc = func(w ResponseWriter, _ *Request) { // No response. } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts c := ts.Client() @@ -2587,7 +2533,9 @@ func TestTimeoutHandlerPanicRecovery(t *testing.T) { wrapper := func(h Handler) Handler { return TimeoutHandler(h, time.Second, "") } - testHandlerPanic(t, false, false, wrapper, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, wrapper, "intentional death for testing") + }, testNotParallel) } func TestRedirectBadPath(t *testing.T) { @@ -2705,17 +2653,10 @@ func TestRedirectContentTypeAndBody(t *testing.T) { // connection immediately. But when it re-uses the connection, it typically closes // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. -func TestZeroLengthPostAndResponse_h1(t *testing.T) { - testZeroLengthPostAndResponse(t, h1Mode) -} -func TestZeroLengthPostAndResponse_h2(t *testing.T) { - testZeroLengthPostAndResponse(t, h2Mode) -} +func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) } -func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { +func testZeroLengthPostAndResponse(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("handler ReadAll: %v", err) @@ -2725,7 +2666,6 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } rw.Header().Set("Content-Length", "0") })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, strings.NewReader("")) if err != nil { @@ -2752,23 +2692,26 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } } -func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil, nil) } -func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil, nil) } - -func TestHandlerPanic_h1(t *testing.T) { - testHandlerPanic(t, false, h1Mode, nil, "intentional death for testing") +func TestHandlerPanicNil(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, nil) + }, testNotParallel) } -func TestHandlerPanic_h2(t *testing.T) { - testHandlerPanic(t, false, h2Mode, nil, "intentional death for testing") + +func TestHandlerPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, "intentional death for testing") + }, testNotParallel) } func TestHandlerPanicWithHijack(t *testing.T) { // Only testing HTTP/1, and our http2 server doesn't support hijacking. - testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, true, mode, nil, "intentional death for testing") + }, testNotParallel, []testMode{http1Mode}) } -func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) { - defer afterTest(t) +func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) { // Unlike the other tests that set the log output to io.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -2803,8 +2746,7 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) H if wrapper != nil { handler = wrapper(handler) } - cst := newClientServerTest(t, h2, handler) - defer cst.close() + cst := newClientServerTest(t, mode, handler) // Do a blocking read on the log output pipe so its logging // doesn't bleed into the next test. But wait only 5 seconds @@ -2847,9 +2789,11 @@ func (w terrorWriter) Write(p []byte) (int, error) { // Issue 16456: allow writing 0 bytes on hijacked conn to test hijack // without any log spam. func TestServerWriteHijackZeroBytes(t *testing.T) { - defer afterTest(t) + run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode}) +} +func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) { done := make(chan struct{}) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) w.(Flusher).Flush() conn, _, err := w.(Hijacker).Hijack() @@ -2862,10 +2806,9 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { if err != ErrHijacked { t.Errorf("Write error = %v; want ErrHijacked", err) } - })) - ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) + }).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -2880,19 +2823,23 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { } } -func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") } -func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") } -func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") } -func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") } +func TestServerNoDate(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Date") + }) +} -func testServerNoHeader(t *testing.T, h2 bool, header string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerContentType(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Content-Type") + }) +} + +func testServerNoHeader(t *testing.T, mode testMode, header string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()[header] = nil io.WriteString(w, "<html>foo</html>") // non-empty })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -2903,15 +2850,13 @@ func testServerNoHeader(t *testing.T, h2 bool, header string) { } } -func TestStripPrefix(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) } +func testStripPrefix(t *testing.T, mode testMode) { h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) w.Header().Set("X-RawPath", r.URL.RawPath) }) - ts := httptest.NewServer(StripPrefix("/foo/bar", h)) - defer ts.Close() + ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts c := ts.Client() @@ -2961,15 +2906,11 @@ func TestStripPrefixNotModifyRequest(t *testing.T) { } } -func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) } -func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) } -func testRequestLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) } +func testRequestLimit(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") }), optQuietLog) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) var bytesPerHeader = len("header12345: val12345\r\n") for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ { @@ -2979,7 +2920,7 @@ func testRequestLimit(t *testing.T, h2 bool) { if res != nil { defer res.Body.Close() } - if h2 { + if mode == http2Mode { // In HTTP/2, the result depends on a race. If the client has received the // server's SETTINGS before RoundTrip starts sending the request, then RoundTrip // will fail with an error. Otherwise, the client should receive a 431 from the @@ -3021,13 +2962,10 @@ func (cr countReader) Read(p []byte) (n int, err error) { return } -func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) } -func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) } -func testRequestBodyLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) } +func testRequestBodyLimit(t *testing.T, mode testMode) { const limit = 1 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) n, err := io.Copy(io.Discard, r.Body) if err == nil { @@ -3044,7 +2982,6 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit) } })) - defer cst.close() nWritten := new(int64) req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) @@ -3068,13 +3005,12 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. -func TestClientWriteShutdown(t *testing.T) { +func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) } +func testClientWriteShutdown(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/17906") } - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -3119,12 +3055,12 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See https://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerGracefulClose, []testMode{http1Mode}) +} +func testServerGracefulClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3162,11 +3098,9 @@ func TestServerGracefulClose(t *testing.T) { <-writeErr } -func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) } -func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) } -func testCaseSensitiveMethod(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) } +func testCaseSensitiveMethod(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "get" { t.Errorf(`Got method %q; want "get"`, r.Method) } @@ -3187,8 +3121,10 @@ func testCaseSensitiveMethod(t *testing.T, h2 bool) { // response, the net/http package adds a "Content-Length: 0" response // header. func TestContentLengthZero(t *testing.T) { - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {})) - defer ts.Close() + run(t, testContentLengthZero, []testMode{http1Mode}) +} +func testContentLengthZero(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -3215,15 +3151,17 @@ func TestContentLengthZero(t *testing.T) { } func TestCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testCloseNotifier, []testMode{http1Mode}) +} +func testCloseNotifier(t *testing.T, mode testMode) { gotReq := make(chan bool, 1) sawClose := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() <-cc sawClose <- true - })) + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3257,11 +3195,12 @@ For: // // Issue 13165 (where it used to deadlock), but behavior changed in Issue 23921. func TestCloseNotifierPipelined(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testCloseNotifierPipelined, []testMode{http1Mode}) +} +func testCloseNotifierPipelined(t *testing.T, mode testMode) { gotReq := make(chan bool, 2) sawClose := make(chan bool, 2) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() select { @@ -3270,8 +3209,7 @@ func TestCloseNotifierPipelined(t *testing.T) { case <-time.After(100 * time.Millisecond): } sawClose <- true - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3341,12 +3279,14 @@ func TestCloseNotifierChanLeak(t *testing.T) { // Issue 9763. // HTTP/1-only test. (http2 doesn't have Hijack) func TestHijackAfterCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testHijackAfterCloseNotifier, []testMode{http1Mode}) +} +func testHijackAfterCloseNotifier(t *testing.T, mode testMode) { script := make(chan string, 2) script <- "closenotify" script <- "hijack" close(script) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { plan := <-script switch plan { default: @@ -3369,13 +3309,12 @@ func TestHijackAfterCloseNotifier(t *testing.T) { c.Close() return } - })) - defer ts.Close() - res1, err := Get(ts.URL) + })).ts + res1, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } - res2, err := Get(ts.URL) + res2, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } @@ -3387,12 +3326,13 @@ func TestHijackAfterCloseNotifier(t *testing.T) { } func TestHijackBeforeRequestBodyRead(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode}) +} +func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) { var requestBody = bytes.Repeat([]byte("a"), 1<<20) bodyOkay := make(chan bool, 1) gotCloseNotify := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(bodyOkay) // caller will read false if nothing else reqBody := r.Body @@ -3419,8 +3359,7 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { case <-time.After(5 * time.Second): gotCloseNotify <- false } - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3440,14 +3379,14 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { } } -func TestOptions(t *testing.T) { +func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) } +func testOptions(t *testing.T, mode testMode) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() mux.HandleFunc("/", func(w ResponseWriter, r *Request) { uric <- r.RequestURI }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3492,15 +3431,15 @@ func TestOptions(t *testing.T) { } } -func TestOptionsHandler(t *testing.T) { +func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) } +func testOptionsHandler(t *testing.T, mode testMode) { rc := make(chan *Request, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { rc <- r - })) - ts.Config.DisableGeneralOptionsHandler = true - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.DisableGeneralOptionsHandler = true + }).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3804,12 +3743,12 @@ func TestDoubleHijack(t *testing.T) { // optimization and is pointless if dealing with a // badly behaved client. func TestHTTP10ConnectionHeader(t *testing.T) { - defer afterTest(t) - + run(t, testHTTP10ConnectionHeader, []testMode{http1Mode}) +} +func testHTTP10ConnectionHeader(t *testing.T, mode testMode) { mux := NewServeMux() mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {})) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts // net/http uses HTTP/1.1 for requests, so write requests manually tests := []struct { @@ -3856,14 +3795,11 @@ func TestHTTP10ConnectionHeader(t *testing.T) { } // See golang.org/issue/5660 -func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) } -func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) } -func testServerReaderFromOrder(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) } +func testServerReaderFromOrder(t *testing.T, mode testMode) { pr, pw := io.Pipe() const size = 3 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path done := make(chan bool) go func() { @@ -3883,7 +3819,6 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { pw.Close() <-done })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size)) if err != nil { @@ -3957,16 +3892,10 @@ func TestContentTypeOkayOn204(t *testing.T) { // proxy). So then two people own that Request.Body (both the server // and the http client), and both think they can close it on failure. // Therefore, all incoming server requests Bodies need to be thread-safe. -func TestTransportAndServerSharedBodyRace_h1(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h1Mode) +func TestTransportAndServerSharedBodyRace(t *testing.T) { + run(t, testTransportAndServerSharedBodyRace) } -func TestTransportAndServerSharedBodyRace_h2(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h2Mode) -} -func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) { const bodySize = 1 << 20 // errorf is like t.Errorf, but also writes to println. When @@ -3980,7 +3909,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { } unblockBackend := make(chan bool) - backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gone := rw.(CloseNotifier).CloseNotify() didCopy := make(chan any) go func() { @@ -4007,7 +3936,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { backendRespc := make(chan *Response, 1) var proxy *clientServerTest - proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { req2, _ := NewRequest("POST", backend.ts.URL, req.Body) req2.ContentLength = bodySize cancel := make(chan struct{}) @@ -4027,7 +3956,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // Try to cause a race: Both the Transport and the proxy handler's Server // will try to read/close req.Body (aka req2.Body) - if h2 { + if mode == http2Mode { close(cancel) } else { proxy.c.Transport.(*Transport).CancelRequest(req2) @@ -4071,22 +4000,23 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // cause the Handler goroutine's Request.Body.Close to block. // See issue 7121. func TestRequestBodyCloseDoesntBlock(t *testing.T) { + run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode}) +} +func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) readErrCh := make(chan error, 1) errCh := make(chan error, 2) - server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { go func(body io.Reader) { _, err := body.Read(make([]byte, 100)) readErrCh <- err }(req.Body) time.Sleep(500 * time.Millisecond) - })) - defer server.Close() + })).ts closeConn := make(chan bool) defer close(closeConn) @@ -4149,9 +4079,8 @@ func TestAppendTime(t *testing.T) { } } -func TestServerConnState(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) } +func testServerConnState(t *testing.T, mode testMode) { handler := map[string]func(w ResponseWriter, r *Request){ "/": func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello.") @@ -4217,37 +4146,36 @@ func TestServerConnState(t *testing.T) { // next call to wantLog. } - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { handler[r.URL.Path](w, r) - })) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + ts.Config.ConnState = func(c net.Conn, state ConnState) { + if c == nil { + t.Errorf("nil conn seen in state %s", state) + return + } + sl := <-activeLog + if sl.active == nil && state == StateNew { + sl.active = c + } else if sl.active != c { + t.Errorf("unexpected conn in state %s", state) + activeLog <- sl + return + } + sl.got = append(sl.got, state) + if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { + close(sl.complete) + sl.complete = nil + } + activeLog <- sl + } + }).ts defer func() { activeLog <- &stateLog{} // If the test failed, allow any remaining ConnState callbacks to complete. ts.Close() }() - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - ts.Config.ConnState = func(c net.Conn, state ConnState) { - if c == nil { - t.Errorf("nil conn seen in state %s", state) - return - } - sl := <-activeLog - if sl.active == nil && state == StateNew { - sl.active = c - } else if sl.active != c { - t.Errorf("unexpected conn in state %s", state) - activeLog <- sl - return - } - sl.got = append(sl.got, state) - if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { - close(sl.complete) - sl.complete = nil - } - activeLog <- sl - } - - ts.Start() c := ts.Client() mustGet := func(url string, headers ...string) { @@ -4329,13 +4257,15 @@ func TestServerConnState(t *testing.T) { }, StateNew, StateActive, StateIdle, StateClosed) } -func TestServerKeepAlivesEnabled(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - ts.Config.SetKeepAlivesEnabled(false) - ts.Start() - defer ts.Close() - res, err := Get(ts.URL) +func TestServerKeepAlivesEnabledResultClose(t *testing.T) { + run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode}) +} +func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -4346,16 +4276,12 @@ func TestServerKeepAlivesEnabled(t *testing.T) { } // golang.org/issue/7856 -func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) } -func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) } -func testServerEmptyBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) } +func testServerEmptyBodyRace(t *testing.T, mode testMode) { var n int32 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { atomic.AddInt32(&n, 1) }), optQuietLog) - defer cst.close() var wg sync.WaitGroup const reqs = 20 for i := 0; i < reqs; i++ { @@ -4436,9 +4362,9 @@ func TestCloseWrite(t *testing.T) { // fixed. // // So add an explicit test for this. -func TestServerFlushAndHijack(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) } +func testServerFlushAndHijack(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello, ") w.(Flusher).Flush() conn, buf, _ := w.(Hijacker).Hijack() @@ -4449,8 +4375,7 @@ func TestServerFlushAndHijack(t *testing.T) { if err := conn.Close(); err != nil { t.Error(err) } - })) - defer ts.Close() + })).ts res, err := Get(ts.URL) if err != nil { t.Fatal(err) @@ -4472,20 +4397,21 @@ func TestServerFlushAndHijack(t *testing.T) { // To test, verify we don't timeout or see fewer unique client // addresses (== unique connections) than requests. func TestServerKeepAliveAfterWriteError(t *testing.T) { + run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode}) +} +func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) const numReq = 3 addrc := make(chan string, numReq) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addrc <- r.RemoteAddr time.Sleep(500 * time.Millisecond) w.(Flusher).Flush() - })) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }).ts errc := make(chan error, numReq) go func() { @@ -4529,12 +4455,13 @@ func TestServerKeepAliveAfterWriteError(t *testing.T) { // Issue 9987: shouldn't add automatic Content-Length (or // Content-Type) if a Transfer-Encoding was set by the handler. func TestNoContentLengthIfTransferEncoding(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode}) +} +func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Transfer-Encoding", "foo") io.WriteString(w, "<html>") - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -4682,15 +4609,12 @@ func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { } } -func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) } -func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) } -func testHandlerSetsBodyNil(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) } +func testHandlerSetsBodyNil(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = nil fmt.Fprintf(w, "%v", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -4780,9 +4704,11 @@ func TestServerValidatesHostHeader(t *testing.T) { } func TestServerHandlersCanHandleH2PRI(t *testing.T) { + run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode}) +} +func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) { const upgradeResponse = "upgrade here" - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, br, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -4804,8 +4730,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { return } io.WriteString(conn, upgradeResponse) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -4872,17 +4797,12 @@ func TestServerValidatesHeaders(t *testing.T) { } } -func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h1Mode) -} -func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h2Mode) +func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { + run(t, testServerRequestContextCancel_ServeHTTPDone) } -func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) { ctxc := make(chan context.Context, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() select { case <-ctx.Done(): @@ -4891,7 +4811,6 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { } ctxc <- ctx })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4910,16 +4829,16 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { // is always blocked in a Read call so it notices the EOF from the client. // See issues 15927 and 15224. func TestServerRequestContextCancel_ConnClose(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode}) +} +func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) { inHandler := make(chan struct{}) handlerDone := make(chan struct{}) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(inHandler) <-r.Context().Done() close(handlerDone) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -4931,23 +4850,17 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) { <-handlerDone } -func TestServerContext_ServerContextKey_h1(t *testing.T) { - testServerContext_ServerContextKey(t, h1Mode) -} -func TestServerContext_ServerContextKey_h2(t *testing.T) { - testServerContext_ServerContextKey(t, h2Mode) +func TestServerContext_ServerContextKey(t *testing.T) { + run(t, testServerContext_ServerContextKey) } -func testServerContext_ServerContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testServerContext_ServerContextKey(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() got := ctx.Value(ServerContextKey) if _, ok := got.(*Server); !ok { t.Errorf("context value = %T; want *http.Server", got) } })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4955,20 +4868,14 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) { res.Body.Close() } -func TestServerContext_LocalAddrContextKey_h1(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h1Mode) -} -func TestServerContext_LocalAddrContextKey_h2(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h2Mode) +func TestServerContext_LocalAddrContextKey(t *testing.T) { + run(t, testServerContext_LocalAddrContextKey) } -func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) { ch := make(chan any, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ch <- r.Context().Value(LocalAddrContextKey) })) - defer cst.close() if _, err := cst.c.Head(cst.ts.URL); err != nil { t.Fatal(err) } @@ -5021,16 +4928,19 @@ func TestHandlerSetTransferEncodingGzip(t *testing.T) { } func BenchmarkClientServer(b *testing.B) { + run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode}) +} +func benchmarkClientServer(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") - })) - defer ts.Close() + })).ts b.StartTimer() + c := ts.Client() for i := 0; i < b.N; i++ { - res, err := Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { b.Fatal("Get:", err) } @@ -5048,33 +4958,21 @@ func BenchmarkClientServer(b *testing.B) { b.StopTimer() } -func BenchmarkClientServerParallel4(b *testing.B) { - benchmarkClientServerParallel(b, 4, false) -} - -func BenchmarkClientServerParallel64(b *testing.B) { - benchmarkClientServerParallel(b, 64, false) -} - -func BenchmarkClientServerParallelTLS4(b *testing.B) { - benchmarkClientServerParallel(b, 4, true) -} - -func BenchmarkClientServerParallelTLS64(b *testing.B) { - benchmarkClientServerParallel(b, 64, true) +func BenchmarkClientServerParallel(b *testing.B) { + for _, parallelism := range []int{4, 64} { + b.Run(fmt.Sprint(parallelism), func(b *testing.B) { + run(b, func(b *testing.B, mode testMode) { + benchmarkClientServerParallel(b, parallelism, mode) + }, []testMode{http1Mode, https1Mode, http2Mode}) + }) + } } -func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { +func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) { b.ReportAllocs() - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") - })) - if useTLS { - ts.StartTLS() - } else { - ts.Start() - } - defer ts.Close() + })).ts b.ResetTimer() b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { @@ -5464,15 +5362,15 @@ Host: golang.org } } -func BenchmarkCloseNotifier(b *testing.B) { +func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) } +func benchmarkCloseNotifier(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() sawClose := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { <-rw.(CloseNotifier).CloseNotify() sawClose <- true - })) - defer ts.Close() + })).ts tot := time.NewTimer(5 * time.Second) defer tot.Stop() b.StartTimer() @@ -5508,20 +5406,18 @@ func TestConcurrentServerServe(t *testing.T) { } } -func TestServerIdleTimeout(t *testing.T) { +func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) } +func testServerIdleTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) io.WriteString(w, r.RemoteAddr) - })) - ts.Config.ReadHeaderTimeout = 1 * time.Second - ts.Config.IdleTimeout = 2 * time.Second - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = 1 * time.Second + ts.Config.IdleTimeout = 2 * time.Second + }).ts c := ts.Client() get := func() string { @@ -5576,12 +5472,12 @@ func get(t *testing.T, c *Client, url string) string { // Tests that calls to Server.SetKeepAlivesEnabled(false) closes any // currently-open connections. func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode}) +} +func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -5620,16 +5516,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { } } -func TestServerShutdown_h1(t *testing.T) { - testServerShutdown(t, h1Mode) -} -func TestServerShutdown_h2(t *testing.T) { - testServerShutdown(t, h2Mode) -} - -func testServerShutdown(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) } +func testServerShutdown(t *testing.T, mode testMode) { var doShutdown func() // set later var doStateCount func() var shutdownRes = make(chan error, 1) @@ -5645,10 +5533,9 @@ func testServerShutdown(t *testing.T, h2 bool) { time.Sleep(20 * time.Millisecond) io.WriteString(w, r.RemoteAddr) }) - cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) { + cst := newClientServerTest(t, mode, handler, func(srv *httptest.Server) { srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} }) }) - defer cst.close() doShutdown = func() { shutdownRes <- cst.ts.Config.Shutdown(context.Background()) @@ -5678,24 +5565,22 @@ func testServerShutdown(t *testing.T, h2 bool) { } } -func TestServerShutdownStateNew(t *testing.T) { +func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) } +func testServerShutdownStateNew(t *testing.T, mode testMode) { if testing.Short() { t.Skip("test takes 5-6 seconds; skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - // nothing. - })) var connAccepted sync.WaitGroup - ts.Config.ConnState = func(conn net.Conn, state ConnState) { - if state == StateNew { - connAccepted.Done() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // nothing. + }), func(ts *httptest.Server) { + ts.Config.ConnState = func(conn net.Conn, state ConnState) { + if state == StateNew { + connAccepted.Done() + } } - } - ts.Start() - defer ts.Close() + }).ts // Start a connection but never write to it. connAccepted.Add(1) @@ -5757,16 +5642,14 @@ func TestServerCloseDeadlock(t *testing.T) { // Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by // both HTTP/1 and HTTP/2. -func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) } -func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) } -func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { - if h2 { +func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) } +func testServerKeepAlivesEnabled(t *testing.T, mode testMode) { + if mode == http2Mode { restore := ExportSetH2GoawayTimeout(10 * time.Millisecond) defer restore() } // Not parallel: messes with global variable. (http2goAwayTimeout) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) defer cst.close() srv := cst.ts.Config srv.SetKeepAlivesEnabled(false) @@ -5803,9 +5686,8 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { // Issue 18447: test that the Server's ReadTimeout is stopped while // the server's doing its 1-byte background read between requests, // waiting for the connection to maybe close. -func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) } +func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 10 * time.Millisecond, 50 * time.Millisecond, @@ -5813,17 +5695,16 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { select { case <-time.After(2 * timeout): fmt.Fprint(w, "ok") case <-r.Context().Done(): fmt.Fprint(w, r.Context().Err()) } - })) - ts.Config.ReadTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + }).ts c := ts.Client() @@ -5847,8 +5728,9 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { // beginning of a request has been received, rather than including time the // connection spent idle. func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode}) +} +func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 10 * time.Millisecond, 50 * time.Millisecond, @@ -5856,11 +5738,10 @@ func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := httptest.NewUnstartedServer(serve(200)) - ts.Config.ReadHeaderTimeout = timeout - ts.Config.IdleTimeout = 0 // disable idle timeout - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = timeout + ts.Config.IdleTimeout = 0 // disable idle timeout + }).ts // rather than using an http.Client, create a single connection, so that // we can ensure this connection is not closed. @@ -5912,13 +5793,13 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t * // Issue 18535: test that the Server doesn't try to do a background // read if it's already done one. func TestServerDuplicateBackgroundRead(t *testing.T) { + run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode}) +} +func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) { if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" { testenv.SkipFlaky(t, 24826) } - setParallel(t) - defer afterTest(t) - goroutines := 5 requests := 2000 if testing.Short() { @@ -5926,8 +5807,7 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { requests = 100 } - hts := httptest.NewServer(HandlerFunc(NotFound)) - defer hts.Close() + hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n") @@ -5970,14 +5850,15 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { // bufio.Reader.Buffered(), without resorting to Reading it // (potentially blocking) to get at it. func TestServerHijackGetsBackgroundByte(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) inHandler := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) // Tell the client to send more data after the GET request. @@ -6000,8 +5881,7 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { t.Error("context unexpectedly canceled") default: } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -6030,14 +5910,15 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { // immediate 1MB of data to the server to fill up the server's 4KB // buffer. func TestServerHijackGetsBackgroundByte_big(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) const size = 8 << 10 - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) conn, buf, err := w.(Hijacker).Hijack() @@ -6061,8 +5942,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { } else if !allX { t.Errorf("read %q; want %d 'x'", slurp, size) } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -6198,73 +6078,27 @@ func TestStripPortFromHost(t *testing.T) { } } -func TestServerContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) - type baseKey struct{} - type connKey struct{} - ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) - } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) - } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() - res, err := ts.Client().Get(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - ctx := <-ch - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("base context key = %#v; want %q", got, want) - } - if got, want := ctx.Value(connKey{}), "conn"; got != want { - t.Errorf("conn context key = %#v; want %q", got, want) - } -} - -func TestServerContextsHTTP2(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerContexts(t *testing.T) { run(t, testServerContexts) } +func testServerContexts(t *testing.T, mode testMode) { type baseKey struct{} type connKey struct{} ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - if r.ProtoMajor != 2 { - t.Errorf("unexpected HTTP/1.x request") - } + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) + }), func(ts *httptest.Server) { + ts.Config.BaseContext = func(ln net.Listener) context.Context { + if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { + t.Errorf("unexpected onceClose listener type %T", ln) + } + return context.WithValue(context.Background(), baseKey{}, "base") } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.TLS = &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - } - ts.StartTLS() - defer ts.Close() - ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true + }).ts res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) @@ -6281,20 +6115,20 @@ func TestServerContextsHTTP2(t *testing.T) { // Issue 35750: check ConnContext not modifying context for other connections func TestConnContextNotModifyingAllContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testConnContextNotModifyingAllContexts) +} +func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) { type connKey struct{} - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { rw.Header().Set("Connection", "close") - })) - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got := ctx.Value(connKey{}); got != nil { - t.Errorf("in ConnContext, unexpected context key = %#v", got) + }), func(ts *httptest.Server) { + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got := ctx.Value(connKey{}); got != nil { + t.Errorf("in ConnContext, unexpected context key = %#v", got) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() + }).ts var res *Response var err error @@ -6315,10 +6149,12 @@ func TestConnContextNotModifyingAllContexts(t *testing.T) { // Issue 30710: ensure that as per the spec, a server responds // with 501 Not Implemented for unsupported transfer-encodings. func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode}) +} +func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello, World!")) - })) - defer cst.Close() + })).ts serverURL, err := url.Parse(cst.URL) if err != nil { @@ -6353,19 +6189,9 @@ func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { } } -func TestContentEncodingNoSniffing_h1(t *testing.T) { - testContentEncodingNoSniffing(t, h1Mode) -} - -func TestContentEncodingNoSniffing_h2(t *testing.T) { - testContentEncodingNoSniffing(t, h2Mode) -} - // Issue 31753: don't sniff when Content-Encoding is set -func testContentEncodingNoSniffing(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) } +func testContentEncodingNoSniffing(t *testing.T, mode testMode) { type setting struct { name string body []byte @@ -6428,13 +6254,12 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { for _, tt := range settings { t.Run(tt.name, func(t *testing.T) { - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { if tt.contentEncoding != nil { rw.Header().Set("Content-Encoding", tt.contentEncoding.(string)) } rw.Write(tt.body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -6460,13 +6285,13 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { // Issue 30803: ensure that TimeoutHandler logs spurious // WriteHeader calls, for consistency with other Handlers. func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { + run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode}) +} +func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - pc, curFile, _, _ := runtime.Caller(0) curFileBaseName := filepath.Base(curFile) testFuncName := runtime.FuncForPC(pc).Name() @@ -6520,7 +6345,7 @@ func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { dur = 10 * time.Second } th := TimeoutHandler(sh, dur, timeoutMsg) - cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, th, optWithServerLog(srvLog)) + cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog)) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -6590,15 +6415,16 @@ func BenchmarkResponseStatusLine(b *testing.B) { } }) } + func TestDisableKeepAliveUpgrade(t *testing.T) { + run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode}) +} +func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - - s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "someProto") w.WriteHeader(StatusSwitchingProtocols) @@ -6611,10 +6437,9 @@ func TestDisableKeepAliveUpgrade(t *testing.T) { // Copy from the *bufio.ReadWriter, which may contain buffered data. // Copy to the net.Conn, to avoid buffering the output. io.Copy(c, buf) - })) - s.Config.SetKeepAlivesEnabled(false) - s.Start() - defer s.Close() + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts cl := s.Client() cl.Transport.(*Transport).DisableKeepAlives = true @@ -6683,21 +6508,21 @@ func TestQuerySemicolon(t *testing.T) { {"?a=1;x=good;x=bad", "", "good", true}, } - for _, tt := range tests { - t.Run(tt.query+"/allow=false", func(t *testing.T) { - allowSemicolons := false - testQuerySemicolon(t, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) - }) - t.Run(tt.query+"/allow=true", func(t *testing.T) { - allowSemicolons, expectWarning := true, false - testQuerySemicolon(t, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) - }) - } + run(t, func(t *testing.T, mode testMode) { + for _, tt := range tests { + t.Run(tt.query+"/allow=false", func(t *testing.T) { + allowSemicolons := false + testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) + }) + t.Run(tt.query+"/allow=true", func(t *testing.T) { + allowSemicolons, expectWarning := true, false + testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) + }) + } + }) } -func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolons, expectWarning bool) { - setParallel(t) - +func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectWarning bool) { writeBackX := func(w ResponseWriter, r *Request) { x := r.URL.Query().Get("x") if expectWarning { @@ -6720,11 +6545,10 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon h = AllowQuerySemicolons(h) } - ts := httptest.NewUnstartedServer(h) logBuf := &strings.Builder{} - ts.Config.ErrorLog = log.New(logBuf, "", 0) - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(logBuf, "", 0) + }).ts req, _ := NewRequest("GET", ts.URL+query, nil) res, err := ts.Client().Do(req) @@ -6759,13 +6583,15 @@ func TestMaxBytesHandler(t *testing.T) { for _, requestSize := range []int64{100, 1_000, 1_000_000} { t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize), func(t *testing.T) { - testMaxBytesHandler(t, maxSize, requestSize) + run(t, func(t *testing.T, mode testMode) { + testMaxBytesHandler(t, mode, maxSize, requestSize) + }) }) } } } -func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { +func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) { var ( handlerN int64 handlerErr error @@ -6776,7 +6602,7 @@ func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { io.Copy(w, &buf) }) - ts := httptest.NewServer(MaxBytesHandler(echo, maxSize)) + ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts defer ts.Close() c := ts.Client() @@ -6843,13 +6669,12 @@ func TestProcessing(t *testing.T) { } } -func TestParseFormCleanup_h1(t *testing.T) { testParseFormCleanup(t, h1Mode) } -func TestParseFormCleanup_h2(t *testing.T) { - t.Skip("https://go.dev/issue/20253") - testParseFormCleanup(t, h2Mode) -} +func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) } +func testParseFormCleanup(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/20253") + } -func testParseFormCleanup(t *testing.T, h2 bool) { const maxMemory = 1024 const key = "file" @@ -6858,9 +6683,7 @@ func testParseFormCleanup(t *testing.T, h2 bool) { t.Skip("https://go.dev/issue/25965") } - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.ParseMultipartForm(maxMemory) f, _, err := r.FormFile(key) if err != nil { @@ -6874,7 +6697,6 @@ func testParseFormCleanup(t *testing.T, h2 bool) { } w.Write([]byte(of.Name())) })) - defer cst.close() fBuf := new(bytes.Buffer) mw := multipart.NewWriter(fBuf) @@ -6911,33 +6733,23 @@ func testParseFormCleanup(t *testing.T, h2 bool) { func TestHeadBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "HEAD") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "HEAD") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") }) }) } func TestGetBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "GET") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "GET") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") }) }) } -func testHeadBody(t *testing.T, h2, chunked bool, method string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { b, err := io.ReadAll(r.Body) if err != nil { t.Errorf("server reading body: %v", err) |
