diff options
Diffstat (limited to 'src/net/http/clientserver_test.go')
| -rw-r--r-- | src/net/http/clientserver_test.go | 441 |
1 files changed, 237 insertions, 204 deletions
diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index b472ca4b78..87e34cef85 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -35,8 +35,65 @@ import ( "time" ) +type testMode string + +const ( + http1Mode = testMode("h1") // HTTP/1.1 + https1Mode = testMode("https1") // HTTPS/1.1 + http2Mode = testMode("h2") // HTTP/2 +) + +type testNotParallelOpt struct{} + +var ( + testNotParallel = testNotParallelOpt{} +) + +type TBRun[T any] interface { + testing.TB + Run(string, func(T)) bool +} + +// run runs a client/server test in a variety of test configurations. +// +// Tests execute in HTTP/1.1 and HTTP/2 modes by default. +// To run in a different set of configurations, pass a []testMode option. +// +// Tests call t.Parallel() by default. +// To disable parallel execution, pass the testNotParallel option. +func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { + t.Helper() + modes := []testMode{http1Mode, http2Mode} + parallel := true + for _, opt := range opts { + switch opt := opt.(type) { + case []testMode: + modes = opt + case testNotParallelOpt: + parallel = false + default: + t.Fatalf("unknown option type %T", opt) + } + } + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + for _, mode := range modes { + t.Run(string(mode), func(t T) { + t.Helper() + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + t.Cleanup(func() { + afterTest(t) + }) + f(t, mode) + }) + } +} + type clientServerTest struct { - t *testing.T + t testing.TB h2 bool h Handler ts *httptest.Server @@ -69,11 +126,6 @@ func (t *clientServerTest) scheme() string { return "http" } -const ( - h1Mode = false - h2Mode = true -) - var optQuietLog = func(ts *httptest.Server) { ts.Config.ErrorLog = quietLog } @@ -84,23 +136,33 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) { } } -func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest { - if h2 { +// newClientServerTest creates and starts an httptest.Server. +// +// The mode parameter selects the implementation to test: +// HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use +// the 'run' function, which will start a subtests for each tested mode. +// +// The vararg opts parameter can include functions to configure the +// test server or transport. +// +// func(*httptest.Server) // run before starting the server +// func(*http.Transport) +func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { + if mode == http2Mode { CondSkipHTTP2(t) } cst := &clientServerTest{ t: t, - h2: h2, + h2: mode == http2Mode, h: h, - tr: &Transport{}, } - cst.c = &Client{Transport: cst.tr} cst.ts = httptest.NewUnstartedServer(h) + var transportFuncs []func(*Transport) for _, opt := range opts { switch opt := opt.(type) { case func(*Transport): - opt(cst.tr) + transportFuncs = append(transportFuncs, opt) case func(*httptest.Server): opt(cst.ts) default: @@ -108,60 +170,84 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientS } } - if !h2 { + switch mode { + case http1Mode: cst.ts.Start() - return cst + case https1Mode: + cst.ts.StartTLS() + case http2Mode: + ExportHttp2ConfigureServer(cst.ts.Config, nil) + cst.ts.TLS = cst.ts.Config.TLSConfig + cst.ts.StartTLS() + default: + t.Fatalf("unknown test mode %v", mode) } - ExportHttp2ConfigureServer(cst.ts.Config, nil) - cst.ts.TLS = cst.ts.Config.TLSConfig - cst.ts.StartTLS() - - cst.tr.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, + cst.c = cst.ts.Client() + cst.tr = cst.c.Transport.(*Transport) + if mode == http2Mode { + if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { + t.Fatal(err) + } } - if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { - t.Fatal(err) + for _, f := range transportFuncs { + f(cst.tr) } + t.Cleanup(func() { + cst.close() + }) return cst } // Testing the newClientServerTest helper itself. func TestNewClientServerTest(t *testing.T) { + run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testNewClientServerTest(t *testing.T, mode testMode) { var got struct { sync.Mutex - log []string + proto string + hasTLS bool } h := HandlerFunc(func(w ResponseWriter, r *Request) { got.Lock() defer got.Unlock() - got.log = append(got.log, r.Proto) + got.proto = r.Proto + got.hasTLS = r.TLS != nil }) - for _, v := range [2]bool{false, true} { - cst := newClientServerTest(t, v, h) - if _, err := cst.c.Head(cst.ts.URL); err != nil { - t.Fatal(err) - } - cst.close() + cst := newClientServerTest(t, mode, h) + if _, err := cst.c.Head(cst.ts.URL); err != nil { + t.Fatal(err) + } + var wantProto string + var wantTLS bool + switch mode { + case http1Mode: + wantProto = "HTTP/1.1" + wantTLS = false + case https1Mode: + wantProto = "HTTP/1.1" + wantTLS = true + case http2Mode: + wantProto = "HTTP/2.0" + wantTLS = true } - got.Lock() // no need to unlock - if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) { - t.Errorf("got %q; want %q", got.log, want) + if got.proto != wantProto { + t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) + } + if got.hasTLS != wantTLS { + t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) } } -func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) } -func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) } - -func testChunkedResponseHeaders(t *testing.T, h2 bool) { - defer afterTest(t) +func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } +func testChunkedResponseHeaders(t *testing.T, mode testMode) { log.SetOutput(io.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted w.(Flusher).Flush() fmt.Fprintf(w, "I am a chunked response.") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -172,7 +258,7 @@ func testChunkedResponseHeaders(t *testing.T, h2 bool) { t.Errorf("expected ContentLength of %d; got %d", e, g) } wantTE := []string{"chunked"} - if h2 { + if mode == http2Mode { wantTE = nil } if !reflect.DeepEqual(res.TransferEncoding, wantTE) { @@ -204,9 +290,9 @@ func (tt h12Compare) reqFunc() reqFunc { func (tt h12Compare) run(t *testing.T) { setParallel(t) - cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) + cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst1.close() - cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) + cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst2.close() res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) @@ -459,12 +545,9 @@ func TestH12_AutoGzip_Disabled(t *testing.T) { // Test304Responses verifies that 304s don't declare that they're // chunking in their response headers and aren't allowed to produce // output. -func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) } -func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) } - -func test304Responses(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func Test304Responses(t *testing.T) { run(t, test304Responses) } +func test304Responses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) if err != ErrBodyNotAllowed { @@ -528,20 +611,17 @@ func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int6 // Tests that closing the Request.Cancel channel also while still // reading the response body. Issue 13159. -func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) } -func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) } -func testCancelRequestMidBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } +func testCancelRequestMidBody(t *testing.T, mode testMode) { unblock := make(chan bool) didFlush := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello") w.(Flusher).Flush() didFlush <- true <-unblock io.WriteString(w, ", world.") })) - defer cst.close() defer close(unblock) req, _ := NewRequest("GET", cst.ts.URL, nil) @@ -577,12 +657,9 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) { } // Tests that clients can send trailers to a server and that the server can read them. -func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) } -func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) } - -func testTrailersClientToServer(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } +func testTrailersClientToServer(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var decl []string for k := range r.Trailer { decl = append(decl, k) @@ -605,7 +682,6 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { r.Trailer.Get("Client-Trailer-B")) } })) - defer cst.close() var req *Request req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( @@ -632,15 +708,20 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { } // Tests that servers send trailers to a client and that the client can read them. -func TestTrailersServerToClient_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, false) } -func TestTrailersServerToClient_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, false) } -func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) } -func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) } +func TestTrailersServerToClient(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, false) + }) +} +func TestTrailersServerToClientFlush(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, true) + }) +} -func testTrailersServerToClient(t *testing.T, h2, flush bool) { - defer afterTest(t) +func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") w.Header().Add("Trailer", "Server-Trailer-C") @@ -657,7 +738,6 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { w.Header().Set("Server-Trailer-C", "valuec") // skipping B w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -668,7 +748,7 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { "Content-Type": {"text/plain; charset=utf-8"}, } wantLen := -1 - if h2 && !flush { + if mode == http2Mode && !flush { // In HTTP/1.1, any use of trailers forces HTTP/1.1 // chunking and a flush at the first write. That's // unnecessary with HTTP/2's framing, so the server @@ -708,16 +788,12 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { } // Don't allow a Body.Read after Body.Close. Issue 13648. -func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) } -func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) } - -func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { - defer afterTest(t) +func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } +func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -729,13 +805,11 @@ func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { } } -func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) } -func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) } -func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } +func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { const reqBody = "some request body" const resBody = "some response body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var wg sync.WaitGroup wg.Add(2) didRead := make(chan bool, 1) @@ -754,7 +828,7 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { // Write in another goroutine. go func() { defer wg.Done() - if !h2 { + if mode != http2Mode { // our HTTP/1 implementation intentionally // doesn't permit writes during read (mostly // due to it being undefined); if that is ever @@ -765,7 +839,6 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { }() wg.Wait() })) - defer cst.close() req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) req.Header.Add("Expect", "100-continue") // just to complicate things res, err := cst.c.Do(req) @@ -782,15 +855,12 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { } } -func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) } -func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) } -func testConnectRequest(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } +func testConnectRequest(t *testing.T, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotc <- r })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -840,17 +910,14 @@ func testConnectRequest(t *testing.T, h2 bool) { } } -func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) } -func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) } -func testTransportUserAgent(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } +func testTransportUserAgent(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%q", r.Header["User-Agent"]) })) - defer cst.close() either := func(a, b string) string { - if h2 { + if mode == http2Mode { return b } return a @@ -901,19 +968,22 @@ func testTransportUserAgent(t *testing.T, h2 bool) { } } -func TestStarRequestFoo_h1(t *testing.T) { testStarRequest(t, "FOO", h1Mode) } -func TestStarRequestFoo_h2(t *testing.T) { testStarRequest(t, "FOO", h2Mode) } -func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) } -func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) } -func testStarRequest(t *testing.T, method string, h2 bool) { - defer afterTest(t) +func TestStarRequestMethod(t *testing.T) { + for _, method := range []string{"FOO", "OPTIONS"} { + t.Run(method, func(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testStarRequest(t, method, mode) + }) + }) + } +} +func testStarRequest(t *testing.T, method string, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("foo", "bar") gotc <- r w.(Flusher).Flush() })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -972,9 +1042,10 @@ func testStarRequest(t *testing.T, method string, h2 bool) { // Issue 13957 func TestTransportDiscardsUnneededConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) +} +func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) })) defer cst.close() @@ -1058,20 +1129,19 @@ func TestTransportDiscardsUnneededConns(t *testing.T) { } // tests that Transport doesn't retain a pointer to the provided request. -func TestTransportGCRequest_Body_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, true) } -func TestTransportGCRequest_Body_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, true) } -func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) } -func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) } -func testTransportGCRequest(t *testing.T, h2, body bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGCRequest(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) + t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) + }) +} +func testTransportGCRequest(t *testing.T, mode testMode, body bool) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.ReadAll(r.Body) if body { io.WriteString(w, "Hello.") } })) - defer cst.close() didGC := make(chan struct{}) (func() { @@ -1103,19 +1173,11 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { } } -func TestTransportRejectsInvalidHeaders_h1(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h1Mode) -} -func TestTransportRejectsInvalidHeaders_h2(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h2Mode) -} -func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } +func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Handler saw headers: %q", r.Header) }), optQuietLog) - defer cst.close() cst.tr.DisableKeepAlives = true tests := []struct { @@ -1161,27 +1223,22 @@ func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { } } -func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") } -func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") } -func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) } -func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) } -func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { - testInterruptWithPanic(t, h1Mode, ErrAbortHandler) -} -func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { - testInterruptWithPanic(t, h2Mode, ErrAbortHandler) +func TestInterruptWithPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) + t.Run("nil", func(t *testing.T) { testInterruptWithPanic(t, mode, nil) }) + t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) + }) } -func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { - setParallel(t) +func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { const msg = "hello" - defer afterTest(t) testDone := make(chan struct{}) defer close(testDone) var errorLog lockedBytesBuffer gotHeaders := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() @@ -1193,7 +1250,6 @@ func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1274,15 +1330,11 @@ func TestH12_AutoGzipWithDumpResponse(t *testing.T) { } // Issue 14607 -func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) } -func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) } -func testCloseIdleConnections(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } +func testCloseIdleConnections(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1320,15 +1372,11 @@ func (r testErrorReader) Read(p []byte) (n int, err error) { return 0, io.EOF } -func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) } -func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) } - -func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } +func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusUnauthorized) })) - defer cst.close() // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. cst.tr.ExpectContinueTimeout = 10 * time.Second @@ -1349,18 +1397,15 @@ func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { } } -func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) } -func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) } -func testServerUndeclaredTrailers(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } +func testServerUndeclaredTrailers(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Foo", "Bar") w.Header().Set("Trailer:Foo", "Baz") w.(Flusher).Flush() w.Header().Add("Trailer:Foo", "Baz2") w.Header().Set("Trailer:Bar", "Quux") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1381,8 +1426,10 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { } func TestBadResponseAfterReadingBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) +} +func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.Copy(io.Discard, r.Body) if err != nil { t.Fatal(err) @@ -1394,7 +1441,6 @@ func TestBadResponseAfterReadingBody(t *testing.T) { defer c.Close() fmt.Fprintln(c, "some bogus crap") })) - defer cst.close() closes := 0 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) @@ -1407,12 +1453,10 @@ func TestBadResponseAfterReadingBody(t *testing.T) { } } -func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) } -func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) } -func testWriteHeader0(t *testing.T, h2 bool) { - defer afterTest(t) +func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } +func testWriteHeader0(t *testing.T, mode testMode) { gotpanic := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(gotpanic) defer func() { if e := recover(); e != nil { @@ -1431,7 +1475,6 @@ func testWriteHeader0(t *testing.T, h2 bool) { }() w.WriteHeader(0) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1446,15 +1489,17 @@ func testWriteHeader0(t *testing.T, h2 bool) { // Issue 23010: don't be super strict checking WriteHeader's code if // it's not even valid to call WriteHeader then anyway. -func TestWriteHeaderNoCodeCheck_h1(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, false) } -func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) } -func TestWriteHeaderNoCodeCheck_h2(t *testing.T) { testWriteHeaderAfterWrite(t, h2Mode, false) } -func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { - setParallel(t) - defer afterTest(t) - +func TestWriteHeaderNoCodeCheck(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testWriteHeaderAfterWrite(t, mode, false) + }) +} +func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { + testWriteHeaderAfterWrite(t, http1Mode, true) +} +func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { var errorLog lockedBytesBuffer - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if hijack { conn, _, _ := w.(Hijacker).Hijack() defer conn.Close() @@ -1470,7 +1515,6 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1485,7 +1529,7 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } // Also check the stderr output: - if h2 { + if mode == http2Mode { // TODO: also emit this log message for HTTP/2? // We historically haven't, so don't check. return @@ -1501,14 +1545,14 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } func TestBidiStreamReverseProxy(t *testing.T) { - setParallel(t) - defer afterTest(t) - backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) +} +func testBidiStreamReverseProxy(t *testing.T, mode testMode) { + backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if _, err := io.Copy(w, r.Body); err != nil { log.Printf("bidi backend copy: %v", err) } })) - defer backend.close() backURL, err := url.Parse(backend.ts.URL) if err != nil { @@ -1516,10 +1560,9 @@ func TestBidiStreamReverseProxy(t *testing.T) { } rp := httputil.NewSingleHostReverseProxy(backURL) rp.Transport = backend.tr - proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { rp.ServeHTTP(w, r) })) - defer proxy.close() bodyRes := make(chan any, 1) // error or hash.Hash pr, pw := io.Pipe() @@ -1586,15 +1629,10 @@ func TestH12_WebSocketUpgrade(t *testing.T) { }.run(t) } -func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) } -func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) } - -func testIdentityTransferEncoding(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } +func testIdentityTransferEncoding(t *testing.T, mode testMode) { const body = "body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotBody, _ := io.ReadAll(r.Body) if got, want := string(gotBody), body; got != want { t.Errorf("got request body = %q; want %q", got, want) @@ -1604,7 +1642,6 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { w.(Flusher).Flush() io.WriteString(w, body) })) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) res, err := cst.c.Do(req) if err != nil { @@ -1620,14 +1657,11 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { } } -func TestEarlyHintsRequest_h1(t *testing.T) { testEarlyHintsRequest(t, h1Mode) } -func TestEarlyHintsRequest_h2(t *testing.T) { testEarlyHintsRequest(t, h2Mode) } -func testEarlyHintsRequest(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } +func testEarlyHintsRequest(t *testing.T, mode testMode) { var wg sync.WaitGroup wg.Add(1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { h := w.Header() h.Add("Content-Length", "123") // must be ignored @@ -1642,7 +1676,6 @@ func testEarlyHintsRequest(t *testing.T, h2 bool) { w.Write([]byte("Hello")) })) - defer cst.close() checkLinkHeaders := func(t *testing.T, expected, got []string) { t.Helper() |
