From 747e1961e95c2eb3df62e045b90b111c2ceea337 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 3 Oct 2022 16:07:48 -0700 Subject: net/http: refactor tests to run most in HTTP/1 and HTTP/2 modes Replace the ad-hoc approach to running tests in HTTP/1 and HTTP/2 modes with a 'run' function that executes a test in various modes. By default, these modes are HTTP/1 and HTTP/2, but tests can opt-in to HTTPS/1 as well. The 'run' function also takes care of post-test cleanup (running the afterTest function). The 'run' function runs tests in parallel by default. Tests which can't run in parallel (generally because they use global test hooks) pass a testNotParallel option to disable parallelism. Update clientServerTest to use t.Cleanup to clean up after itself, rather than leaving this up to tests to handle. Drop an unnecessary mutex in SetReadLoopBeforeNextReadHook. Test hooks can't be set in parallel, and we want the race detector to notify us if two simultaneous tests try to set a hook. Fixes #56032 Change-Id: I16be64913c426fc93d84abc6ad85dbd3bc191224 Reviewed-on: https://go-review.googlesource.com/c/go/+/438137 TryBot-Result: Gopher Robot Run-TryBot: Damien Neil Reviewed-by: Brad Fitzpatrick Reviewed-by: David Chase --- src/net/http/clientserver_test.go | 441 ++++++++++++++++++++------------------ 1 file changed, 237 insertions(+), 204 deletions(-) (limited to 'src/net/http/clientserver_test.go') diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index b472ca4b78..87e34cef85 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -35,8 +35,65 @@ import ( "time" ) +type testMode string + +const ( + http1Mode = testMode("h1") // HTTP/1.1 + https1Mode = testMode("https1") // HTTPS/1.1 + http2Mode = testMode("h2") // HTTP/2 +) + +type testNotParallelOpt struct{} + +var ( + testNotParallel = testNotParallelOpt{} +) + +type TBRun[T any] interface { + testing.TB + Run(string, func(T)) bool +} + +// run runs a client/server test in a variety of test configurations. +// +// Tests execute in HTTP/1.1 and HTTP/2 modes by default. +// To run in a different set of configurations, pass a []testMode option. +// +// Tests call t.Parallel() by default. +// To disable parallel execution, pass the testNotParallel option. +func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { + t.Helper() + modes := []testMode{http1Mode, http2Mode} + parallel := true + for _, opt := range opts { + switch opt := opt.(type) { + case []testMode: + modes = opt + case testNotParallelOpt: + parallel = false + default: + t.Fatalf("unknown option type %T", opt) + } + } + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + for _, mode := range modes { + t.Run(string(mode), func(t T) { + t.Helper() + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + t.Cleanup(func() { + afterTest(t) + }) + f(t, mode) + }) + } +} + type clientServerTest struct { - t *testing.T + t testing.TB h2 bool h Handler ts *httptest.Server @@ -69,11 +126,6 @@ func (t *clientServerTest) scheme() string { return "http" } -const ( - h1Mode = false - h2Mode = true -) - var optQuietLog = func(ts *httptest.Server) { ts.Config.ErrorLog = quietLog } @@ -84,23 +136,33 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) { } } -func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest { - if h2 { +// newClientServerTest creates and starts an httptest.Server. +// +// The mode parameter selects the implementation to test: +// HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use +// the 'run' function, which will start a subtests for each tested mode. +// +// The vararg opts parameter can include functions to configure the +// test server or transport. +// +// func(*httptest.Server) // run before starting the server +// func(*http.Transport) +func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { + if mode == http2Mode { CondSkipHTTP2(t) } cst := &clientServerTest{ t: t, - h2: h2, + h2: mode == http2Mode, h: h, - tr: &Transport{}, } - cst.c = &Client{Transport: cst.tr} cst.ts = httptest.NewUnstartedServer(h) + var transportFuncs []func(*Transport) for _, opt := range opts { switch opt := opt.(type) { case func(*Transport): - opt(cst.tr) + transportFuncs = append(transportFuncs, opt) case func(*httptest.Server): opt(cst.ts) default: @@ -108,60 +170,84 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientS } } - if !h2 { + switch mode { + case http1Mode: cst.ts.Start() - return cst + case https1Mode: + cst.ts.StartTLS() + case http2Mode: + ExportHttp2ConfigureServer(cst.ts.Config, nil) + cst.ts.TLS = cst.ts.Config.TLSConfig + cst.ts.StartTLS() + default: + t.Fatalf("unknown test mode %v", mode) } - ExportHttp2ConfigureServer(cst.ts.Config, nil) - cst.ts.TLS = cst.ts.Config.TLSConfig - cst.ts.StartTLS() - - cst.tr.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, + cst.c = cst.ts.Client() + cst.tr = cst.c.Transport.(*Transport) + if mode == http2Mode { + if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { + t.Fatal(err) + } } - if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { - t.Fatal(err) + for _, f := range transportFuncs { + f(cst.tr) } + t.Cleanup(func() { + cst.close() + }) return cst } // Testing the newClientServerTest helper itself. func TestNewClientServerTest(t *testing.T) { + run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testNewClientServerTest(t *testing.T, mode testMode) { var got struct { sync.Mutex - log []string + proto string + hasTLS bool } h := HandlerFunc(func(w ResponseWriter, r *Request) { got.Lock() defer got.Unlock() - got.log = append(got.log, r.Proto) + got.proto = r.Proto + got.hasTLS = r.TLS != nil }) - for _, v := range [2]bool{false, true} { - cst := newClientServerTest(t, v, h) - if _, err := cst.c.Head(cst.ts.URL); err != nil { - t.Fatal(err) - } - cst.close() + cst := newClientServerTest(t, mode, h) + if _, err := cst.c.Head(cst.ts.URL); err != nil { + t.Fatal(err) + } + var wantProto string + var wantTLS bool + switch mode { + case http1Mode: + wantProto = "HTTP/1.1" + wantTLS = false + case https1Mode: + wantProto = "HTTP/1.1" + wantTLS = true + case http2Mode: + wantProto = "HTTP/2.0" + wantTLS = true + } + if got.proto != wantProto { + t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) } - got.Lock() // no need to unlock - if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) { - t.Errorf("got %q; want %q", got.log, want) + if got.hasTLS != wantTLS { + t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) } } -func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) } -func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) } - -func testChunkedResponseHeaders(t *testing.T, h2 bool) { - defer afterTest(t) +func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } +func testChunkedResponseHeaders(t *testing.T, mode testMode) { log.SetOutput(io.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted w.(Flusher).Flush() fmt.Fprintf(w, "I am a chunked response.") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -172,7 +258,7 @@ func testChunkedResponseHeaders(t *testing.T, h2 bool) { t.Errorf("expected ContentLength of %d; got %d", e, g) } wantTE := []string{"chunked"} - if h2 { + if mode == http2Mode { wantTE = nil } if !reflect.DeepEqual(res.TransferEncoding, wantTE) { @@ -204,9 +290,9 @@ func (tt h12Compare) reqFunc() reqFunc { func (tt h12Compare) run(t *testing.T) { setParallel(t) - cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) + cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst1.close() - cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) + cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst2.close() res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) @@ -459,12 +545,9 @@ func TestH12_AutoGzip_Disabled(t *testing.T) { // Test304Responses verifies that 304s don't declare that they're // chunking in their response headers and aren't allowed to produce // output. -func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) } -func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) } - -func test304Responses(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func Test304Responses(t *testing.T) { run(t, test304Responses) } +func test304Responses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) if err != ErrBodyNotAllowed { @@ -528,20 +611,17 @@ func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int6 // Tests that closing the Request.Cancel channel also while still // reading the response body. Issue 13159. -func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) } -func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) } -func testCancelRequestMidBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } +func testCancelRequestMidBody(t *testing.T, mode testMode) { unblock := make(chan bool) didFlush := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello") w.(Flusher).Flush() didFlush <- true <-unblock io.WriteString(w, ", world.") })) - defer cst.close() defer close(unblock) req, _ := NewRequest("GET", cst.ts.URL, nil) @@ -577,12 +657,9 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) { } // Tests that clients can send trailers to a server and that the server can read them. -func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) } -func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) } - -func testTrailersClientToServer(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } +func testTrailersClientToServer(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var decl []string for k := range r.Trailer { decl = append(decl, k) @@ -605,7 +682,6 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { r.Trailer.Get("Client-Trailer-B")) } })) - defer cst.close() var req *Request req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( @@ -632,15 +708,20 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { } // Tests that servers send trailers to a client and that the client can read them. -func TestTrailersServerToClient_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, false) } -func TestTrailersServerToClient_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, false) } -func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) } -func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) } +func TestTrailersServerToClient(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, false) + }) +} +func TestTrailersServerToClientFlush(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, true) + }) +} -func testTrailersServerToClient(t *testing.T, h2, flush bool) { - defer afterTest(t) +func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") w.Header().Add("Trailer", "Server-Trailer-C") @@ -657,7 +738,6 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { w.Header().Set("Server-Trailer-C", "valuec") // skipping B w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -668,7 +748,7 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { "Content-Type": {"text/plain; charset=utf-8"}, } wantLen := -1 - if h2 && !flush { + if mode == http2Mode && !flush { // In HTTP/1.1, any use of trailers forces HTTP/1.1 // chunking and a flush at the first write. That's // unnecessary with HTTP/2's framing, so the server @@ -708,16 +788,12 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { } // Don't allow a Body.Read after Body.Close. Issue 13648. -func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) } -func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) } - -func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { - defer afterTest(t) +func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } +func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -729,13 +805,11 @@ func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { } } -func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) } -func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) } -func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } +func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { const reqBody = "some request body" const resBody = "some response body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var wg sync.WaitGroup wg.Add(2) didRead := make(chan bool, 1) @@ -754,7 +828,7 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { // Write in another goroutine. go func() { defer wg.Done() - if !h2 { + if mode != http2Mode { // our HTTP/1 implementation intentionally // doesn't permit writes during read (mostly // due to it being undefined); if that is ever @@ -765,7 +839,6 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { }() wg.Wait() })) - defer cst.close() req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) req.Header.Add("Expect", "100-continue") // just to complicate things res, err := cst.c.Do(req) @@ -782,15 +855,12 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { } } -func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) } -func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) } -func testConnectRequest(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } +func testConnectRequest(t *testing.T, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotc <- r })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -840,17 +910,14 @@ func testConnectRequest(t *testing.T, h2 bool) { } } -func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) } -func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) } -func testTransportUserAgent(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } +func testTransportUserAgent(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%q", r.Header["User-Agent"]) })) - defer cst.close() either := func(a, b string) string { - if h2 { + if mode == http2Mode { return b } return a @@ -901,19 +968,22 @@ func testTransportUserAgent(t *testing.T, h2 bool) { } } -func TestStarRequestFoo_h1(t *testing.T) { testStarRequest(t, "FOO", h1Mode) } -func TestStarRequestFoo_h2(t *testing.T) { testStarRequest(t, "FOO", h2Mode) } -func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) } -func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) } -func testStarRequest(t *testing.T, method string, h2 bool) { - defer afterTest(t) +func TestStarRequestMethod(t *testing.T) { + for _, method := range []string{"FOO", "OPTIONS"} { + t.Run(method, func(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testStarRequest(t, method, mode) + }) + }) + } +} +func testStarRequest(t *testing.T, method string, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("foo", "bar") gotc <- r w.(Flusher).Flush() })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -972,9 +1042,10 @@ func testStarRequest(t *testing.T, method string, h2 bool) { // Issue 13957 func TestTransportDiscardsUnneededConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) +} +func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) })) defer cst.close() @@ -1058,20 +1129,19 @@ func TestTransportDiscardsUnneededConns(t *testing.T) { } // tests that Transport doesn't retain a pointer to the provided request. -func TestTransportGCRequest_Body_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, true) } -func TestTransportGCRequest_Body_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, true) } -func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) } -func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) } -func testTransportGCRequest(t *testing.T, h2, body bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGCRequest(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) + t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) + }) +} +func testTransportGCRequest(t *testing.T, mode testMode, body bool) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.ReadAll(r.Body) if body { io.WriteString(w, "Hello.") } })) - defer cst.close() didGC := make(chan struct{}) (func() { @@ -1103,19 +1173,11 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { } } -func TestTransportRejectsInvalidHeaders_h1(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h1Mode) -} -func TestTransportRejectsInvalidHeaders_h2(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h2Mode) -} -func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } +func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Handler saw headers: %q", r.Header) }), optQuietLog) - defer cst.close() cst.tr.DisableKeepAlives = true tests := []struct { @@ -1161,27 +1223,22 @@ func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { } } -func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") } -func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") } -func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) } -func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) } -func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { - testInterruptWithPanic(t, h1Mode, ErrAbortHandler) -} -func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { - testInterruptWithPanic(t, h2Mode, ErrAbortHandler) +func TestInterruptWithPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) + t.Run("nil", func(t *testing.T) { testInterruptWithPanic(t, mode, nil) }) + t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) + }) } -func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { - setParallel(t) +func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { const msg = "hello" - defer afterTest(t) testDone := make(chan struct{}) defer close(testDone) var errorLog lockedBytesBuffer gotHeaders := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() @@ -1193,7 +1250,6 @@ func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1274,15 +1330,11 @@ func TestH12_AutoGzipWithDumpResponse(t *testing.T) { } // Issue 14607 -func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) } -func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) } -func testCloseIdleConnections(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } +func testCloseIdleConnections(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1320,15 +1372,11 @@ func (r testErrorReader) Read(p []byte) (n int, err error) { return 0, io.EOF } -func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) } -func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) } - -func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } +func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusUnauthorized) })) - defer cst.close() // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. cst.tr.ExpectContinueTimeout = 10 * time.Second @@ -1349,18 +1397,15 @@ func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { } } -func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) } -func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) } -func testServerUndeclaredTrailers(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } +func testServerUndeclaredTrailers(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Foo", "Bar") w.Header().Set("Trailer:Foo", "Baz") w.(Flusher).Flush() w.Header().Add("Trailer:Foo", "Baz2") w.Header().Set("Trailer:Bar", "Quux") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1381,8 +1426,10 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { } func TestBadResponseAfterReadingBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) +} +func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.Copy(io.Discard, r.Body) if err != nil { t.Fatal(err) @@ -1394,7 +1441,6 @@ func TestBadResponseAfterReadingBody(t *testing.T) { defer c.Close() fmt.Fprintln(c, "some bogus crap") })) - defer cst.close() closes := 0 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) @@ -1407,12 +1453,10 @@ func TestBadResponseAfterReadingBody(t *testing.T) { } } -func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) } -func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) } -func testWriteHeader0(t *testing.T, h2 bool) { - defer afterTest(t) +func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } +func testWriteHeader0(t *testing.T, mode testMode) { gotpanic := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(gotpanic) defer func() { if e := recover(); e != nil { @@ -1431,7 +1475,6 @@ func testWriteHeader0(t *testing.T, h2 bool) { }() w.WriteHeader(0) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1446,15 +1489,17 @@ func testWriteHeader0(t *testing.T, h2 bool) { // Issue 23010: don't be super strict checking WriteHeader's code if // it's not even valid to call WriteHeader then anyway. -func TestWriteHeaderNoCodeCheck_h1(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, false) } -func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) } -func TestWriteHeaderNoCodeCheck_h2(t *testing.T) { testWriteHeaderAfterWrite(t, h2Mode, false) } -func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { - setParallel(t) - defer afterTest(t) - +func TestWriteHeaderNoCodeCheck(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testWriteHeaderAfterWrite(t, mode, false) + }) +} +func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { + testWriteHeaderAfterWrite(t, http1Mode, true) +} +func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { var errorLog lockedBytesBuffer - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if hijack { conn, _, _ := w.(Hijacker).Hijack() defer conn.Close() @@ -1470,7 +1515,6 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1485,7 +1529,7 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } // Also check the stderr output: - if h2 { + if mode == http2Mode { // TODO: also emit this log message for HTTP/2? // We historically haven't, so don't check. return @@ -1501,14 +1545,14 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } func TestBidiStreamReverseProxy(t *testing.T) { - setParallel(t) - defer afterTest(t) - backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) +} +func testBidiStreamReverseProxy(t *testing.T, mode testMode) { + backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if _, err := io.Copy(w, r.Body); err != nil { log.Printf("bidi backend copy: %v", err) } })) - defer backend.close() backURL, err := url.Parse(backend.ts.URL) if err != nil { @@ -1516,10 +1560,9 @@ func TestBidiStreamReverseProxy(t *testing.T) { } rp := httputil.NewSingleHostReverseProxy(backURL) rp.Transport = backend.tr - proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { rp.ServeHTTP(w, r) })) - defer proxy.close() bodyRes := make(chan any, 1) // error or hash.Hash pr, pw := io.Pipe() @@ -1586,15 +1629,10 @@ func TestH12_WebSocketUpgrade(t *testing.T) { }.run(t) } -func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) } -func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) } - -func testIdentityTransferEncoding(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } +func testIdentityTransferEncoding(t *testing.T, mode testMode) { const body = "body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotBody, _ := io.ReadAll(r.Body) if got, want := string(gotBody), body; got != want { t.Errorf("got request body = %q; want %q", got, want) @@ -1604,7 +1642,6 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { w.(Flusher).Flush() io.WriteString(w, body) })) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) res, err := cst.c.Do(req) if err != nil { @@ -1620,14 +1657,11 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { } } -func TestEarlyHintsRequest_h1(t *testing.T) { testEarlyHintsRequest(t, h1Mode) } -func TestEarlyHintsRequest_h2(t *testing.T) { testEarlyHintsRequest(t, h2Mode) } -func testEarlyHintsRequest(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } +func testEarlyHintsRequest(t *testing.T, mode testMode) { var wg sync.WaitGroup wg.Add(1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { h := w.Header() h.Add("Content-Length", "123") // must be ignored @@ -1642,7 +1676,6 @@ func testEarlyHintsRequest(t *testing.T, h2 bool) { w.Write([]byte("Hello")) })) - defer cst.close() checkLinkHeaders := func(t *testing.T, expected, got []string) { t.Helper() -- cgit v1.3