aboutsummaryrefslogtreecommitdiff
path: root/src/net/http/serve_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http/serve_test.go')
-rw-r--r--src/net/http/serve_test.go1180
1 files changed, 496 insertions, 684 deletions
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
index 4fadc56c9e..a93f6eff1b 100644
--- a/src/net/http/serve_test.go
+++ b/src/net/http/serve_test.go
@@ -246,15 +246,13 @@ var vtests = []struct {
{"http://someHost.com/someDir", "/someDir/"},
}
-func TestHostHandlers(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
+func testHostHandlers(t *testing.T, mode testMode) {
mux := NewServeMux()
for _, h := range handlers {
mux.Handle(h.pattern, stringHandler(h.msg))
}
- ts := httptest.NewServer(mux)
- defer ts.Close()
+ ts := newClientServerTest(t, mode, mux).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
@@ -487,9 +485,9 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
// properly sets the query string in the redirect URL.
// See Issue 17841.
func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
-
+ run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
+}
+func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
writeBackQuery := func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%s", r.URL.RawQuery)
}
@@ -502,8 +500,7 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
})
- ts := httptest.NewServer(mux)
- defer ts.Close()
+ ts := newClientServerTest(t, mode, mux).ts
tests := [...]struct {
path string
@@ -546,7 +543,6 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
setParallel(t)
- defer afterTest(t)
mux := NewServeMux()
mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
@@ -578,9 +574,6 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
{"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
}
- ts := httptest.NewServer(mux)
- defer ts.Close()
-
for i, tt := range tests {
req, _ := NewRequest(tt.method, tt.url, nil)
w := httptest.NewRecorder()
@@ -602,13 +595,10 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
}
}
-func TestShouldRedirectConcurrency(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
-
+func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
+func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
mux := NewServeMux()
- ts := httptest.NewServer(mux)
- defer ts.Close()
+ newClientServerTest(t, mode, mux)
mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
}
@@ -656,13 +646,12 @@ func benchmarkServeMux(b *testing.B, runHandler bool) {
}
}
-func TestServerTimeouts(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
+func testServerTimeouts(t *testing.T, mode testMode) {
// Try three times, with increasing timeouts.
tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
for i, timeout := range tries {
- err := testServerTimeouts(timeout)
+ err := testServerTimeoutsWithTimeout(t, timeout, mode)
if err == nil {
return
}
@@ -674,16 +663,15 @@ func TestServerTimeouts(t *testing.T) {
t.Fatal("all attempts failed")
}
-func testServerTimeouts(timeout time.Duration) error {
+func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
reqNum := 0
- ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
reqNum++
fmt.Fprintf(res, "req=%d", reqNum)
- }))
- ts.Config.ReadTimeout = timeout
- ts.Config.WriteTimeout = timeout
- ts.Start()
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ ts.Config.ReadTimeout = timeout
+ ts.Config.WriteTimeout = timeout
+ }).ts
// Hit the HTTP server successfully.
c := ts.Client()
@@ -749,22 +737,20 @@ func testServerTimeouts(timeout time.Duration) error {
}
// Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437)
-func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) {
+func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
+ run(t, testWriteDeadlineExtendedOnNewRequest)
+}
+func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping in short mode")
}
- setParallel(t)
- defer afterTest(t)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {}))
- ts.Config.WriteTimeout = 250 * time.Millisecond
- ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
- ts.StartTLS()
- defer ts.Close()
+ ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
+ func(ts *httptest.Server) {
+ ts.Config.WriteTimeout = 250 * time.Millisecond
+ },
+ ).ts
c := ts.Client()
- if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
- t.Fatal(err)
- }
for i := 1; i <= 3; i++ {
req, err := NewRequest("GET", ts.URL, nil)
@@ -785,9 +771,6 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) {
t.Fatalf("http2 Get #%d: %v", i, err)
}
r.Body.Close()
- if r.ProtoMajor != 2 {
- t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto)
- }
time.Sleep(ts.Config.WriteTimeout / 2)
}
}
@@ -810,33 +793,31 @@ func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
}
// Test that the HTTP/2 server RSTs stream on slow write.
-func TestHTTP2WriteDeadlineEnforcedPerStream(t *testing.T) {
+func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
setParallel(t)
- defer afterTest(t)
- tryTimeouts(t, testHTTP2WriteDeadlineEnforcedPerStream)
+ run(t, func(t *testing.T, mode testMode) {
+ tryTimeouts(t, func(timeout time.Duration) error {
+ return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
+ })
+ })
}
-func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error {
+func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
reqNum := 0
- ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
reqNum++
if reqNum == 1 {
return // first request succeeds
}
time.Sleep(timeout) // second request times out
- }))
- ts.Config.WriteTimeout = timeout / 2
- ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
- ts.StartTLS()
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ ts.Config.WriteTimeout = timeout / 2
+ }).ts
c := ts.Client()
- if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
- return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err)
- }
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
@@ -844,12 +825,9 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error {
}
r, err := c.Do(req)
if err != nil {
- return fmt.Errorf("http2 Get #1: %v", err)
+ return fmt.Errorf("Get #1: %v", err)
}
r.Body.Close()
- if r.ProtoMajor != 2 {
- return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto)
- }
req, err = NewRequest("GET", ts.URL, nil)
if err != nil {
@@ -858,45 +836,42 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error {
r, err = c.Do(req)
if err == nil {
r.Body.Close()
- if r.ProtoMajor != 2 {
- return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto)
- }
- return fmt.Errorf("http2 Get #2 expected error, got nil")
+ return fmt.Errorf("Get #2 expected error, got nil")
}
- expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3
- if !strings.Contains(err.Error(), expected) {
- return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
+ if mode == http2Mode {
+ expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3
+ if !strings.Contains(err.Error(), expected) {
+ return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
+ }
}
return nil
}
// Test that the HTTP/2 server does not send RST when WriteDeadline not set.
-func TestHTTP2NoWriteDeadline(t *testing.T) {
+func TestNoWriteDeadline(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
setParallel(t)
defer afterTest(t)
- tryTimeouts(t, testHTTP2NoWriteDeadline)
+ run(t, func(t *testing.T, mode testMode) {
+ tryTimeouts(t, func(timeout time.Duration) error {
+ return testNoWriteDeadline(t, mode, timeout)
+ })
+ })
}
-func testHTTP2NoWriteDeadline(timeout time.Duration) error {
+func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
reqNum := 0
- ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
reqNum++
if reqNum == 1 {
return // first request succeeds
}
time.Sleep(timeout) // second request timesout
- }))
- ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
- ts.StartTLS()
- defer ts.Close()
+ })).ts
c := ts.Client()
- if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
- return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err)
- }
for i := 0; i < 2; i++ {
req, err := NewRequest("GET", ts.URL, nil)
@@ -905,12 +880,9 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error {
}
r, err := c.Do(req)
if err != nil {
- return fmt.Errorf("http2 Get #%d: %v", i, err)
+ return fmt.Errorf("Get #%d: %v", i, err)
}
r.Body.Close()
- if r.ProtoMajor != 2 {
- return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto)
- }
}
return nil
}
@@ -918,15 +890,14 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error {
// golang.org/issue/4741 -- setting only a write timeout that triggers
// shouldn't cause a handler to block forever on reads (next HTTP
// request) that will never happen.
-func TestOnlyWriteTimeout(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
+func testOnlyWriteTimeout(t *testing.T, mode testMode) {
var (
mu sync.RWMutex
conn net.Conn
)
var afterTimeoutErrc = make(chan error, 1)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
buf := make([]byte, 512<<10)
_, err := w.Write(buf)
if err != nil {
@@ -942,10 +913,9 @@ func TestOnlyWriteTimeout(t *testing.T) {
conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
_, err = w.Write(buf)
afterTimeoutErrc <- err
- }))
- ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
- ts.Start()
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
+ }).ts
c := ts.Client()
@@ -992,9 +962,12 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) {
}
// TestIdentityResponse verifies that a handler can unset
-func TestIdentityResponse(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
+func testIdentityResponse(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("https://go.dev/issue/56019")
+ }
+
handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
rw.Header().Set("Content-Length", "3")
rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
@@ -1012,9 +985,7 @@ func TestIdentityResponse(t *testing.T) {
}
})
- ts := httptest.NewServer(handler)
- defer ts.Close()
-
+ ts := newClientServerTest(t, mode, handler).ts
c := ts.Client()
// Note: this relies on the assumption (which is true) that
@@ -1048,6 +1019,10 @@ func TestIdentityResponse(t *testing.T) {
}
res.Body.Close()
+ if mode != http1Mode {
+ return
+ }
+
// Verify that the connection is closed when the declared Content-Length
// is larger than what the handler wrote.
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -1070,9 +1045,7 @@ func TestIdentityResponse(t *testing.T) {
func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
setParallel(t)
- defer afterTest(t)
- s := httptest.NewServer(h)
- defer s.Close()
+ s := newClientServerTest(t, http1Mode, h).ts
conn, err := net.Dial("tcp", s.Listener.Addr().String())
if err != nil {
@@ -1114,9 +1087,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
setParallel(t)
- defer afterTest(t)
- ts := httptest.NewServer(handler)
- defer ts.Close()
+ ts := newClientServerTest(t, http1Mode, handler).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
@@ -1192,14 +1163,12 @@ func TestHTTP10KeepAlive304Response(t *testing.T) {
}
// Issue 15703
-func TestKeepAliveFinalChunkWithEOF(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
- cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
+func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.(Flusher).Flush() // force chunked encoding
w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
}))
- defer cst.close()
type data struct {
Addr string
}
@@ -1222,16 +1191,11 @@ func TestKeepAliveFinalChunkWithEOF(t *testing.T) {
}
}
-func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) }
-func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) }
-
-func testSetsRemoteAddr(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
+func testSetsRemoteAddr(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%s", r.RemoteAddr)
}))
- defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
@@ -1276,17 +1240,18 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
// Issue 12943
func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
- defer afterTest(t)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
- }))
+ run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
+}
+func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
conns := make(chan net.Conn)
- ts.Listener = &blockingRemoteAddrListener{
- Listener: ts.Listener,
- conns: conns,
- }
- ts.Start()
- defer ts.Close()
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
+ }), func(ts *httptest.Server) {
+ ts.Listener = &blockingRemoteAddrListener{
+ Listener: ts.Listener,
+ conns: conns,
+ }
+ }).ts
c := ts.Client()
c.Timeout = time.Second
@@ -1351,13 +1316,9 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
// TestHeadResponses verifies that all MIME type sniffing and Content-Length
// counting of GET requests also happens on HEAD requests.
-func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) }
-func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) }
-
-func testHeadResponses(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
+func testHeadResponses(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("<html>"))
if err != nil {
t.Errorf("ResponseWriter.Write: %v", err)
@@ -1369,7 +1330,6 @@ func testHeadResponses(t *testing.T, h2 bool) {
t.Errorf("Copy(ResponseWriter, ...): %v", err)
}
}))
- defer cst.close()
res, err := cst.c.Head(cst.ts.URL)
if err != nil {
t.Error(err)
@@ -1393,14 +1353,16 @@ func testHeadResponses(t *testing.T, h2 bool) {
}
func TestTLSHandshakeTimeout(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
+}
+func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
errc := make(chanWriter, 10) // but only expecting 1
- ts.Config.ReadTimeout = 250 * time.Millisecond
- ts.Config.ErrorLog = log.New(errc, "", 0)
- ts.StartTLS()
- defer ts.Close()
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
+ func(ts *httptest.Server) {
+ ts.Config.ReadTimeout = 250 * time.Millisecond
+ ts.Config.ErrorLog = log.New(errc, "", 0)
+ },
+ ).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
@@ -1423,19 +1385,18 @@ func TestTLSHandshakeTimeout(t *testing.T) {
}
}
-func TestTLSServer(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
- ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
+func testTLSServer(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
if r.TLS != nil {
w.Header().Set("X-TLS-Set", "true")
if r.TLS.HandshakeComplete {
w.Header().Set("X-TLS-HandshakeComplete", "true")
}
}
- }))
- ts.Config.ErrorLog = log.New(io.Discard, "", 0)
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(io.Discard, "", 0)
+ }).ts
// Connect an idle TCP connection to this server before we run
// our real tests. This idle connection used to block forever
@@ -1528,14 +1489,15 @@ func TestServeTLS(t *testing.T) {
// Test that the HTTPS server nicely rejects plaintext HTTP/1.x requests.
func TestTLSServerRejectHTTPRequests(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
- ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
+}
+func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
t.Error("unexpected HTTPS request")
- }))
- var errBuf bytes.Buffer
- ts.Config.ErrorLog = log.New(&errBuf, "", 0)
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ var errBuf bytes.Buffer
+ ts.Config.ErrorLog = log.New(&errBuf, "", 0)
+ }).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
@@ -1727,11 +1689,9 @@ var serverExpectTests = []serverExpectTest{
// Tests that the server responds to the "Expect" request header
// correctly.
-// http2 test: TestServer_Response_Automatic100Continue
-func TestServerExpect(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
+func testServerExpect(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
// Note using r.FormValue("readbody") because for POST
// requests that would read from r.Body, which we only
// conditionally want to do.
@@ -1741,8 +1701,7 @@ func TestServerExpect(t *testing.T) {
} else {
w.WriteHeader(StatusUnauthorized)
}
- }))
- defer ts.Close()
+ })).ts
runTest := func(test serverExpectTest) {
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -2287,11 +2246,8 @@ func (c cancelableTimeoutContext) Err() error {
return nil
}
-func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
-func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
-func testTimeoutHandler(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
+func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
+func testTimeoutHandler(t *testing.T, mode testMode) {
sendHi := make(chan bool, 1)
writeErrors := make(chan error, 1)
sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -2301,8 +2257,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
})
ctx, cancel := context.WithCancel(context.Background())
h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
- cst := newClientServerTest(t, h2, h)
- defer cst.close()
+ cst := newClientServerTest(t, mode, h)
// Succeed without timing out:
sendHi <- true
@@ -2348,10 +2303,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
}
// See issues 8209 and 8414.
-func TestTimeoutHandlerRace(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
-
+func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
+func testTimeoutHandlerRace(t *testing.T, mode testMode) {
delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
ms, _ := strconv.Atoi(r.URL.Path[1:])
if ms == 0 {
@@ -2363,8 +2316,7 @@ func TestTimeoutHandlerRace(t *testing.T) {
}
})
- ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
- defer ts.Close()
+ ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
c := ts.Client()
@@ -2393,16 +2345,13 @@ func TestTimeoutHandlerRace(t *testing.T) {
// See issues 8209 and 8414.
// Both issues involved panics in the implementation of TimeoutHandler.
-func TestTimeoutHandlerRaceHeader(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
-
+func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
+func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
w.WriteHeader(204)
})
- ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, ""))
- defer ts.Close()
+ ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
var wg sync.WaitGroup
gate := make(chan bool, 50)
@@ -2433,9 +2382,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) {
}
// Issue 9162
-func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
+func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
sendHi := make(chan bool, 1)
writeErrors := make(chan error, 1)
sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -2446,8 +2394,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
})
ctx, cancel := context.WithCancel(context.Background())
h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
- cst := newClientServerTest(t, h1Mode, h)
- defer cst.close()
+ cst := newClientServerTest(t, mode, h)
// Succeed without timing out:
sendHi <- true
@@ -2491,15 +2438,17 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
// Issue 14568.
func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
+ run(t, testTimeoutHandlerStartTimerWhenServing)
+}
+func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping sleeping test in -short mode")
}
- defer afterTest(t)
var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
w.WriteHeader(StatusNoContent)
}
timeout := 300 * time.Millisecond
- ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
+ ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
defer ts.Close()
c := ts.Client()
@@ -2518,9 +2467,8 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
}
}
-func TestTimeoutHandlerContextCanceled(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
+func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
writeErrors := make(chan error, 1)
sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Type", "text/plain")
@@ -2540,7 +2488,7 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
h := NewTestTimeoutHandler(sayHi, ctx)
- cst := newClientServerTest(t, h1Mode, h)
+ cst := newClientServerTest(t, mode, h)
defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
@@ -2560,15 +2508,13 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) {
}
// https://golang.org/issue/15948
-func TestTimeoutHandlerEmptyResponse(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
+func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
// No response.
}
timeout := 300 * time.Millisecond
- ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
- defer ts.Close()
+ ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
c := ts.Client()
@@ -2587,7 +2533,9 @@ func TestTimeoutHandlerPanicRecovery(t *testing.T) {
wrapper := func(h Handler) Handler {
return TimeoutHandler(h, time.Second, "")
}
- testHandlerPanic(t, false, false, wrapper, "intentional death for testing")
+ run(t, func(t *testing.T, mode testMode) {
+ testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
+ }, testNotParallel)
}
func TestRedirectBadPath(t *testing.T) {
@@ -2705,17 +2653,10 @@ func TestRedirectContentTypeAndBody(t *testing.T) {
// connection immediately. But when it re-uses the connection, it typically closes
// the previous request's body, which is not optimal for zero-lengthed bodies,
// as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF.
-func TestZeroLengthPostAndResponse_h1(t *testing.T) {
- testZeroLengthPostAndResponse(t, h1Mode)
-}
-func TestZeroLengthPostAndResponse_h2(t *testing.T) {
- testZeroLengthPostAndResponse(t, h2Mode)
-}
+func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
-func testZeroLengthPostAndResponse(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) {
+func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
all, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("handler ReadAll: %v", err)
@@ -2725,7 +2666,6 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) {
}
rw.Header().Set("Content-Length", "0")
}))
- defer cst.close()
req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
if err != nil {
@@ -2752,23 +2692,26 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) {
}
}
-func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil, nil) }
-func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil, nil) }
-
-func TestHandlerPanic_h1(t *testing.T) {
- testHandlerPanic(t, false, h1Mode, nil, "intentional death for testing")
+func TestHandlerPanicNil(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testHandlerPanic(t, false, mode, nil, nil)
+ }, testNotParallel)
}
-func TestHandlerPanic_h2(t *testing.T) {
- testHandlerPanic(t, false, h2Mode, nil, "intentional death for testing")
+
+func TestHandlerPanic(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testHandlerPanic(t, false, mode, nil, "intentional death for testing")
+ }, testNotParallel)
}
func TestHandlerPanicWithHijack(t *testing.T) {
// Only testing HTTP/1, and our http2 server doesn't support hijacking.
- testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing")
+ run(t, func(t *testing.T, mode testMode) {
+ testHandlerPanic(t, true, mode, nil, "intentional death for testing")
+ }, testNotParallel, []testMode{http1Mode})
}
-func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) {
- defer afterTest(t)
+func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
// Unlike the other tests that set the log output to io.Discard
// to quiet the output, this test uses a pipe. The pipe serves three
// purposes:
@@ -2803,8 +2746,7 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) H
if wrapper != nil {
handler = wrapper(handler)
}
- cst := newClientServerTest(t, h2, handler)
- defer cst.close()
+ cst := newClientServerTest(t, mode, handler)
// Do a blocking read on the log output pipe so its logging
// doesn't bleed into the next test. But wait only 5 seconds
@@ -2847,9 +2789,11 @@ func (w terrorWriter) Write(p []byte) (int, error) {
// Issue 16456: allow writing 0 bytes on hijacked conn to test hijack
// without any log spam.
func TestServerWriteHijackZeroBytes(t *testing.T) {
- defer afterTest(t)
+ run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
+}
+func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
done := make(chan struct{})
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
defer close(done)
w.(Flusher).Flush()
conn, _, err := w.(Hijacker).Hijack()
@@ -2862,10 +2806,9 @@ func TestServerWriteHijackZeroBytes(t *testing.T) {
if err != ErrHijacked {
t.Errorf("Write error = %v; want ErrHijacked", err)
}
- }))
- ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
- ts.Start()
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
+ }).ts
c := ts.Client()
res, err := c.Get(ts.URL)
@@ -2880,19 +2823,23 @@ func TestServerWriteHijackZeroBytes(t *testing.T) {
}
}
-func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") }
-func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") }
-func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") }
-func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") }
+func TestServerNoDate(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testServerNoHeader(t, mode, "Date")
+ })
+}
-func testServerNoHeader(t *testing.T, h2 bool, header string) {
- setParallel(t)
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestServerContentType(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testServerNoHeader(t, mode, "Content-Type")
+ })
+}
+
+func testServerNoHeader(t *testing.T, mode testMode, header string) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header()[header] = nil
io.WriteString(w, "<html>foo</html>") // non-empty
}))
- defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
@@ -2903,15 +2850,13 @@ func testServerNoHeader(t *testing.T, h2 bool, header string) {
}
}
-func TestStripPrefix(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
+func testStripPrefix(t *testing.T, mode testMode) {
h := HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("X-Path", r.URL.Path)
w.Header().Set("X-RawPath", r.URL.RawPath)
})
- ts := httptest.NewServer(StripPrefix("/foo/bar", h))
- defer ts.Close()
+ ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
c := ts.Client()
@@ -2961,15 +2906,11 @@ func TestStripPrefixNotModifyRequest(t *testing.T) {
}
}
-func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) }
-func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) }
-func testRequestLimit(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
+func testRequestLimit(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
t.Fatalf("didn't expect to get request in Handler")
}), optQuietLog)
- defer cst.close()
req, _ := NewRequest("GET", cst.ts.URL, nil)
var bytesPerHeader = len("header12345: val12345\r\n")
for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
@@ -2979,7 +2920,7 @@ func testRequestLimit(t *testing.T, h2 bool) {
if res != nil {
defer res.Body.Close()
}
- if h2 {
+ if mode == http2Mode {
// In HTTP/2, the result depends on a race. If the client has received the
// server's SETTINGS before RoundTrip starts sending the request, then RoundTrip
// will fail with an error. Otherwise, the client should receive a 431 from the
@@ -3021,13 +2962,10 @@ func (cr countReader) Read(p []byte) (n int, err error) {
return
}
-func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) }
-func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) }
-func testRequestBodyLimit(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
+func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
+func testRequestBodyLimit(t *testing.T, mode testMode) {
const limit = 1 << 20
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
r.Body = MaxBytesReader(w, r.Body, limit)
n, err := io.Copy(io.Discard, r.Body)
if err == nil {
@@ -3044,7 +2982,6 @@ func testRequestBodyLimit(t *testing.T, h2 bool) {
t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
}
}))
- defer cst.close()
nWritten := new(int64)
req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
@@ -3068,13 +3005,12 @@ func testRequestBodyLimit(t *testing.T, h2 bool) {
// TestClientWriteShutdown tests that if the client shuts down the write
// side of their TCP connection, the server doesn't send a 400 Bad Request.
-func TestClientWriteShutdown(t *testing.T) {
+func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
+func testClientWriteShutdown(t *testing.T, mode testMode) {
if runtime.GOOS == "plan9" {
t.Skip("skipping test; see https://golang.org/issue/17906")
}
- defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
- defer ts.Close()
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
@@ -3119,12 +3055,12 @@ func TestServerBufferedChunking(t *testing.T) {
// closing the TCP connection, causing the client to get a RST.
// See https://golang.org/issue/3595
func TestServerGracefulClose(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ run(t, testServerGracefulClose, []testMode{http1Mode})
+}
+func testServerGracefulClose(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
Error(w, "bye", StatusUnauthorized)
- }))
- defer ts.Close()
+ })).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
@@ -3162,11 +3098,9 @@ func TestServerGracefulClose(t *testing.T) {
<-writeErr
}
-func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) }
-func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) }
-func testCaseSensitiveMethod(t *testing.T, h2 bool) {
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
+func testCaseSensitiveMethod(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "get" {
t.Errorf(`Got method %q; want "get"`, r.Method)
}
@@ -3187,8 +3121,10 @@ func testCaseSensitiveMethod(t *testing.T, h2 bool) {
// response, the net/http package adds a "Content-Length: 0" response
// header.
func TestContentLengthZero(t *testing.T) {
- ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {}))
- defer ts.Close()
+ run(t, testContentLengthZero, []testMode{http1Mode})
+}
+func testContentLengthZero(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -3215,15 +3151,17 @@ func TestContentLengthZero(t *testing.T) {
}
func TestCloseNotifier(t *testing.T) {
- defer afterTest(t)
+ run(t, testCloseNotifier, []testMode{http1Mode})
+}
+func testCloseNotifier(t *testing.T, mode testMode) {
gotReq := make(chan bool, 1)
sawClose := make(chan bool, 1)
- ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
gotReq <- true
cc := rw.(CloseNotifier).CloseNotify()
<-cc
sawClose <- true
- }))
+ })).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error dialing: %v", err)
@@ -3257,11 +3195,12 @@ For:
//
// Issue 13165 (where it used to deadlock), but behavior changed in Issue 23921.
func TestCloseNotifierPipelined(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+ run(t, testCloseNotifierPipelined, []testMode{http1Mode})
+}
+func testCloseNotifierPipelined(t *testing.T, mode testMode) {
gotReq := make(chan bool, 2)
sawClose := make(chan bool, 2)
- ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
gotReq <- true
cc := rw.(CloseNotifier).CloseNotify()
select {
@@ -3270,8 +3209,7 @@ func TestCloseNotifierPipelined(t *testing.T) {
case <-time.After(100 * time.Millisecond):
}
sawClose <- true
- }))
- defer ts.Close()
+ })).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error dialing: %v", err)
@@ -3341,12 +3279,14 @@ func TestCloseNotifierChanLeak(t *testing.T) {
// Issue 9763.
// HTTP/1-only test. (http2 doesn't have Hijack)
func TestHijackAfterCloseNotifier(t *testing.T) {
- defer afterTest(t)
+ run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
+}
+func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
script := make(chan string, 2)
script <- "closenotify"
script <- "hijack"
close(script)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
plan := <-script
switch plan {
default:
@@ -3369,13 +3309,12 @@ func TestHijackAfterCloseNotifier(t *testing.T) {
c.Close()
return
}
- }))
- defer ts.Close()
- res1, err := Get(ts.URL)
+ })).ts
+ res1, err := ts.Client().Get(ts.URL)
if err != nil {
log.Fatal(err)
}
- res2, err := Get(ts.URL)
+ res2, err := ts.Client().Get(ts.URL)
if err != nil {
log.Fatal(err)
}
@@ -3387,12 +3326,13 @@ func TestHijackAfterCloseNotifier(t *testing.T) {
}
func TestHijackBeforeRequestBodyRead(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+ run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
+}
+func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
var requestBody = bytes.Repeat([]byte("a"), 1<<20)
bodyOkay := make(chan bool, 1)
gotCloseNotify := make(chan bool, 1)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
defer close(bodyOkay) // caller will read false if nothing else
reqBody := r.Body
@@ -3419,8 +3359,7 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) {
case <-time.After(5 * time.Second):
gotCloseNotify <- false
}
- }))
- defer ts.Close()
+ })).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
@@ -3440,14 +3379,14 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) {
}
}
-func TestOptions(t *testing.T) {
+func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
+func testOptions(t *testing.T, mode testMode) {
uric := make(chan string, 2) // only expect 1, but leave space for 2
mux := NewServeMux()
mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
uric <- r.RequestURI
})
- ts := httptest.NewServer(mux)
- defer ts.Close()
+ ts := newClientServerTest(t, mode, mux).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
@@ -3492,15 +3431,15 @@ func TestOptions(t *testing.T) {
}
}
-func TestOptionsHandler(t *testing.T) {
+func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
+func testOptionsHandler(t *testing.T, mode testMode) {
rc := make(chan *Request, 1)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
rc <- r
- }))
- ts.Config.DisableGeneralOptionsHandler = true
- ts.Start()
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ ts.Config.DisableGeneralOptionsHandler = true
+ }).ts
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
@@ -3804,12 +3743,12 @@ func TestDoubleHijack(t *testing.T) {
// optimization and is pointless if dealing with a
// badly behaved client.
func TestHTTP10ConnectionHeader(t *testing.T) {
- defer afterTest(t)
-
+ run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
+}
+func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
mux := NewServeMux()
mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
- ts := httptest.NewServer(mux)
- defer ts.Close()
+ ts := newClientServerTest(t, mode, mux).ts
// net/http uses HTTP/1.1 for requests, so write requests manually
tests := []struct {
@@ -3856,14 +3795,11 @@ func TestHTTP10ConnectionHeader(t *testing.T) {
}
// See golang.org/issue/5660
-func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) }
-func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) }
-func testServerReaderFromOrder(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
+func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
+func testServerReaderFromOrder(t *testing.T, mode testMode) {
pr, pw := io.Pipe()
const size = 3 << 20
- cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path
done := make(chan bool)
go func() {
@@ -3883,7 +3819,6 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) {
pw.Close()
<-done
}))
- defer cst.close()
req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
if err != nil {
@@ -3957,16 +3892,10 @@ func TestContentTypeOkayOn204(t *testing.T) {
// proxy). So then two people own that Request.Body (both the server
// and the http client), and both think they can close it on failure.
// Therefore, all incoming server requests Bodies need to be thread-safe.
-func TestTransportAndServerSharedBodyRace_h1(t *testing.T) {
- testTransportAndServerSharedBodyRace(t, h1Mode)
+func TestTransportAndServerSharedBodyRace(t *testing.T) {
+ run(t, testTransportAndServerSharedBodyRace)
}
-func TestTransportAndServerSharedBodyRace_h2(t *testing.T) {
- testTransportAndServerSharedBodyRace(t, h2Mode)
-}
-func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
-
+func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
const bodySize = 1 << 20
// errorf is like t.Errorf, but also writes to println. When
@@ -3980,7 +3909,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
}
unblockBackend := make(chan bool)
- backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
gone := rw.(CloseNotifier).CloseNotify()
didCopy := make(chan any)
go func() {
@@ -4007,7 +3936,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
backendRespc := make(chan *Response, 1)
var proxy *clientServerTest
- proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
req2.ContentLength = bodySize
cancel := make(chan struct{})
@@ -4027,7 +3956,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
// Try to cause a race: Both the Transport and the proxy handler's Server
// will try to read/close req.Body (aka req2.Body)
- if h2 {
+ if mode == http2Mode {
close(cancel)
} else {
proxy.c.Transport.(*Transport).CancelRequest(req2)
@@ -4071,22 +4000,23 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
// cause the Handler goroutine's Request.Body.Close to block.
// See issue 7121.
func TestRequestBodyCloseDoesntBlock(t *testing.T) {
+ run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
+}
+func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping in -short mode")
}
- defer afterTest(t)
readErrCh := make(chan error, 1)
errCh := make(chan error, 2)
- server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
go func(body io.Reader) {
_, err := body.Read(make([]byte, 100))
readErrCh <- err
}(req.Body)
time.Sleep(500 * time.Millisecond)
- }))
- defer server.Close()
+ })).ts
closeConn := make(chan bool)
defer close(closeConn)
@@ -4149,9 +4079,8 @@ func TestAppendTime(t *testing.T) {
}
}
-func TestServerConnState(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
+func testServerConnState(t *testing.T, mode testMode) {
handler := map[string]func(w ResponseWriter, r *Request){
"/": func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "Hello.")
@@ -4217,37 +4146,36 @@ func TestServerConnState(t *testing.T) {
// next call to wantLog.
}
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
handler[r.URL.Path](w, r)
- }))
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(io.Discard, "", 0)
+ ts.Config.ConnState = func(c net.Conn, state ConnState) {
+ if c == nil {
+ t.Errorf("nil conn seen in state %s", state)
+ return
+ }
+ sl := <-activeLog
+ if sl.active == nil && state == StateNew {
+ sl.active = c
+ } else if sl.active != c {
+ t.Errorf("unexpected conn in state %s", state)
+ activeLog <- sl
+ return
+ }
+ sl.got = append(sl.got, state)
+ if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) {
+ close(sl.complete)
+ sl.complete = nil
+ }
+ activeLog <- sl
+ }
+ }).ts
defer func() {
activeLog <- &stateLog{} // If the test failed, allow any remaining ConnState callbacks to complete.
ts.Close()
}()
- ts.Config.ErrorLog = log.New(io.Discard, "", 0)
- ts.Config.ConnState = func(c net.Conn, state ConnState) {
- if c == nil {
- t.Errorf("nil conn seen in state %s", state)
- return
- }
- sl := <-activeLog
- if sl.active == nil && state == StateNew {
- sl.active = c
- } else if sl.active != c {
- t.Errorf("unexpected conn in state %s", state)
- activeLog <- sl
- return
- }
- sl.got = append(sl.got, state)
- if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) {
- close(sl.complete)
- sl.complete = nil
- }
- activeLog <- sl
- }
-
- ts.Start()
c := ts.Client()
mustGet := func(url string, headers ...string) {
@@ -4329,13 +4257,15 @@ func TestServerConnState(t *testing.T) {
}, StateNew, StateActive, StateIdle, StateClosed)
}
-func TestServerKeepAlivesEnabled(t *testing.T) {
- defer afterTest(t)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
- ts.Config.SetKeepAlivesEnabled(false)
- ts.Start()
- defer ts.Close()
- res, err := Get(ts.URL)
+func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
+ run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
+}
+func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ }), func(ts *httptest.Server) {
+ ts.Config.SetKeepAlivesEnabled(false)
+ }).ts
+ res, err := ts.Client().Get(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -4346,16 +4276,12 @@ func TestServerKeepAlivesEnabled(t *testing.T) {
}
// golang.org/issue/7856
-func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) }
-func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) }
-func testServerEmptyBodyRace(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
+func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
+func testServerEmptyBodyRace(t *testing.T, mode testMode) {
var n int32
- cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
atomic.AddInt32(&n, 1)
}), optQuietLog)
- defer cst.close()
var wg sync.WaitGroup
const reqs = 20
for i := 0; i < reqs; i++ {
@@ -4436,9 +4362,9 @@ func TestCloseWrite(t *testing.T) {
// fixed.
//
// So add an explicit test for this.
-func TestServerFlushAndHijack(t *testing.T) {
- defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
+func testServerFlushAndHijack(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, "Hello, ")
w.(Flusher).Flush()
conn, buf, _ := w.(Hijacker).Hijack()
@@ -4449,8 +4375,7 @@ func TestServerFlushAndHijack(t *testing.T) {
if err := conn.Close(); err != nil {
t.Error(err)
}
- }))
- defer ts.Close()
+ })).ts
res, err := Get(ts.URL)
if err != nil {
t.Fatal(err)
@@ -4472,20 +4397,21 @@ func TestServerFlushAndHijack(t *testing.T) {
// To test, verify we don't timeout or see fewer unique client
// addresses (== unique connections) than requests.
func TestServerKeepAliveAfterWriteError(t *testing.T) {
+ run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
+}
+func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping in -short mode")
}
- defer afterTest(t)
const numReq = 3
addrc := make(chan string, numReq)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
addrc <- r.RemoteAddr
time.Sleep(500 * time.Millisecond)
w.(Flusher).Flush()
- }))
- ts.Config.WriteTimeout = 250 * time.Millisecond
- ts.Start()
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ ts.Config.WriteTimeout = 250 * time.Millisecond
+ }).ts
errc := make(chan error, numReq)
go func() {
@@ -4529,12 +4455,13 @@ func TestServerKeepAliveAfterWriteError(t *testing.T) {
// Issue 9987: shouldn't add automatic Content-Length (or
// Content-Type) if a Transfer-Encoding was set by the handler.
func TestNoContentLengthIfTransferEncoding(t *testing.T) {
- defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
+}
+func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Transfer-Encoding", "foo")
io.WriteString(w, "<html>")
- }))
- defer ts.Close()
+ })).ts
c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
@@ -4682,15 +4609,12 @@ func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
}
}
-func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) }
-func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) }
-func testHandlerSetsBodyNil(t *testing.T, h2 bool) {
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
+func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
r.Body = nil
fmt.Fprintf(w, "%v", r.RemoteAddr)
}))
- defer cst.close()
get := func() string {
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
@@ -4780,9 +4704,11 @@ func TestServerValidatesHostHeader(t *testing.T) {
}
func TestServerHandlersCanHandleH2PRI(t *testing.T) {
+ run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
+}
+func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
const upgradeResponse = "upgrade here"
- defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
conn, br, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
@@ -4804,8 +4730,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) {
return
}
io.WriteString(conn, upgradeResponse)
- }))
- defer ts.Close()
+ })).ts
c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
@@ -4872,17 +4797,12 @@ func TestServerValidatesHeaders(t *testing.T) {
}
}
-func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) {
- testServerRequestContextCancel_ServeHTTPDone(t, h1Mode)
-}
-func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) {
- testServerRequestContextCancel_ServeHTTPDone(t, h2Mode)
+func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
+ run(t, testServerRequestContextCancel_ServeHTTPDone)
}
-func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
+func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
ctxc := make(chan context.Context, 1)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
ctx := r.Context()
select {
case <-ctx.Done():
@@ -4891,7 +4811,6 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) {
}
ctxc <- ctx
}))
- defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
@@ -4910,16 +4829,16 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) {
// is always blocked in a Read call so it notices the EOF from the client.
// See issues 15927 and 15224.
func TestServerRequestContextCancel_ConnClose(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+ run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
+}
+func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
inHandler := make(chan struct{})
handlerDone := make(chan struct{})
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
close(inHandler)
<-r.Context().Done()
close(handlerDone)
- }))
- defer ts.Close()
+ })).ts
c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
@@ -4931,23 +4850,17 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) {
<-handlerDone
}
-func TestServerContext_ServerContextKey_h1(t *testing.T) {
- testServerContext_ServerContextKey(t, h1Mode)
-}
-func TestServerContext_ServerContextKey_h2(t *testing.T) {
- testServerContext_ServerContextKey(t, h2Mode)
+func TestServerContext_ServerContextKey(t *testing.T) {
+ run(t, testServerContext_ServerContextKey)
}
-func testServerContext_ServerContextKey(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
ctx := r.Context()
got := ctx.Value(ServerContextKey)
if _, ok := got.(*Server); !ok {
t.Errorf("context value = %T; want *http.Server", got)
}
}))
- defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
@@ -4955,20 +4868,14 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) {
res.Body.Close()
}
-func TestServerContext_LocalAddrContextKey_h1(t *testing.T) {
- testServerContext_LocalAddrContextKey(t, h1Mode)
-}
-func TestServerContext_LocalAddrContextKey_h2(t *testing.T) {
- testServerContext_LocalAddrContextKey(t, h2Mode)
+func TestServerContext_LocalAddrContextKey(t *testing.T) {
+ run(t, testServerContext_LocalAddrContextKey)
}
-func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
+func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
ch := make(chan any, 1)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- r.Context().Value(LocalAddrContextKey)
}))
- defer cst.close()
if _, err := cst.c.Head(cst.ts.URL); err != nil {
t.Fatal(err)
}
@@ -5021,16 +4928,19 @@ func TestHandlerSetTransferEncodingGzip(t *testing.T) {
}
func BenchmarkClientServer(b *testing.B) {
+ run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
+}
+func benchmarkClientServer(b *testing.B, mode testMode) {
b.ReportAllocs()
b.StopTimer()
- ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
fmt.Fprintf(rw, "Hello world.\n")
- }))
- defer ts.Close()
+ })).ts
b.StartTimer()
+ c := ts.Client()
for i := 0; i < b.N; i++ {
- res, err := Get(ts.URL)
+ res, err := c.Get(ts.URL)
if err != nil {
b.Fatal("Get:", err)
}
@@ -5048,33 +4958,21 @@ func BenchmarkClientServer(b *testing.B) {
b.StopTimer()
}
-func BenchmarkClientServerParallel4(b *testing.B) {
- benchmarkClientServerParallel(b, 4, false)
-}
-
-func BenchmarkClientServerParallel64(b *testing.B) {
- benchmarkClientServerParallel(b, 64, false)
-}
-
-func BenchmarkClientServerParallelTLS4(b *testing.B) {
- benchmarkClientServerParallel(b, 4, true)
-}
-
-func BenchmarkClientServerParallelTLS64(b *testing.B) {
- benchmarkClientServerParallel(b, 64, true)
+func BenchmarkClientServerParallel(b *testing.B) {
+ for _, parallelism := range []int{4, 64} {
+ b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
+ run(b, func(b *testing.B, mode testMode) {
+ benchmarkClientServerParallel(b, parallelism, mode)
+ }, []testMode{http1Mode, https1Mode, http2Mode})
+ })
+ }
}
-func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) {
+func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
b.ReportAllocs()
- ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
fmt.Fprintf(rw, "Hello world.\n")
- }))
- if useTLS {
- ts.StartTLS()
- } else {
- ts.Start()
- }
- defer ts.Close()
+ })).ts
b.ResetTimer()
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
@@ -5464,15 +5362,15 @@ Host: golang.org
}
}
-func BenchmarkCloseNotifier(b *testing.B) {
+func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
+func benchmarkCloseNotifier(b *testing.B, mode testMode) {
b.ReportAllocs()
b.StopTimer()
sawClose := make(chan bool)
- ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
<-rw.(CloseNotifier).CloseNotify()
sawClose <- true
- }))
- defer ts.Close()
+ })).ts
tot := time.NewTimer(5 * time.Second)
defer tot.Stop()
b.StartTimer()
@@ -5508,20 +5406,18 @@ func TestConcurrentServerServe(t *testing.T) {
}
}
-func TestServerIdleTimeout(t *testing.T) {
+func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
+func testServerIdleTimeout(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping in short mode")
}
- setParallel(t)
- defer afterTest(t)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.Copy(io.Discard, r.Body)
io.WriteString(w, r.RemoteAddr)
- }))
- ts.Config.ReadHeaderTimeout = 1 * time.Second
- ts.Config.IdleTimeout = 2 * time.Second
- ts.Start()
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ ts.Config.ReadHeaderTimeout = 1 * time.Second
+ ts.Config.IdleTimeout = 2 * time.Second
+ }).ts
c := ts.Client()
get := func() string {
@@ -5576,12 +5472,12 @@ func get(t *testing.T, c *Client, url string) string {
// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any
// currently-open connections.
func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
+}
+func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, r.RemoteAddr)
- }))
- defer ts.Close()
+ })).ts
c := ts.Client()
tr := c.Transport.(*Transport)
@@ -5620,16 +5516,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
}
}
-func TestServerShutdown_h1(t *testing.T) {
- testServerShutdown(t, h1Mode)
-}
-func TestServerShutdown_h2(t *testing.T) {
- testServerShutdown(t, h2Mode)
-}
-
-func testServerShutdown(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
+func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
+func testServerShutdown(t *testing.T, mode testMode) {
var doShutdown func() // set later
var doStateCount func()
var shutdownRes = make(chan error, 1)
@@ -5645,10 +5533,9 @@ func testServerShutdown(t *testing.T, h2 bool) {
time.Sleep(20 * time.Millisecond)
io.WriteString(w, r.RemoteAddr)
})
- cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) {
+ cst := newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} })
})
- defer cst.close()
doShutdown = func() {
shutdownRes <- cst.ts.Config.Shutdown(context.Background())
@@ -5678,24 +5565,22 @@ func testServerShutdown(t *testing.T, h2 bool) {
}
}
-func TestServerShutdownStateNew(t *testing.T) {
+func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) }
+func testServerShutdownStateNew(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("test takes 5-6 seconds; skipping in short mode")
}
- setParallel(t)
- defer afterTest(t)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- // nothing.
- }))
var connAccepted sync.WaitGroup
- ts.Config.ConnState = func(conn net.Conn, state ConnState) {
- if state == StateNew {
- connAccepted.Done()
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // nothing.
+ }), func(ts *httptest.Server) {
+ ts.Config.ConnState = func(conn net.Conn, state ConnState) {
+ if state == StateNew {
+ connAccepted.Done()
+ }
}
- }
- ts.Start()
- defer ts.Close()
+ }).ts
// Start a connection but never write to it.
connAccepted.Add(1)
@@ -5757,16 +5642,14 @@ func TestServerCloseDeadlock(t *testing.T) {
// Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by
// both HTTP/1 and HTTP/2.
-func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) }
-func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) }
-func testServerKeepAlivesEnabled(t *testing.T, h2 bool) {
- if h2 {
+func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
+func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
+ if mode == http2Mode {
restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
defer restore()
}
// Not parallel: messes with global variable. (http2goAwayTimeout)
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer cst.close()
srv := cst.ts.Config
srv.SetKeepAlivesEnabled(false)
@@ -5803,9 +5686,8 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) {
// Issue 18447: test that the Server's ReadTimeout is stopped while
// the server's doing its 1-byte background read between requests,
// waiting for the connection to maybe close.
-func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
+func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
runTimeSensitiveTest(t, []time.Duration{
10 * time.Millisecond,
50 * time.Millisecond,
@@ -5813,17 +5695,16 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) {
time.Second,
2 * time.Second,
}, func(t *testing.T, timeout time.Duration) error {
- ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
select {
case <-time.After(2 * timeout):
fmt.Fprint(w, "ok")
case <-r.Context().Done():
fmt.Fprint(w, r.Context().Err())
}
- }))
- ts.Config.ReadTimeout = timeout
- ts.Start()
- defer ts.Close()
+ }), func(ts *httptest.Server) {
+ ts.Config.ReadTimeout = timeout
+ }).ts
c := ts.Client()
@@ -5847,8 +5728,9 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) {
// beginning of a request has been received, rather than including time the
// connection spent idle.
func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+ run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
+}
+func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
runTimeSensitiveTest(t, []time.Duration{
10 * time.Millisecond,
50 * time.Millisecond,
@@ -5856,11 +5738,10 @@ func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
time.Second,
2 * time.Second,
}, func(t *testing.T, timeout time.Duration) error {
- ts := httptest.NewUnstartedServer(serve(200))
- ts.Config.ReadHeaderTimeout = timeout
- ts.Config.IdleTimeout = 0 // disable idle timeout
- ts.Start()
- defer ts.Close()
+ ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
+ ts.Config.ReadHeaderTimeout = timeout
+ ts.Config.IdleTimeout = 0 // disable idle timeout
+ }).ts
// rather than using an http.Client, create a single connection, so that
// we can ensure this connection is not closed.
@@ -5912,13 +5793,13 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *
// Issue 18535: test that the Server doesn't try to do a background
// read if it's already done one.
func TestServerDuplicateBackgroundRead(t *testing.T) {
+ run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
+}
+func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
testenv.SkipFlaky(t, 24826)
}
- setParallel(t)
- defer afterTest(t)
-
goroutines := 5
requests := 2000
if testing.Short() {
@@ -5926,8 +5807,7 @@ func TestServerDuplicateBackgroundRead(t *testing.T) {
requests = 100
}
- hts := httptest.NewServer(HandlerFunc(NotFound))
- defer hts.Close()
+ hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
@@ -5970,14 +5850,15 @@ func TestServerDuplicateBackgroundRead(t *testing.T) {
// bufio.Reader.Buffered(), without resorting to Reading it
// (potentially blocking) to get at it.
func TestServerHijackGetsBackgroundByte(t *testing.T) {
+ run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
+}
+func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
if runtime.GOOS == "plan9" {
t.Skip("skipping test; see https://golang.org/issue/18657")
}
- setParallel(t)
- defer afterTest(t)
done := make(chan struct{})
inHandler := make(chan bool, 1)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
defer close(done)
// Tell the client to send more data after the GET request.
@@ -6000,8 +5881,7 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) {
t.Error("context unexpectedly canceled")
default:
}
- }))
- defer ts.Close()
+ })).ts
cn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
@@ -6030,14 +5910,15 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) {
// immediate 1MB of data to the server to fill up the server's 4KB
// buffer.
func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
+ run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
+}
+func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
if runtime.GOOS == "plan9" {
t.Skip("skipping test; see https://golang.org/issue/18657")
}
- setParallel(t)
- defer afterTest(t)
done := make(chan struct{})
const size = 8 << 10
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
defer close(done)
conn, buf, err := w.(Hijacker).Hijack()
@@ -6061,8 +5942,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
} else if !allX {
t.Errorf("read %q; want %d 'x'", slurp, size)
}
- }))
- defer ts.Close()
+ })).ts
cn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
@@ -6198,73 +6078,27 @@ func TestStripPortFromHost(t *testing.T) {
}
}
-func TestServerContexts(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
- type baseKey struct{}
- type connKey struct{}
- ch := make(chan context.Context, 1)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
- ch <- r.Context()
- }))
- ts.Config.BaseContext = func(ln net.Listener) context.Context {
- if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
- t.Errorf("unexpected onceClose listener type %T", ln)
- }
- return context.WithValue(context.Background(), baseKey{}, "base")
- }
- ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
- if got, want := ctx.Value(baseKey{}), "base"; got != want {
- t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
- }
- return context.WithValue(ctx, connKey{}, "conn")
- }
- ts.Start()
- defer ts.Close()
- res, err := ts.Client().Get(ts.URL)
- if err != nil {
- t.Fatal(err)
- }
- res.Body.Close()
- ctx := <-ch
- if got, want := ctx.Value(baseKey{}), "base"; got != want {
- t.Errorf("base context key = %#v; want %q", got, want)
- }
- if got, want := ctx.Value(connKey{}), "conn"; got != want {
- t.Errorf("conn context key = %#v; want %q", got, want)
- }
-}
-
-func TestServerContextsHTTP2(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
+func testServerContexts(t *testing.T, mode testMode) {
type baseKey struct{}
type connKey struct{}
ch := make(chan context.Context, 1)
- ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
- if r.ProtoMajor != 2 {
- t.Errorf("unexpected HTTP/1.x request")
- }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
ch <- r.Context()
- }))
- ts.Config.BaseContext = func(ln net.Listener) context.Context {
- if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
- t.Errorf("unexpected onceClose listener type %T", ln)
+ }), func(ts *httptest.Server) {
+ ts.Config.BaseContext = func(ln net.Listener) context.Context {
+ if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
+ t.Errorf("unexpected onceClose listener type %T", ln)
+ }
+ return context.WithValue(context.Background(), baseKey{}, "base")
}
- return context.WithValue(context.Background(), baseKey{}, "base")
- }
- ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
- if got, want := ctx.Value(baseKey{}), "base"; got != want {
- t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
+ ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
+ if got, want := ctx.Value(baseKey{}), "base"; got != want {
+ t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
+ }
+ return context.WithValue(ctx, connKey{}, "conn")
}
- return context.WithValue(ctx, connKey{}, "conn")
- }
- ts.TLS = &tls.Config{
- NextProtos: []string{"h2", "http/1.1"},
- }
- ts.StartTLS()
- defer ts.Close()
- ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true
+ }).ts
res, err := ts.Client().Get(ts.URL)
if err != nil {
t.Fatal(err)
@@ -6281,20 +6115,20 @@ func TestServerContextsHTTP2(t *testing.T) {
// Issue 35750: check ConnContext not modifying context for other connections
func TestConnContextNotModifyingAllContexts(t *testing.T) {
- setParallel(t)
- defer afterTest(t)
+ run(t, testConnContextNotModifyingAllContexts)
+}
+func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
type connKey struct{}
- ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
rw.Header().Set("Connection", "close")
- }))
- ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
- if got := ctx.Value(connKey{}); got != nil {
- t.Errorf("in ConnContext, unexpected context key = %#v", got)
+ }), func(ts *httptest.Server) {
+ ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
+ if got := ctx.Value(connKey{}); got != nil {
+ t.Errorf("in ConnContext, unexpected context key = %#v", got)
+ }
+ return context.WithValue(ctx, connKey{}, "conn")
}
- return context.WithValue(ctx, connKey{}, "conn")
- }
- ts.Start()
- defer ts.Close()
+ }).ts
var res *Response
var err error
@@ -6315,10 +6149,12 @@ func TestConnContextNotModifyingAllContexts(t *testing.T) {
// Issue 30710: ensure that as per the spec, a server responds
// with 501 Not Implemented for unsupported transfer-encodings.
func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
- cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
+}
+func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("Hello, World!"))
- }))
- defer cst.Close()
+ })).ts
serverURL, err := url.Parse(cst.URL)
if err != nil {
@@ -6353,19 +6189,9 @@ func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
}
}
-func TestContentEncodingNoSniffing_h1(t *testing.T) {
- testContentEncodingNoSniffing(t, h1Mode)
-}
-
-func TestContentEncodingNoSniffing_h2(t *testing.T) {
- testContentEncodingNoSniffing(t, h2Mode)
-}
-
// Issue 31753: don't sniff when Content-Encoding is set
-func testContentEncodingNoSniffing(t *testing.T, h2 bool) {
- setParallel(t)
- defer afterTest(t)
-
+func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
+func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
type setting struct {
name string
body []byte
@@ -6428,13 +6254,12 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) {
for _, tt := range settings {
t.Run(tt.name, func(t *testing.T) {
- cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
if tt.contentEncoding != nil {
rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
}
rw.Write(tt.body)
}))
- defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
@@ -6460,13 +6285,13 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) {
// Issue 30803: ensure that TimeoutHandler logs spurious
// WriteHeader calls, for consistency with other Handlers.
func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
+ run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
+}
+func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping in short mode")
}
- setParallel(t)
- defer afterTest(t)
-
pc, curFile, _, _ := runtime.Caller(0)
curFileBaseName := filepath.Base(curFile)
testFuncName := runtime.FuncForPC(pc).Name()
@@ -6520,7 +6345,7 @@ func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
dur = 10 * time.Second
}
th := TimeoutHandler(sh, dur, timeoutMsg)
- cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, th, optWithServerLog(srvLog))
+ cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
@@ -6590,15 +6415,16 @@ func BenchmarkResponseStatusLine(b *testing.B) {
}
})
}
+
func TestDisableKeepAliveUpgrade(t *testing.T) {
+ run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
+}
+func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping in short mode")
}
- setParallel(t)
- defer afterTest(t)
-
- s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "someProto")
w.WriteHeader(StatusSwitchingProtocols)
@@ -6611,10 +6437,9 @@ func TestDisableKeepAliveUpgrade(t *testing.T) {
// Copy from the *bufio.ReadWriter, which may contain buffered data.
// Copy to the net.Conn, to avoid buffering the output.
io.Copy(c, buf)
- }))
- s.Config.SetKeepAlivesEnabled(false)
- s.Start()
- defer s.Close()
+ }), func(ts *httptest.Server) {
+ ts.Config.SetKeepAlivesEnabled(false)
+ }).ts
cl := s.Client()
cl.Transport.(*Transport).DisableKeepAlives = true
@@ -6683,21 +6508,21 @@ func TestQuerySemicolon(t *testing.T) {
{"?a=1;x=good;x=bad", "", "good", true},
}
- for _, tt := range tests {
- t.Run(tt.query+"/allow=false", func(t *testing.T) {
- allowSemicolons := false
- testQuerySemicolon(t, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning)
- })
- t.Run(tt.query+"/allow=true", func(t *testing.T) {
- allowSemicolons, expectWarning := true, false
- testQuerySemicolon(t, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning)
- })
- }
+ run(t, func(t *testing.T, mode testMode) {
+ for _, tt := range tests {
+ t.Run(tt.query+"/allow=false", func(t *testing.T) {
+ allowSemicolons := false
+ testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning)
+ })
+ t.Run(tt.query+"/allow=true", func(t *testing.T) {
+ allowSemicolons, expectWarning := true, false
+ testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning)
+ })
+ }
+ })
}
-func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolons, expectWarning bool) {
- setParallel(t)
-
+func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectWarning bool) {
writeBackX := func(w ResponseWriter, r *Request) {
x := r.URL.Query().Get("x")
if expectWarning {
@@ -6720,11 +6545,10 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon
h = AllowQuerySemicolons(h)
}
- ts := httptest.NewUnstartedServer(h)
logBuf := &strings.Builder{}
- ts.Config.ErrorLog = log.New(logBuf, "", 0)
- ts.Start()
- defer ts.Close()
+ ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(logBuf, "", 0)
+ }).ts
req, _ := NewRequest("GET", ts.URL+query, nil)
res, err := ts.Client().Do(req)
@@ -6759,13 +6583,15 @@ func TestMaxBytesHandler(t *testing.T) {
for _, requestSize := range []int64{100, 1_000, 1_000_000} {
t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
func(t *testing.T) {
- testMaxBytesHandler(t, maxSize, requestSize)
+ run(t, func(t *testing.T, mode testMode) {
+ testMaxBytesHandler(t, mode, maxSize, requestSize)
+ })
})
}
}
}
-func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
+func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
var (
handlerN int64
handlerErr error
@@ -6776,7 +6602,7 @@ func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
io.Copy(w, &buf)
})
- ts := httptest.NewServer(MaxBytesHandler(echo, maxSize))
+ ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts
defer ts.Close()
c := ts.Client()
@@ -6843,13 +6669,12 @@ func TestProcessing(t *testing.T) {
}
}
-func TestParseFormCleanup_h1(t *testing.T) { testParseFormCleanup(t, h1Mode) }
-func TestParseFormCleanup_h2(t *testing.T) {
- t.Skip("https://go.dev/issue/20253")
- testParseFormCleanup(t, h2Mode)
-}
+func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
+func testParseFormCleanup(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("https://go.dev/issue/20253")
+ }
-func testParseFormCleanup(t *testing.T, h2 bool) {
const maxMemory = 1024
const key = "file"
@@ -6858,9 +6683,7 @@ func testParseFormCleanup(t *testing.T, h2 bool) {
t.Skip("https://go.dev/issue/25965")
}
- setParallel(t)
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
r.ParseMultipartForm(maxMemory)
f, _, err := r.FormFile(key)
if err != nil {
@@ -6874,7 +6697,6 @@ func testParseFormCleanup(t *testing.T, h2 bool) {
}
w.Write([]byte(of.Name()))
}))
- defer cst.close()
fBuf := new(bytes.Buffer)
mw := multipart.NewWriter(fBuf)
@@ -6911,33 +6733,23 @@ func testParseFormCleanup(t *testing.T, h2 bool) {
func TestHeadBody(t *testing.T) {
const identityMode = false
const chunkedMode = true
- t.Run("h1", func(t *testing.T) {
- t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "HEAD") })
- t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "HEAD") })
- })
- t.Run("h2", func(t *testing.T) {
- t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "HEAD") })
- t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "HEAD") })
+ run(t, func(t *testing.T, mode testMode) {
+ t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
+ t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
})
}
func TestGetBody(t *testing.T) {
const identityMode = false
const chunkedMode = true
- t.Run("h1", func(t *testing.T) {
- t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "GET") })
- t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "GET") })
- })
- t.Run("h2", func(t *testing.T) {
- t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "GET") })
- t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "GET") })
+ run(t, func(t *testing.T, mode testMode) {
+ t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
+ t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
})
}
-func testHeadBody(t *testing.T, h2, chunked bool, method string) {
- setParallel(t)
- defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
b, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("server reading body: %v", err)