diff options
| author | Damien Neil <dneil@google.com> | 2022-10-03 16:07:48 -0700 |
|---|---|---|
| committer | Damien Neil <dneil@google.com> | 2022-10-07 16:53:14 +0000 |
| commit | 747e1961e95c2eb3df62e045b90b111c2ceea337 (patch) | |
| tree | 59a7c933ffe695cac3c9c48e8c8e8afa068b0985 /src/net/http/serve_test.go | |
| parent | 5ca0cd3f1824f189b6c5edf59b669f22a393e2e1 (diff) | |
| download | go-747e1961e95c2eb3df62e045b90b111c2ceea337.tar.xz | |
net/http: refactor tests to run most in HTTP/1 and HTTP/2 modes
Replace the ad-hoc approach to running tests in HTTP/1 and HTTP/2
modes with a 'run' function that executes a test in various modes.
By default, these modes are HTTP/1 and HTTP/2, but tests can
opt-in to HTTPS/1 as well.
The 'run' function also takes care of post-test cleanup (running the
afterTest function).
The 'run' function runs tests in parallel by default. Tests which
can't run in parallel (generally because they use global test hooks)
pass a testNotParallel option to disable parallelism.
Update clientServerTest to use t.Cleanup to clean up after itself,
rather than leaving this up to tests to handle.
Drop an unnecessary mutex in SetReadLoopBeforeNextReadHook.
Test hooks can't be set in parallel, and we want the race detector
to notify us if two simultaneous tests try to set a hook.
Fixes #56032
Change-Id: I16be64913c426fc93d84abc6ad85dbd3bc191224
Reviewed-on: https://go-review.googlesource.com/c/go/+/438137
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: David Chase <drchase@google.com>
Diffstat (limited to 'src/net/http/serve_test.go')
| -rw-r--r-- | src/net/http/serve_test.go | 1180 |
1 files changed, 496 insertions, 684 deletions
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 4fadc56c9e..a93f6eff1b 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -246,15 +246,13 @@ var vtests = []struct { {"http://someHost.com/someDir", "/someDir/"}, } -func TestHostHandlers(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) } +func testHostHandlers(t *testing.T, mode testMode) { mux := NewServeMux() for _, h := range handlers { mux.Handle(h.pattern, stringHandler(h.msg)) } - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -487,9 +485,9 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { // properly sets the query string in the redirect URL. // See Issue 17841. func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { - setParallel(t) - defer afterTest(t) - + run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode}) +} +func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) { writeBackQuery := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.URL.RawQuery) } @@ -502,8 +500,7 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { fmt.Fprintf(w, "%s:bar", r.URL.RawQuery) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts tests := [...]struct { path string @@ -546,7 +543,6 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { setParallel(t) - defer afterTest(t) mux := NewServeMux() mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/")) @@ -578,9 +574,6 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""}, } - ts := httptest.NewServer(mux) - defer ts.Close() - for i, tt := range tests { req, _ := NewRequest(tt.method, tt.url, nil) w := httptest.NewRecorder() @@ -602,13 +595,10 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { } } -func TestShouldRedirectConcurrency(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) } +func testShouldRedirectConcurrency(t *testing.T, mode testMode) { mux := NewServeMux() - ts := httptest.NewServer(mux) - defer ts.Close() + newClientServerTest(t, mode, mux) mux.HandleFunc("/", func(w ResponseWriter, r *Request) {}) } @@ -656,13 +646,12 @@ func benchmarkServeMux(b *testing.B, runHandler bool) { } } -func TestServerTimeouts(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) } +func testServerTimeouts(t *testing.T, mode testMode) { // Try three times, with increasing timeouts. tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second} for i, timeout := range tries { - err := testServerTimeouts(timeout) + err := testServerTimeoutsWithTimeout(t, timeout, mode) if err == nil { return } @@ -674,16 +663,15 @@ func TestServerTimeouts(t *testing.T) { t.Fatal("all attempts failed") } -func testServerTimeouts(timeout time.Duration) error { +func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - })) - ts.Config.ReadTimeout = timeout - ts.Config.WriteTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + ts.Config.WriteTimeout = timeout + }).ts // Hit the HTTP server successfully. c := ts.Client() @@ -749,22 +737,20 @@ func testServerTimeouts(timeout time.Duration) error { } // Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437) -func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { +func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) { + run(t, testWriteDeadlineExtendedOnNewRequest) +} +func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {})) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}), + func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }, + ).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - t.Fatal(err) - } for i := 1; i <= 3; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -785,9 +771,6 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { t.Fatalf("http2 Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } time.Sleep(ts.Config.WriteTimeout / 2) } } @@ -810,33 +793,31 @@ func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) { } // Test that the HTTP/2 server RSTs stream on slow write. -func TestHTTP2WriteDeadlineEnforcedPerStream(t *testing.T) { +func TestWriteDeadlineEnforcedPerStream(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) - defer afterTest(t) - tryTimeouts(t, testHTTP2WriteDeadlineEnforcedPerStream) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testWriteDeadlineEnforcedPerStream(t, mode, timeout) + }) + }) } -func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { +func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request times out - })) - ts.Config.WriteTimeout = timeout / 2 - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = timeout / 2 + }).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -844,12 +825,9 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #1: %v", err) + return fmt.Errorf("Get #1: %v", err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } req, err = NewRequest("GET", ts.URL, nil) if err != nil { @@ -858,45 +836,42 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { r, err = c.Do(req) if err == nil { r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } - return fmt.Errorf("http2 Get #2 expected error, got nil") + return fmt.Errorf("Get #2 expected error, got nil") } - expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 - if !strings.Contains(err.Error(), expected) { - return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + if mode == http2Mode { + expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 + if !strings.Contains(err.Error(), expected) { + return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + } } return nil } // Test that the HTTP/2 server does not send RST when WriteDeadline not set. -func TestHTTP2NoWriteDeadline(t *testing.T) { +func TestNoWriteDeadline(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) defer afterTest(t) - tryTimeouts(t, testHTTP2NoWriteDeadline) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testNoWriteDeadline(t, mode, timeout) + }) + }) } -func testHTTP2NoWriteDeadline(timeout time.Duration) error { +func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request timesout - })) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + })).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } for i := 0; i < 2; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -905,12 +880,9 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #%d: %v", i, err) + return fmt.Errorf("Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } } return nil } @@ -918,15 +890,14 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { // golang.org/issue/4741 -- setting only a write timeout that triggers // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. -func TestOnlyWriteTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) } +func testOnlyWriteTimeout(t *testing.T, mode testMode) { var ( mu sync.RWMutex conn net.Conn ) var afterTimeoutErrc = make(chan error, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { buf := make([]byte, 512<<10) _, err := w.Write(buf) if err != nil { @@ -942,10 +913,9 @@ func TestOnlyWriteTimeout(t *testing.T) { conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) _, err = w.Write(buf) afterTimeoutErrc <- err - })) - ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} + }).ts c := ts.Client() @@ -992,9 +962,12 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) { } // TestIdentityResponse verifies that a handler can unset -func TestIdentityResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) } +func testIdentityResponse(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56019") + } + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -1012,9 +985,7 @@ func TestIdentityResponse(t *testing.T) { } }) - ts := httptest.NewServer(handler) - defer ts.Close() - + ts := newClientServerTest(t, mode, handler).ts c := ts.Client() // Note: this relies on the assumption (which is true) that @@ -1048,6 +1019,10 @@ func TestIdentityResponse(t *testing.T) { } res.Body.Close() + if mode != http1Mode { + return + } + // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1070,9 +1045,7 @@ func TestIdentityResponse(t *testing.T) { func testTCPConnectionCloses(t *testing.T, req string, h Handler) { setParallel(t) - defer afterTest(t) - s := httptest.NewServer(h) - defer s.Close() + s := newClientServerTest(t, http1Mode, h).ts conn, err := net.Dial("tcp", s.Listener.Addr().String()) if err != nil { @@ -1114,9 +1087,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(handler) - defer ts.Close() + ts := newClientServerTest(t, http1Mode, handler).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1192,14 +1163,12 @@ func TestHTTP10KeepAlive304Response(t *testing.T) { } // Issue 15703 -func TestKeepAliveFinalChunkWithEOF(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) } +func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() // force chunked encoding w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}")) })) - defer cst.close() type data struct { Addr string } @@ -1222,16 +1191,11 @@ func TestKeepAliveFinalChunkWithEOF(t *testing.T) { } } -func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } -func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } - -func testSetsRemoteAddr(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) } +func testSetsRemoteAddr(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1276,17 +1240,18 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr { // Issue 12943 func TestServerAllowsBlockingRemoteAddr(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "RA:%s", r.RemoteAddr) - })) + run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode}) +} +func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) { conns := make(chan net.Conn) - ts.Listener = &blockingRemoteAddrListener{ - Listener: ts.Listener, - conns: conns, - } - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "RA:%s", r.RemoteAddr) + }), func(ts *httptest.Server) { + ts.Listener = &blockingRemoteAddrListener{ + Listener: ts.Listener, + conns: conns, + } + }).ts c := ts.Client() c.Timeout = time.Second @@ -1351,13 +1316,9 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { // TestHeadResponses verifies that all MIME type sniffing and Content-Length // counting of GET requests also happens on HEAD requests. -func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) } -func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) } - -func testHeadResponses(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) } +func testHeadResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("<html>")) if err != nil { t.Errorf("ResponseWriter.Write: %v", err) @@ -1369,7 +1330,6 @@ func testHeadResponses(t *testing.T, h2 bool) { t.Errorf("Copy(ResponseWriter, ...): %v", err) } })) - defer cst.close() res, err := cst.c.Head(cst.ts.URL) if err != nil { t.Error(err) @@ -1393,14 +1353,16 @@ func testHeadResponses(t *testing.T, h2 bool) { } func TestTLSHandshakeTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode}) +} +func testTLSHandshakeTimeout(t *testing.T, mode testMode) { errc := make(chanWriter, 10) // but only expecting 1 - ts.Config.ReadTimeout = 250 * time.Millisecond - ts.Config.ErrorLog = log.New(errc, "", 0) - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), + func(ts *httptest.Server) { + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.ErrorLog = log.New(errc, "", 0) + }, + ).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -1423,19 +1385,18 @@ func TestTLSHandshakeTimeout(t *testing.T) { } } -func TestTLSServer(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) } +func testTLSServer(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") if r.TLS.HandshakeComplete { w.Header().Set("X-TLS-HandshakeComplete", "true") } } - })) - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + }).ts // Connect an idle TCP connection to this server before we run // our real tests. This idle connection used to block forever @@ -1528,14 +1489,15 @@ func TestServeTLS(t *testing.T) { // Test that the HTTPS server nicely rejects plaintext HTTP/1.x requests. func TestTLSServerRejectHTTPRequests(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode}) +} +func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("unexpected HTTPS request") - })) - var errBuf bytes.Buffer - ts.Config.ErrorLog = log.New(&errBuf, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + var errBuf bytes.Buffer + ts.Config.ErrorLog = log.New(&errBuf, "", 0) + }).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1727,11 +1689,9 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. -// http2 test: TestServer_Response_Automatic100Continue -func TestServerExpect(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) } +func testServerExpect(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only // conditionally want to do. @@ -1741,8 +1701,7 @@ func TestServerExpect(t *testing.T) { } else { w.WriteHeader(StatusUnauthorized) } - })) - defer ts.Close() + })).ts runTest := func(test serverExpectTest) { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -2287,11 +2246,8 @@ func (c cancelableTimeoutContext) Err() error { return nil } -func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } -func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } -func testTimeoutHandler(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) } +func testTimeoutHandler(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2301,8 +2257,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h2, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2348,10 +2303,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) { } // See issues 8209 and 8414. -func TestTimeoutHandlerRace(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) } +func testTimeoutHandlerRace(t *testing.T, mode testMode) { delayHi := HandlerFunc(func(w ResponseWriter, r *Request) { ms, _ := strconv.Atoi(r.URL.Path[1:]) if ms == 0 { @@ -2363,8 +2316,7 @@ func TestTimeoutHandlerRace(t *testing.T) { } }) - ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts c := ts.Client() @@ -2393,16 +2345,13 @@ func TestTimeoutHandlerRace(t *testing.T) { // See issues 8209 and 8414. // Both issues involved panics in the implementation of TimeoutHandler. -func TestTimeoutHandlerRaceHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) } +func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) { delay204 := HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(204) }) - ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts var wg sync.WaitGroup gate := make(chan bool, 50) @@ -2433,9 +2382,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { } // Issue 9162 -func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) } +func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2446,8 +2394,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h1Mode, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2491,15 +2438,17 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { // Issue 14568. func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { + run(t, testTimeoutHandlerStartTimerWhenServing) +} +func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping sleeping test in -short mode") } - defer afterTest(t) var handler HandlerFunc = func(w ResponseWriter, _ *Request) { w.WriteHeader(StatusNoContent) } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts defer ts.Close() c := ts.Client() @@ -2518,9 +2467,8 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { } } -func TestTimeoutHandlerContextCanceled(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) } +func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) { writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Type", "text/plain") @@ -2540,7 +2488,7 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() h := NewTestTimeoutHandler(sayHi, ctx) - cst := newClientServerTest(t, h1Mode, h) + cst := newClientServerTest(t, mode, h) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -2560,15 +2508,13 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { } // https://golang.org/issue/15948 -func TestTimeoutHandlerEmptyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) } +func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) { var handler HandlerFunc = func(w ResponseWriter, _ *Request) { // No response. } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts c := ts.Client() @@ -2587,7 +2533,9 @@ func TestTimeoutHandlerPanicRecovery(t *testing.T) { wrapper := func(h Handler) Handler { return TimeoutHandler(h, time.Second, "") } - testHandlerPanic(t, false, false, wrapper, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, wrapper, "intentional death for testing") + }, testNotParallel) } func TestRedirectBadPath(t *testing.T) { @@ -2705,17 +2653,10 @@ func TestRedirectContentTypeAndBody(t *testing.T) { // connection immediately. But when it re-uses the connection, it typically closes // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. -func TestZeroLengthPostAndResponse_h1(t *testing.T) { - testZeroLengthPostAndResponse(t, h1Mode) -} -func TestZeroLengthPostAndResponse_h2(t *testing.T) { - testZeroLengthPostAndResponse(t, h2Mode) -} +func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) } -func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { +func testZeroLengthPostAndResponse(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("handler ReadAll: %v", err) @@ -2725,7 +2666,6 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } rw.Header().Set("Content-Length", "0") })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, strings.NewReader("")) if err != nil { @@ -2752,23 +2692,26 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } } -func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil, nil) } -func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil, nil) } - -func TestHandlerPanic_h1(t *testing.T) { - testHandlerPanic(t, false, h1Mode, nil, "intentional death for testing") +func TestHandlerPanicNil(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, nil) + }, testNotParallel) } -func TestHandlerPanic_h2(t *testing.T) { - testHandlerPanic(t, false, h2Mode, nil, "intentional death for testing") + +func TestHandlerPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, "intentional death for testing") + }, testNotParallel) } func TestHandlerPanicWithHijack(t *testing.T) { // Only testing HTTP/1, and our http2 server doesn't support hijacking. - testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, true, mode, nil, "intentional death for testing") + }, testNotParallel, []testMode{http1Mode}) } -func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) { - defer afterTest(t) +func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) { // Unlike the other tests that set the log output to io.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -2803,8 +2746,7 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) H if wrapper != nil { handler = wrapper(handler) } - cst := newClientServerTest(t, h2, handler) - defer cst.close() + cst := newClientServerTest(t, mode, handler) // Do a blocking read on the log output pipe so its logging // doesn't bleed into the next test. But wait only 5 seconds @@ -2847,9 +2789,11 @@ func (w terrorWriter) Write(p []byte) (int, error) { // Issue 16456: allow writing 0 bytes on hijacked conn to test hijack // without any log spam. func TestServerWriteHijackZeroBytes(t *testing.T) { - defer afterTest(t) + run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode}) +} +func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) { done := make(chan struct{}) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) w.(Flusher).Flush() conn, _, err := w.(Hijacker).Hijack() @@ -2862,10 +2806,9 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { if err != ErrHijacked { t.Errorf("Write error = %v; want ErrHijacked", err) } - })) - ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) + }).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -2880,19 +2823,23 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { } } -func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") } -func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") } -func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") } -func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") } +func TestServerNoDate(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Date") + }) +} -func testServerNoHeader(t *testing.T, h2 bool, header string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerContentType(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Content-Type") + }) +} + +func testServerNoHeader(t *testing.T, mode testMode, header string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()[header] = nil io.WriteString(w, "<html>foo</html>") // non-empty })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -2903,15 +2850,13 @@ func testServerNoHeader(t *testing.T, h2 bool, header string) { } } -func TestStripPrefix(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) } +func testStripPrefix(t *testing.T, mode testMode) { h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) w.Header().Set("X-RawPath", r.URL.RawPath) }) - ts := httptest.NewServer(StripPrefix("/foo/bar", h)) - defer ts.Close() + ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts c := ts.Client() @@ -2961,15 +2906,11 @@ func TestStripPrefixNotModifyRequest(t *testing.T) { } } -func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) } -func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) } -func testRequestLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) } +func testRequestLimit(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") }), optQuietLog) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) var bytesPerHeader = len("header12345: val12345\r\n") for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ { @@ -2979,7 +2920,7 @@ func testRequestLimit(t *testing.T, h2 bool) { if res != nil { defer res.Body.Close() } - if h2 { + if mode == http2Mode { // In HTTP/2, the result depends on a race. If the client has received the // server's SETTINGS before RoundTrip starts sending the request, then RoundTrip // will fail with an error. Otherwise, the client should receive a 431 from the @@ -3021,13 +2962,10 @@ func (cr countReader) Read(p []byte) (n int, err error) { return } -func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) } -func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) } -func testRequestBodyLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) } +func testRequestBodyLimit(t *testing.T, mode testMode) { const limit = 1 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) n, err := io.Copy(io.Discard, r.Body) if err == nil { @@ -3044,7 +2982,6 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit) } })) - defer cst.close() nWritten := new(int64) req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) @@ -3068,13 +3005,12 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. -func TestClientWriteShutdown(t *testing.T) { +func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) } +func testClientWriteShutdown(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/17906") } - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -3119,12 +3055,12 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See https://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerGracefulClose, []testMode{http1Mode}) +} +func testServerGracefulClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3162,11 +3098,9 @@ func TestServerGracefulClose(t *testing.T) { <-writeErr } -func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) } -func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) } -func testCaseSensitiveMethod(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) } +func testCaseSensitiveMethod(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "get" { t.Errorf(`Got method %q; want "get"`, r.Method) } @@ -3187,8 +3121,10 @@ func testCaseSensitiveMethod(t *testing.T, h2 bool) { // response, the net/http package adds a "Content-Length: 0" response // header. func TestContentLengthZero(t *testing.T) { - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {})) - defer ts.Close() + run(t, testContentLengthZero, []testMode{http1Mode}) +} +func testContentLengthZero(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -3215,15 +3151,17 @@ func TestContentLengthZero(t *testing.T) { } func TestCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testCloseNotifier, []testMode{http1Mode}) +} +func testCloseNotifier(t *testing.T, mode testMode) { gotReq := make(chan bool, 1) sawClose := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() <-cc sawClose <- true - })) + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3257,11 +3195,12 @@ For: // // Issue 13165 (where it used to deadlock), but behavior changed in Issue 23921. func TestCloseNotifierPipelined(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testCloseNotifierPipelined, []testMode{http1Mode}) +} +func testCloseNotifierPipelined(t *testing.T, mode testMode) { gotReq := make(chan bool, 2) sawClose := make(chan bool, 2) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() select { @@ -3270,8 +3209,7 @@ func TestCloseNotifierPipelined(t *testing.T) { case <-time.After(100 * time.Millisecond): } sawClose <- true - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3341,12 +3279,14 @@ func TestCloseNotifierChanLeak(t *testing.T) { // Issue 9763. // HTTP/1-only test. (http2 doesn't have Hijack) func TestHijackAfterCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testHijackAfterCloseNotifier, []testMode{http1Mode}) +} +func testHijackAfterCloseNotifier(t *testing.T, mode testMode) { script := make(chan string, 2) script <- "closenotify" script <- "hijack" close(script) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { plan := <-script switch plan { default: @@ -3369,13 +3309,12 @@ func TestHijackAfterCloseNotifier(t *testing.T) { c.Close() return } - })) - defer ts.Close() - res1, err := Get(ts.URL) + })).ts + res1, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } - res2, err := Get(ts.URL) + res2, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } @@ -3387,12 +3326,13 @@ func TestHijackAfterCloseNotifier(t *testing.T) { } func TestHijackBeforeRequestBodyRead(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode}) +} +func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) { var requestBody = bytes.Repeat([]byte("a"), 1<<20) bodyOkay := make(chan bool, 1) gotCloseNotify := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(bodyOkay) // caller will read false if nothing else reqBody := r.Body @@ -3419,8 +3359,7 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { case <-time.After(5 * time.Second): gotCloseNotify <- false } - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3440,14 +3379,14 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { } } -func TestOptions(t *testing.T) { +func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) } +func testOptions(t *testing.T, mode testMode) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() mux.HandleFunc("/", func(w ResponseWriter, r *Request) { uric <- r.RequestURI }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3492,15 +3431,15 @@ func TestOptions(t *testing.T) { } } -func TestOptionsHandler(t *testing.T) { +func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) } +func testOptionsHandler(t *testing.T, mode testMode) { rc := make(chan *Request, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { rc <- r - })) - ts.Config.DisableGeneralOptionsHandler = true - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.DisableGeneralOptionsHandler = true + }).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3804,12 +3743,12 @@ func TestDoubleHijack(t *testing.T) { // optimization and is pointless if dealing with a // badly behaved client. func TestHTTP10ConnectionHeader(t *testing.T) { - defer afterTest(t) - + run(t, testHTTP10ConnectionHeader, []testMode{http1Mode}) +} +func testHTTP10ConnectionHeader(t *testing.T, mode testMode) { mux := NewServeMux() mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {})) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts // net/http uses HTTP/1.1 for requests, so write requests manually tests := []struct { @@ -3856,14 +3795,11 @@ func TestHTTP10ConnectionHeader(t *testing.T) { } // See golang.org/issue/5660 -func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) } -func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) } -func testServerReaderFromOrder(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) } +func testServerReaderFromOrder(t *testing.T, mode testMode) { pr, pw := io.Pipe() const size = 3 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path done := make(chan bool) go func() { @@ -3883,7 +3819,6 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { pw.Close() <-done })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size)) if err != nil { @@ -3957,16 +3892,10 @@ func TestContentTypeOkayOn204(t *testing.T) { // proxy). So then two people own that Request.Body (both the server // and the http client), and both think they can close it on failure. // Therefore, all incoming server requests Bodies need to be thread-safe. -func TestTransportAndServerSharedBodyRace_h1(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h1Mode) +func TestTransportAndServerSharedBodyRace(t *testing.T) { + run(t, testTransportAndServerSharedBodyRace) } -func TestTransportAndServerSharedBodyRace_h2(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h2Mode) -} -func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) { const bodySize = 1 << 20 // errorf is like t.Errorf, but also writes to println. When @@ -3980,7 +3909,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { } unblockBackend := make(chan bool) - backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gone := rw.(CloseNotifier).CloseNotify() didCopy := make(chan any) go func() { @@ -4007,7 +3936,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { backendRespc := make(chan *Response, 1) var proxy *clientServerTest - proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { req2, _ := NewRequest("POST", backend.ts.URL, req.Body) req2.ContentLength = bodySize cancel := make(chan struct{}) @@ -4027,7 +3956,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // Try to cause a race: Both the Transport and the proxy handler's Server // will try to read/close req.Body (aka req2.Body) - if h2 { + if mode == http2Mode { close(cancel) } else { proxy.c.Transport.(*Transport).CancelRequest(req2) @@ -4071,22 +4000,23 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // cause the Handler goroutine's Request.Body.Close to block. // See issue 7121. func TestRequestBodyCloseDoesntBlock(t *testing.T) { + run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode}) +} +func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) readErrCh := make(chan error, 1) errCh := make(chan error, 2) - server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { go func(body io.Reader) { _, err := body.Read(make([]byte, 100)) readErrCh <- err }(req.Body) time.Sleep(500 * time.Millisecond) - })) - defer server.Close() + })).ts closeConn := make(chan bool) defer close(closeConn) @@ -4149,9 +4079,8 @@ func TestAppendTime(t *testing.T) { } } -func TestServerConnState(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) } +func testServerConnState(t *testing.T, mode testMode) { handler := map[string]func(w ResponseWriter, r *Request){ "/": func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello.") @@ -4217,37 +4146,36 @@ func TestServerConnState(t *testing.T) { // next call to wantLog. } - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { handler[r.URL.Path](w, r) - })) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + ts.Config.ConnState = func(c net.Conn, state ConnState) { + if c == nil { + t.Errorf("nil conn seen in state %s", state) + return + } + sl := <-activeLog + if sl.active == nil && state == StateNew { + sl.active = c + } else if sl.active != c { + t.Errorf("unexpected conn in state %s", state) + activeLog <- sl + return + } + sl.got = append(sl.got, state) + if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { + close(sl.complete) + sl.complete = nil + } + activeLog <- sl + } + }).ts defer func() { activeLog <- &stateLog{} // If the test failed, allow any remaining ConnState callbacks to complete. ts.Close() }() - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - ts.Config.ConnState = func(c net.Conn, state ConnState) { - if c == nil { - t.Errorf("nil conn seen in state %s", state) - return - } - sl := <-activeLog - if sl.active == nil && state == StateNew { - sl.active = c - } else if sl.active != c { - t.Errorf("unexpected conn in state %s", state) - activeLog <- sl - return - } - sl.got = append(sl.got, state) - if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { - close(sl.complete) - sl.complete = nil - } - activeLog <- sl - } - - ts.Start() c := ts.Client() mustGet := func(url string, headers ...string) { @@ -4329,13 +4257,15 @@ func TestServerConnState(t *testing.T) { }, StateNew, StateActive, StateIdle, StateClosed) } -func TestServerKeepAlivesEnabled(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - ts.Config.SetKeepAlivesEnabled(false) - ts.Start() - defer ts.Close() - res, err := Get(ts.URL) +func TestServerKeepAlivesEnabledResultClose(t *testing.T) { + run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode}) +} +func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -4346,16 +4276,12 @@ func TestServerKeepAlivesEnabled(t *testing.T) { } // golang.org/issue/7856 -func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) } -func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) } -func testServerEmptyBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) } +func testServerEmptyBodyRace(t *testing.T, mode testMode) { var n int32 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { atomic.AddInt32(&n, 1) }), optQuietLog) - defer cst.close() var wg sync.WaitGroup const reqs = 20 for i := 0; i < reqs; i++ { @@ -4436,9 +4362,9 @@ func TestCloseWrite(t *testing.T) { // fixed. // // So add an explicit test for this. -func TestServerFlushAndHijack(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) } +func testServerFlushAndHijack(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello, ") w.(Flusher).Flush() conn, buf, _ := w.(Hijacker).Hijack() @@ -4449,8 +4375,7 @@ func TestServerFlushAndHijack(t *testing.T) { if err := conn.Close(); err != nil { t.Error(err) } - })) - defer ts.Close() + })).ts res, err := Get(ts.URL) if err != nil { t.Fatal(err) @@ -4472,20 +4397,21 @@ func TestServerFlushAndHijack(t *testing.T) { // To test, verify we don't timeout or see fewer unique client // addresses (== unique connections) than requests. func TestServerKeepAliveAfterWriteError(t *testing.T) { + run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode}) +} +func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) const numReq = 3 addrc := make(chan string, numReq) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addrc <- r.RemoteAddr time.Sleep(500 * time.Millisecond) w.(Flusher).Flush() - })) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }).ts errc := make(chan error, numReq) go func() { @@ -4529,12 +4455,13 @@ func TestServerKeepAliveAfterWriteError(t *testing.T) { // Issue 9987: shouldn't add automatic Content-Length (or // Content-Type) if a Transfer-Encoding was set by the handler. func TestNoContentLengthIfTransferEncoding(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode}) +} +func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Transfer-Encoding", "foo") io.WriteString(w, "<html>") - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -4682,15 +4609,12 @@ func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { } } -func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) } -func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) } -func testHandlerSetsBodyNil(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) } +func testHandlerSetsBodyNil(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = nil fmt.Fprintf(w, "%v", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -4780,9 +4704,11 @@ func TestServerValidatesHostHeader(t *testing.T) { } func TestServerHandlersCanHandleH2PRI(t *testing.T) { + run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode}) +} +func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) { const upgradeResponse = "upgrade here" - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, br, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -4804,8 +4730,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { return } io.WriteString(conn, upgradeResponse) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -4872,17 +4797,12 @@ func TestServerValidatesHeaders(t *testing.T) { } } -func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h1Mode) -} -func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h2Mode) +func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { + run(t, testServerRequestContextCancel_ServeHTTPDone) } -func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) { ctxc := make(chan context.Context, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() select { case <-ctx.Done(): @@ -4891,7 +4811,6 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { } ctxc <- ctx })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4910,16 +4829,16 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { // is always blocked in a Read call so it notices the EOF from the client. // See issues 15927 and 15224. func TestServerRequestContextCancel_ConnClose(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode}) +} +func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) { inHandler := make(chan struct{}) handlerDone := make(chan struct{}) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(inHandler) <-r.Context().Done() close(handlerDone) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -4931,23 +4850,17 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) { <-handlerDone } -func TestServerContext_ServerContextKey_h1(t *testing.T) { - testServerContext_ServerContextKey(t, h1Mode) -} -func TestServerContext_ServerContextKey_h2(t *testing.T) { - testServerContext_ServerContextKey(t, h2Mode) +func TestServerContext_ServerContextKey(t *testing.T) { + run(t, testServerContext_ServerContextKey) } -func testServerContext_ServerContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testServerContext_ServerContextKey(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() got := ctx.Value(ServerContextKey) if _, ok := got.(*Server); !ok { t.Errorf("context value = %T; want *http.Server", got) } })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4955,20 +4868,14 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) { res.Body.Close() } -func TestServerContext_LocalAddrContextKey_h1(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h1Mode) -} -func TestServerContext_LocalAddrContextKey_h2(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h2Mode) +func TestServerContext_LocalAddrContextKey(t *testing.T) { + run(t, testServerContext_LocalAddrContextKey) } -func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) { ch := make(chan any, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ch <- r.Context().Value(LocalAddrContextKey) })) - defer cst.close() if _, err := cst.c.Head(cst.ts.URL); err != nil { t.Fatal(err) } @@ -5021,16 +4928,19 @@ func TestHandlerSetTransferEncodingGzip(t *testing.T) { } func BenchmarkClientServer(b *testing.B) { + run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode}) +} +func benchmarkClientServer(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") - })) - defer ts.Close() + })).ts b.StartTimer() + c := ts.Client() for i := 0; i < b.N; i++ { - res, err := Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { b.Fatal("Get:", err) } @@ -5048,33 +4958,21 @@ func BenchmarkClientServer(b *testing.B) { b.StopTimer() } -func BenchmarkClientServerParallel4(b *testing.B) { - benchmarkClientServerParallel(b, 4, false) -} - -func BenchmarkClientServerParallel64(b *testing.B) { - benchmarkClientServerParallel(b, 64, false) -} - -func BenchmarkClientServerParallelTLS4(b *testing.B) { - benchmarkClientServerParallel(b, 4, true) -} - -func BenchmarkClientServerParallelTLS64(b *testing.B) { - benchmarkClientServerParallel(b, 64, true) +func BenchmarkClientServerParallel(b *testing.B) { + for _, parallelism := range []int{4, 64} { + b.Run(fmt.Sprint(parallelism), func(b *testing.B) { + run(b, func(b *testing.B, mode testMode) { + benchmarkClientServerParallel(b, parallelism, mode) + }, []testMode{http1Mode, https1Mode, http2Mode}) + }) + } } -func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { +func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) { b.ReportAllocs() - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") - })) - if useTLS { - ts.StartTLS() - } else { - ts.Start() - } - defer ts.Close() + })).ts b.ResetTimer() b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { @@ -5464,15 +5362,15 @@ Host: golang.org } } -func BenchmarkCloseNotifier(b *testing.B) { +func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) } +func benchmarkCloseNotifier(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() sawClose := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { <-rw.(CloseNotifier).CloseNotify() sawClose <- true - })) - defer ts.Close() + })).ts tot := time.NewTimer(5 * time.Second) defer tot.Stop() b.StartTimer() @@ -5508,20 +5406,18 @@ func TestConcurrentServerServe(t *testing.T) { } } -func TestServerIdleTimeout(t *testing.T) { +func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) } +func testServerIdleTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) io.WriteString(w, r.RemoteAddr) - })) - ts.Config.ReadHeaderTimeout = 1 * time.Second - ts.Config.IdleTimeout = 2 * time.Second - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = 1 * time.Second + ts.Config.IdleTimeout = 2 * time.Second + }).ts c := ts.Client() get := func() string { @@ -5576,12 +5472,12 @@ func get(t *testing.T, c *Client, url string) string { // Tests that calls to Server.SetKeepAlivesEnabled(false) closes any // currently-open connections. func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode}) +} +func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -5620,16 +5516,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { } } -func TestServerShutdown_h1(t *testing.T) { - testServerShutdown(t, h1Mode) -} -func TestServerShutdown_h2(t *testing.T) { - testServerShutdown(t, h2Mode) -} - -func testServerShutdown(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) } +func testServerShutdown(t *testing.T, mode testMode) { var doShutdown func() // set later var doStateCount func() var shutdownRes = make(chan error, 1) @@ -5645,10 +5533,9 @@ func testServerShutdown(t *testing.T, h2 bool) { time.Sleep(20 * time.Millisecond) io.WriteString(w, r.RemoteAddr) }) - cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) { + cst := newClientServerTest(t, mode, handler, func(srv *httptest.Server) { srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} }) }) - defer cst.close() doShutdown = func() { shutdownRes <- cst.ts.Config.Shutdown(context.Background()) @@ -5678,24 +5565,22 @@ func testServerShutdown(t *testing.T, h2 bool) { } } -func TestServerShutdownStateNew(t *testing.T) { +func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) } +func testServerShutdownStateNew(t *testing.T, mode testMode) { if testing.Short() { t.Skip("test takes 5-6 seconds; skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - // nothing. - })) var connAccepted sync.WaitGroup - ts.Config.ConnState = func(conn net.Conn, state ConnState) { - if state == StateNew { - connAccepted.Done() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // nothing. + }), func(ts *httptest.Server) { + ts.Config.ConnState = func(conn net.Conn, state ConnState) { + if state == StateNew { + connAccepted.Done() + } } - } - ts.Start() - defer ts.Close() + }).ts // Start a connection but never write to it. connAccepted.Add(1) @@ -5757,16 +5642,14 @@ func TestServerCloseDeadlock(t *testing.T) { // Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by // both HTTP/1 and HTTP/2. -func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) } -func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) } -func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { - if h2 { +func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) } +func testServerKeepAlivesEnabled(t *testing.T, mode testMode) { + if mode == http2Mode { restore := ExportSetH2GoawayTimeout(10 * time.Millisecond) defer restore() } // Not parallel: messes with global variable. (http2goAwayTimeout) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) defer cst.close() srv := cst.ts.Config srv.SetKeepAlivesEnabled(false) @@ -5803,9 +5686,8 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { // Issue 18447: test that the Server's ReadTimeout is stopped while // the server's doing its 1-byte background read between requests, // waiting for the connection to maybe close. -func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) } +func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 10 * time.Millisecond, 50 * time.Millisecond, @@ -5813,17 +5695,16 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { select { case <-time.After(2 * timeout): fmt.Fprint(w, "ok") case <-r.Context().Done(): fmt.Fprint(w, r.Context().Err()) } - })) - ts.Config.ReadTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + }).ts c := ts.Client() @@ -5847,8 +5728,9 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { // beginning of a request has been received, rather than including time the // connection spent idle. func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode}) +} +func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 10 * time.Millisecond, 50 * time.Millisecond, @@ -5856,11 +5738,10 @@ func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := httptest.NewUnstartedServer(serve(200)) - ts.Config.ReadHeaderTimeout = timeout - ts.Config.IdleTimeout = 0 // disable idle timeout - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = timeout + ts.Config.IdleTimeout = 0 // disable idle timeout + }).ts // rather than using an http.Client, create a single connection, so that // we can ensure this connection is not closed. @@ -5912,13 +5793,13 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t * // Issue 18535: test that the Server doesn't try to do a background // read if it's already done one. func TestServerDuplicateBackgroundRead(t *testing.T) { + run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode}) +} +func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) { if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" { testenv.SkipFlaky(t, 24826) } - setParallel(t) - defer afterTest(t) - goroutines := 5 requests := 2000 if testing.Short() { @@ -5926,8 +5807,7 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { requests = 100 } - hts := httptest.NewServer(HandlerFunc(NotFound)) - defer hts.Close() + hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n") @@ -5970,14 +5850,15 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { // bufio.Reader.Buffered(), without resorting to Reading it // (potentially blocking) to get at it. func TestServerHijackGetsBackgroundByte(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) inHandler := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) // Tell the client to send more data after the GET request. @@ -6000,8 +5881,7 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { t.Error("context unexpectedly canceled") default: } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -6030,14 +5910,15 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { // immediate 1MB of data to the server to fill up the server's 4KB // buffer. func TestServerHijackGetsBackgroundByte_big(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) const size = 8 << 10 - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) conn, buf, err := w.(Hijacker).Hijack() @@ -6061,8 +5942,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { } else if !allX { t.Errorf("read %q; want %d 'x'", slurp, size) } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -6198,73 +6078,27 @@ func TestStripPortFromHost(t *testing.T) { } } -func TestServerContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) - type baseKey struct{} - type connKey struct{} - ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) - } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) - } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() - res, err := ts.Client().Get(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - ctx := <-ch - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("base context key = %#v; want %q", got, want) - } - if got, want := ctx.Value(connKey{}), "conn"; got != want { - t.Errorf("conn context key = %#v; want %q", got, want) - } -} - -func TestServerContextsHTTP2(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerContexts(t *testing.T) { run(t, testServerContexts) } +func testServerContexts(t *testing.T, mode testMode) { type baseKey struct{} type connKey struct{} ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - if r.ProtoMajor != 2 { - t.Errorf("unexpected HTTP/1.x request") - } + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) + }), func(ts *httptest.Server) { + ts.Config.BaseContext = func(ln net.Listener) context.Context { + if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { + t.Errorf("unexpected onceClose listener type %T", ln) + } + return context.WithValue(context.Background(), baseKey{}, "base") } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.TLS = &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - } - ts.StartTLS() - defer ts.Close() - ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true + }).ts res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) @@ -6281,20 +6115,20 @@ func TestServerContextsHTTP2(t *testing.T) { // Issue 35750: check ConnContext not modifying context for other connections func TestConnContextNotModifyingAllContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testConnContextNotModifyingAllContexts) +} +func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) { type connKey struct{} - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { rw.Header().Set("Connection", "close") - })) - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got := ctx.Value(connKey{}); got != nil { - t.Errorf("in ConnContext, unexpected context key = %#v", got) + }), func(ts *httptest.Server) { + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got := ctx.Value(connKey{}); got != nil { + t.Errorf("in ConnContext, unexpected context key = %#v", got) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() + }).ts var res *Response var err error @@ -6315,10 +6149,12 @@ func TestConnContextNotModifyingAllContexts(t *testing.T) { // Issue 30710: ensure that as per the spec, a server responds // with 501 Not Implemented for unsupported transfer-encodings. func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode}) +} +func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello, World!")) - })) - defer cst.Close() + })).ts serverURL, err := url.Parse(cst.URL) if err != nil { @@ -6353,19 +6189,9 @@ func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { } } -func TestContentEncodingNoSniffing_h1(t *testing.T) { - testContentEncodingNoSniffing(t, h1Mode) -} - -func TestContentEncodingNoSniffing_h2(t *testing.T) { - testContentEncodingNoSniffing(t, h2Mode) -} - // Issue 31753: don't sniff when Content-Encoding is set -func testContentEncodingNoSniffing(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) } +func testContentEncodingNoSniffing(t *testing.T, mode testMode) { type setting struct { name string body []byte @@ -6428,13 +6254,12 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { for _, tt := range settings { t.Run(tt.name, func(t *testing.T) { - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { if tt.contentEncoding != nil { rw.Header().Set("Content-Encoding", tt.contentEncoding.(string)) } rw.Write(tt.body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -6460,13 +6285,13 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { // Issue 30803: ensure that TimeoutHandler logs spurious // WriteHeader calls, for consistency with other Handlers. func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { + run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode}) +} +func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - pc, curFile, _, _ := runtime.Caller(0) curFileBaseName := filepath.Base(curFile) testFuncName := runtime.FuncForPC(pc).Name() @@ -6520,7 +6345,7 @@ func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { dur = 10 * time.Second } th := TimeoutHandler(sh, dur, timeoutMsg) - cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, th, optWithServerLog(srvLog)) + cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog)) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -6590,15 +6415,16 @@ func BenchmarkResponseStatusLine(b *testing.B) { } }) } + func TestDisableKeepAliveUpgrade(t *testing.T) { + run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode}) +} +func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - - s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "someProto") w.WriteHeader(StatusSwitchingProtocols) @@ -6611,10 +6437,9 @@ func TestDisableKeepAliveUpgrade(t *testing.T) { // Copy from the *bufio.ReadWriter, which may contain buffered data. // Copy to the net.Conn, to avoid buffering the output. io.Copy(c, buf) - })) - s.Config.SetKeepAlivesEnabled(false) - s.Start() - defer s.Close() + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts cl := s.Client() cl.Transport.(*Transport).DisableKeepAlives = true @@ -6683,21 +6508,21 @@ func TestQuerySemicolon(t *testing.T) { {"?a=1;x=good;x=bad", "", "good", true}, } - for _, tt := range tests { - t.Run(tt.query+"/allow=false", func(t *testing.T) { - allowSemicolons := false - testQuerySemicolon(t, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) - }) - t.Run(tt.query+"/allow=true", func(t *testing.T) { - allowSemicolons, expectWarning := true, false - testQuerySemicolon(t, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) - }) - } + run(t, func(t *testing.T, mode testMode) { + for _, tt := range tests { + t.Run(tt.query+"/allow=false", func(t *testing.T) { + allowSemicolons := false + testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) + }) + t.Run(tt.query+"/allow=true", func(t *testing.T) { + allowSemicolons, expectWarning := true, false + testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) + }) + } + }) } -func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolons, expectWarning bool) { - setParallel(t) - +func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectWarning bool) { writeBackX := func(w ResponseWriter, r *Request) { x := r.URL.Query().Get("x") if expectWarning { @@ -6720,11 +6545,10 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon h = AllowQuerySemicolons(h) } - ts := httptest.NewUnstartedServer(h) logBuf := &strings.Builder{} - ts.Config.ErrorLog = log.New(logBuf, "", 0) - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(logBuf, "", 0) + }).ts req, _ := NewRequest("GET", ts.URL+query, nil) res, err := ts.Client().Do(req) @@ -6759,13 +6583,15 @@ func TestMaxBytesHandler(t *testing.T) { for _, requestSize := range []int64{100, 1_000, 1_000_000} { t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize), func(t *testing.T) { - testMaxBytesHandler(t, maxSize, requestSize) + run(t, func(t *testing.T, mode testMode) { + testMaxBytesHandler(t, mode, maxSize, requestSize) + }) }) } } } -func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { +func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) { var ( handlerN int64 handlerErr error @@ -6776,7 +6602,7 @@ func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { io.Copy(w, &buf) }) - ts := httptest.NewServer(MaxBytesHandler(echo, maxSize)) + ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts defer ts.Close() c := ts.Client() @@ -6843,13 +6669,12 @@ func TestProcessing(t *testing.T) { } } -func TestParseFormCleanup_h1(t *testing.T) { testParseFormCleanup(t, h1Mode) } -func TestParseFormCleanup_h2(t *testing.T) { - t.Skip("https://go.dev/issue/20253") - testParseFormCleanup(t, h2Mode) -} +func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) } +func testParseFormCleanup(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/20253") + } -func testParseFormCleanup(t *testing.T, h2 bool) { const maxMemory = 1024 const key = "file" @@ -6858,9 +6683,7 @@ func testParseFormCleanup(t *testing.T, h2 bool) { t.Skip("https://go.dev/issue/25965") } - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.ParseMultipartForm(maxMemory) f, _, err := r.FormFile(key) if err != nil { @@ -6874,7 +6697,6 @@ func testParseFormCleanup(t *testing.T, h2 bool) { } w.Write([]byte(of.Name())) })) - defer cst.close() fBuf := new(bytes.Buffer) mw := multipart.NewWriter(fBuf) @@ -6911,33 +6733,23 @@ func testParseFormCleanup(t *testing.T, h2 bool) { func TestHeadBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "HEAD") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "HEAD") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") }) }) } func TestGetBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "GET") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "GET") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") }) }) } -func testHeadBody(t *testing.T, h2, chunked bool, method string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { b, err := io.ReadAll(r.Body) if err != nil { t.Errorf("server reading body: %v", err) |
