aboutsummaryrefslogtreecommitdiff
path: root/src/net/http/transport_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http/transport_test.go')
-rw-r--r--src/net/http/transport_test.go317
1 files changed, 258 insertions, 59 deletions
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index 2d9ca10bf0..7f6e0938c2 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -23,8 +23,8 @@ import (
"go/token"
"internal/nettrace"
"io"
- "io/ioutil"
"log"
+ mrand "math/rand"
"net"
. "net/http"
"net/http/httptest"
@@ -172,7 +172,7 @@ func TestTransportKeepAlives(t *testing.T) {
if err != nil {
t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
}
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
}
@@ -219,7 +219,7 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
}
defer res.Body.Close()
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
}
@@ -272,7 +272,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) {
t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v",
connectionClose, got, !connectionClose)
}
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
}
@@ -381,7 +381,7 @@ func TestTransportIdleCacheKeys(t *testing.T) {
if err != nil {
t.Error(err)
}
- ioutil.ReadAll(resp.Body)
+ io.ReadAll(resp.Body)
keys := tr.IdleConnKeysForTesting()
if e, g := 1, len(keys); e != g {
@@ -411,7 +411,7 @@ func TestTransportReadToEndReusesConn(t *testing.T) {
w.WriteHeader(200)
w.(Flusher).Flush()
} else {
- w.Header().Set("Content-Type", strconv.Itoa(len(msg)))
+ w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
w.WriteHeader(200)
}
w.Write([]byte(msg))
@@ -494,7 +494,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
t.Error(err)
return
}
- if _, err := ioutil.ReadAll(resp.Body); err != nil {
+ if _, err := io.ReadAll(resp.Body); err != nil {
t.Errorf("ReadAll: %v", err)
return
}
@@ -574,7 +574,7 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
if err != nil {
t.Errorf("unexpected error for request %s: %v", reqId, err)
}
- _, err = ioutil.ReadAll(resp.Body)
+ _, err = io.ReadAll(resp.Body)
if err != nil {
t.Errorf("unexpected error for request %s: %v", reqId, err)
}
@@ -654,7 +654,7 @@ func TestTransportMaxConnsPerHost(t *testing.T) {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
- _, err = ioutil.ReadAll(resp.Body)
+ _, err = io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body failed: %v", err)
}
@@ -732,7 +732,7 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) {
t.Fatalf("%s: %v", name, res.Status)
}
defer res.Body.Close()
- slurp, err := ioutil.ReadAll(res.Body)
+ slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("%s: %v", name, err)
}
@@ -782,7 +782,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) {
condFatalf("error in req #%d, GET: %v", n, err)
continue
}
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
condFatalf("error in req #%d, ReadAll: %v", n, err)
continue
@@ -902,7 +902,7 @@ func TestTransportHeadResponses(t *testing.T) {
if e, g := int64(123), res.ContentLength; e != g {
t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
}
- if all, err := ioutil.ReadAll(res.Body); err != nil {
+ if all, err := io.ReadAll(res.Body); err != nil {
t.Errorf("loop %d: Body ReadAll: %v", i, err)
} else if len(all) != 0 {
t.Errorf("Bogus body %q", all)
@@ -1005,10 +1005,10 @@ func TestRoundTripGzip(t *testing.T) {
t.Errorf("%d. gzip NewReader: %v", i, err)
continue
}
- body, err = ioutil.ReadAll(r)
+ body, err = io.ReadAll(r)
res.Body.Close()
} else {
- body, err = ioutil.ReadAll(res.Body)
+ body, err = io.ReadAll(res.Body)
}
if err != nil {
t.Errorf("%d. Error: %q", i, err)
@@ -1089,7 +1089,7 @@ func TestTransportGzip(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -1132,7 +1132,7 @@ func TestTransportExpect100Continue(t *testing.T) {
switch req.URL.Path {
case "/100":
// This endpoint implicitly responds 100 Continue and reads body.
- if _, err := io.Copy(ioutil.Discard, req.Body); err != nil {
+ if _, err := io.Copy(io.Discard, req.Body); err != nil {
t.Error("Failed to read Body", err)
}
rw.WriteHeader(StatusOK)
@@ -1158,7 +1158,7 @@ func TestTransportExpect100Continue(t *testing.T) {
if err != nil {
log.Fatal(err)
}
- if _, err := io.CopyN(ioutil.Discard, bufrw, req.ContentLength); err != nil {
+ if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
t.Error("Failed to read Body", err)
}
bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
@@ -1624,7 +1624,7 @@ func TestTransportGzipRecursive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -1653,7 +1653,7 @@ func TestTransportGzipShort(t *testing.T) {
t.Fatal(err)
}
defer res.Body.Close()
- _, err = ioutil.ReadAll(res.Body)
+ _, err = io.ReadAll(res.Body)
if err == nil {
t.Fatal("Expect an error from reading a body.")
}
@@ -1700,7 +1700,7 @@ func TestTransportPersistConnLeak(t *testing.T) {
res, err := c.Get(ts.URL)
didReqCh <- true
if err != nil {
- t.Errorf("client fetch error: %v", err)
+ t.Logf("client fetch error: %v", err)
failed <- true
return
}
@@ -1714,17 +1714,15 @@ func TestTransportPersistConnLeak(t *testing.T) {
case <-gotReqCh:
// ok
case <-failed:
- close(unblockCh)
- return
+ // Not great but not what we are testing:
+ // sometimes an overloaded system will fail to make all the connections.
}
}
nhigh := runtime.NumGoroutine()
// Tell all handlers to unblock and reply.
- for i := 0; i < numReq; i++ {
- unblockCh <- true
- }
+ close(unblockCh)
// Wait for all HTTP clients to be done.
for i := 0; i < numReq; i++ {
@@ -2000,7 +1998,7 @@ func TestIssue3644(t *testing.T) {
t.Fatal(err)
}
defer res.Body.Close()
- bs, err := ioutil.ReadAll(res.Body)
+ bs, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -2025,7 +2023,7 @@ func TestIssue3595(t *testing.T) {
t.Errorf("Post: %v", err)
return
}
- got, err := ioutil.ReadAll(res.Body)
+ got, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Body ReadAll: %v", err)
}
@@ -2097,7 +2095,7 @@ func TestTransportConcurrency(t *testing.T) {
wg.Done()
continue
}
- all, err := ioutil.ReadAll(res.Body)
+ all, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("read error on req %s: %v", req, err)
wg.Done()
@@ -2164,7 +2162,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
t.Errorf("Error issuing GET: %v", err)
break
}
- _, err = io.Copy(ioutil.Discard, sres.Body)
+ _, err = io.Copy(io.Discard, sres.Body)
if err == nil {
t.Errorf("Unexpected successful copy")
break
@@ -2185,7 +2183,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
})
mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
defer r.Body.Close()
- io.Copy(ioutil.Discard, r.Body)
+ io.Copy(io.Discard, r.Body)
})
ts := httptest.NewServer(mux)
timeout := 100 * time.Millisecond
@@ -2339,7 +2337,7 @@ func TestTransportCancelRequest(t *testing.T) {
tr.CancelRequest(req)
}()
t0 := time.Now()
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
d := time.Since(t0)
if err != ExportErrRequestCanceled {
@@ -2498,7 +2496,7 @@ func TestCancelRequestWithChannel(t *testing.T) {
close(ch)
}()
t0 := time.Now()
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
d := time.Since(t0)
if err != ExportErrRequestCanceled {
@@ -2679,7 +2677,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) {
Status: "200 OK",
StatusCode: 200,
Header: make(Header),
- Body: ioutil.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
+ Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
}
return res, nil
}
@@ -2693,7 +2691,7 @@ func TestTransportAltProto(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- bodyb, err := ioutil.ReadAll(res.Body)
+ bodyb, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -2770,7 +2768,7 @@ func TestTransportSocketLateBinding(t *testing.T) {
// let the foo response finish so we can use its
// connection for /bar
fooGate <- true
- io.Copy(ioutil.Discard, fooRes.Body)
+ io.Copy(io.Discard, fooRes.Body)
fooRes.Body.Close()
})
@@ -2809,7 +2807,7 @@ func TestTransportReading100Continue(t *testing.T) {
t.Error(err)
return
}
- slurp, err := ioutil.ReadAll(req.Body)
+ slurp, err := io.ReadAll(req.Body)
if err != nil {
t.Errorf("Server request body slurp: %v", err)
return
@@ -2873,7 +2871,7 @@ Content-Length: %d
if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
t.Errorf("%s: response id %q != request id %q", name, idBack, id)
}
- _, err = ioutil.ReadAll(res.Body)
+ _, err = io.ReadAll(res.Body)
if err != nil {
t.Fatalf("%s: Slurp error: %v", name, err)
}
@@ -3152,7 +3150,7 @@ func TestIdleConnChannelLeak(t *testing.T) {
func TestTransportClosesRequestBody(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- io.Copy(ioutil.Discard, r.Body)
+ io.Copy(io.Discard, r.Body)
}))
defer ts.Close()
@@ -3259,7 +3257,7 @@ func TestTLSServerClosesConnection(t *testing.T) {
t.Fatal(err)
}
<-closedc
- slurp, err := ioutil.ReadAll(res.Body)
+ slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -3274,7 +3272,7 @@ func TestTLSServerClosesConnection(t *testing.T) {
errs = append(errs, err)
continue
}
- slurp, err = ioutil.ReadAll(res.Body)
+ slurp, err = io.ReadAll(res.Body)
if err != nil {
errs = append(errs, err)
continue
@@ -3345,7 +3343,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
sconn.c = conn
sconn.Unlock()
conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
- go io.Copy(ioutil.Discard, conn)
+ go io.Copy(io.Discard, conn)
}))
defer ts.Close()
c := ts.Client()
@@ -3594,7 +3592,7 @@ func TestTransportClosesBodyOnError(t *testing.T) {
defer afterTest(t)
readBody := make(chan error, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- _, err := ioutil.ReadAll(r.Body)
+ _, err := io.ReadAll(r.Body)
readBody <- err
}))
defer ts.Close()
@@ -3736,7 +3734,7 @@ func TestTransportDialTLSContext(t *testing.T) {
if err != nil {
return nil, err
}
- return c, c.Handshake()
+ return c, c.HandshakeContext(ctx)
}
req, err := NewRequest("GET", ts.URL, nil)
@@ -3942,7 +3940,7 @@ func TestTransportResponseCancelRace(t *testing.T) {
// If we do an early close, Transport just throws the connection away and
// doesn't reuse it. In order to trigger the bug, it has to reuse the connection
// so read the body
- if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
+ if _, err := io.Copy(io.Discard, res.Body); err != nil {
t.Fatal(err)
}
@@ -3979,7 +3977,7 @@ func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
t.Fatal(err)
}
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatal(err)
@@ -4086,7 +4084,7 @@ func TestTransportFlushesBodyChunks(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- io.Copy(ioutil.Discard, req.Body)
+ io.Copy(io.Discard, req.Body)
// Unblock the transport's roundTrip goroutine.
resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
@@ -4467,7 +4465,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
// Do nothing for the second request.
return
}
- if _, err := ioutil.ReadAll(r.Body); err != nil {
+ if _, err := io.ReadAll(r.Body); err != nil {
t.Error(err)
}
if !noHooks {
@@ -4555,7 +4553,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
t.Fatal(err)
}
logf("got roundtrip.response")
- slurp, err := ioutil.ReadAll(res.Body)
+ slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -5173,6 +5171,57 @@ func TestTransportProxyConnectHeader(t *testing.T) {
}
}
+func TestTransportProxyGetConnectHeader(t *testing.T) {
+ defer afterTest(t)
+ reqc := make(chan *Request, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", r.Method)
+ }
+ reqc <- r
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ c.Close()
+ }))
+ defer ts.Close()
+
+ c := ts.Client()
+ c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
+ return url.Parse(ts.URL)
+ }
+ // These should be ignored:
+ c.Transport.(*Transport).ProxyConnectHeader = Header{
+ "User-Agent": {"foo"},
+ "Other": {"bar"},
+ }
+ c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
+ return Header{
+ "User-Agent": {"foo2"},
+ "Other": {"bar2"},
+ }, nil
+ }
+
+ res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
+ if err == nil {
+ res.Body.Close()
+ t.Errorf("unexpected success")
+ }
+ select {
+ case <-time.After(3 * time.Second):
+ t.Fatal("timeout")
+ case r := <-reqc:
+ if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
+ t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
+ }
+ if got, want := r.Header.Get("Other"), "bar2"; got != want {
+ t.Errorf("CONNECT request Other = %q; want %q", got, want)
+ }
+ }
+}
+
var errFakeRoundTrip = errors.New("fake roundtrip")
type funcRoundTripper func()
@@ -5186,7 +5235,7 @@ func wantBody(res *Response, err error, want string) error {
if err != nil {
return err
}
- slurp, err := ioutil.ReadAll(res.Body)
+ slurp, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("error reading body: %v", err)
}
@@ -5285,7 +5334,7 @@ func TestMissingStatusNoPanic(t *testing.T) {
conn, _ := ln.Accept()
if conn != nil {
io.WriteString(conn, raw)
- ioutil.ReadAll(conn)
+ io.ReadAll(conn)
conn.Close()
}
}()
@@ -5303,7 +5352,7 @@ func TestMissingStatusNoPanic(t *testing.T) {
t.Error("panicked, expecting an error")
}
if res != nil && res.Body != nil {
- io.Copy(ioutil.Discard, res.Body)
+ io.Copy(io.Discard, res.Body)
res.Body.Close()
}
@@ -5489,7 +5538,7 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
}
close(cancel)
- got, err := ioutil.ReadAll(res.Body)
+ got, err := io.ReadAll(res.Body)
if err == nil {
t.Fatalf("unexpected success; read %q, nil", got)
}
@@ -5628,7 +5677,7 @@ func TestTransportCONNECTBidi(t *testing.T) {
}
func TestTransportRequestReplayable(t *testing.T) {
- someBody := ioutil.NopCloser(strings.NewReader(""))
+ someBody := io.NopCloser(strings.NewReader(""))
tests := []struct {
name string
req *Request
@@ -5696,7 +5745,7 @@ func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
func TestTransportRequestWriteRoundTrip(t *testing.T) {
nBytes := int64(1 << 10)
newFileFunc := func() (r io.Reader, done func(), err error) {
- f, err := ioutil.TempFile("", "net-http-newfilefunc")
+ f, err := os.CreateTemp("", "net-http-newfilefunc")
if err != nil {
return nil, nil, err
}
@@ -5789,7 +5838,7 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) {
t,
h1Mode,
HandlerFunc(func(w ResponseWriter, r *Request) {
- io.Copy(ioutil.Discard, r.Body)
+ io.Copy(io.Discard, r.Body)
r.Body.Close()
w.WriteHeader(200)
}),
@@ -5841,6 +5890,7 @@ func TestTransportClone(t *testing.T) {
ResponseHeaderTimeout: time.Second,
ExpectContinueTimeout: time.Second,
ProxyConnectHeader: Header{},
+ GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
MaxResponseHeaderBytes: 1,
ForceAttemptHTTP2: true,
TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
@@ -5924,7 +5974,7 @@ func TestTransportIgnores408(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- slurp, err := ioutil.ReadAll(res.Body)
+ slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
@@ -6186,7 +6236,7 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
return
}
defer resp.Body.Close()
- _, err = ioutil.ReadAll(resp.Body)
+ _, err = io.ReadAll(resp.Body)
if err != nil {
errCh <- fmt.Errorf("read body failed: %v", err)
}
@@ -6248,7 +6298,7 @@ func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
func TestIssue32441(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 {
+ if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
t.Error("body length is zero")
}
}))
@@ -6256,7 +6306,7 @@ func TestIssue32441(t *testing.T) {
c := ts.Client()
c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
// Draining body to trigger failure condition on actual request to server.
- if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 {
+ if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
t.Error("body length is zero during round trip")
}
return nil, ErrSkipAltProtocol
@@ -6284,3 +6334,152 @@ func TestTransportRejectsSignInContentLength(t *testing.T) {
t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
}
}
+
+// dumpConn is a net.Conn which writes to Writer and reads from Reader
+type dumpConn struct {
+ io.Writer
+ io.Reader
+}
+
+func (c *dumpConn) Close() error { return nil }
+func (c *dumpConn) LocalAddr() net.Addr { return nil }
+func (c *dumpConn) RemoteAddr() net.Addr { return nil }
+func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
+
+// delegateReader is a reader that delegates to another reader,
+// once it arrives on a channel.
+type delegateReader struct {
+ c chan io.Reader
+ r io.Reader // nil until received from c
+}
+
+func (r *delegateReader) Read(p []byte) (int, error) {
+ if r.r == nil {
+ var ok bool
+ if r.r, ok = <-r.c; !ok {
+ return 0, errors.New("delegate closed")
+ }
+ }
+ return r.r.Read(p)
+}
+
+func testTransportRace(req *Request) {
+ save := req.Body
+ pr, pw := io.Pipe()
+ defer pr.Close()
+ defer pw.Close()
+ dr := &delegateReader{c: make(chan io.Reader)}
+
+ t := &Transport{
+ Dial: func(net, addr string) (net.Conn, error) {
+ return &dumpConn{pw, dr}, nil
+ },
+ }
+ defer t.CloseIdleConnections()
+
+ quitReadCh := make(chan struct{})
+ // Wait for the request before replying with a dummy response:
+ go func() {
+ defer close(quitReadCh)
+
+ req, err := ReadRequest(bufio.NewReader(pr))
+ if err == nil {
+ // Ensure all the body is read; otherwise
+ // we'll get a partial dump.
+ io.Copy(io.Discard, req.Body)
+ req.Body.Close()
+ }
+ select {
+ case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
+ case quitReadCh <- struct{}{}:
+ // Ensure delegate is closed so Read doesn't block forever.
+ close(dr.c)
+ }
+ }()
+
+ t.RoundTrip(req)
+
+ // Ensure the reader returns before we reset req.Body to prevent
+ // a data race on req.Body.
+ pw.Close()
+ <-quitReadCh
+
+ req.Body = save
+}
+
+// Issue 37669
+// Test that a cancellation doesn't result in a data race due to the writeLoop
+// goroutine being left running, if the caller mutates the processed Request
+// upon completion.
+func TestErrorWriteLoopRace(t *testing.T) {
+ if testing.Short() {
+ return
+ }
+ t.Parallel()
+ for i := 0; i < 1000; i++ {
+ delay := time.Duration(mrand.Intn(5)) * time.Millisecond
+ ctx, cancel := context.WithTimeout(context.Background(), delay)
+ defer cancel()
+
+ r := bytes.NewBuffer(make([]byte, 10000))
+ req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ testTransportRace(req)
+ }
+}
+
+// Issue 41600
+// Test that a new request which uses the connection of an active request
+// cannot cause it to be canceled as well.
+func TestCancelRequestWhenSharingConnection(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) {
+ w.Header().Add("Content-Length", "0")
+ }))
+ defer ts.Close()
+
+ client := ts.Client()
+ transport := client.Transport.(*Transport)
+ transport.MaxIdleConns = 1
+ transport.MaxConnsPerHost = 1
+
+ var wg sync.WaitGroup
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for ctx.Err() == nil {
+ reqctx, reqcancel := context.WithCancel(ctx)
+ go reqcancel()
+ req, _ := NewRequestWithContext(reqctx, "GET", ts.URL, nil)
+ res, err := client.Do(req)
+ if err == nil {
+ res.Body.Close()
+ }
+ }
+ }()
+ }
+
+ for ctx.Err() == nil {
+ req, _ := NewRequest("GET", ts.URL, nil)
+ if res, err := client.Do(req); err != nil {
+ t.Errorf("unexpected: %p %v", req, err)
+ break
+ } else {
+ res.Body.Close()
+ }
+ }
+
+ cancel()
+ wg.Wait()
+}