diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/net/http/client_test.go | 181 | ||||
| -rw-r--r-- | src/net/http/fs_test.go | 21 | ||||
| -rw-r--r-- | src/net/http/httptest/server.go | 4 | ||||
| -rw-r--r-- | src/net/http/httptest/server_test.go | 24 | ||||
| -rw-r--r-- | src/net/http/httputil/reverseproxy_test.go | 31 | ||||
| -rw-r--r-- | src/net/http/main_test.go | 4 | ||||
| -rw-r--r-- | src/net/http/npn_test.go | 24 | ||||
| -rw-r--r-- | src/net/http/serve_test.go | 81 | ||||
| -rw-r--r-- | src/net/http/transport_test.go | 407 |
9 files changed, 321 insertions, 456 deletions
diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index c75456ae53..73f22212f6 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -10,7 +10,6 @@ import ( "bytes" "context" "crypto/tls" - "crypto/x509" "encoding/base64" "errors" "fmt" @@ -73,7 +72,7 @@ func TestClient(t *testing.T) { ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() - c := &Client{Transport: &Transport{DisableKeepAlives: true}} + c := ts.Client() r, err := c.Get(ts.URL) var b []byte if err == nil { @@ -220,10 +219,7 @@ func TestClientRedirects(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - - c := &Client{Transport: tr} + c := ts.Client() _, err := c.Get(ts.URL) if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { t.Errorf("with default client Get, expected error %q, got %q", e, g) @@ -252,13 +248,10 @@ func TestClientRedirects(t *testing.T) { var checkErr error var lastVia []*Request var lastReq *Request - c = &Client{ - Transport: tr, - CheckRedirect: func(req *Request, via []*Request) error { - lastReq = req - lastVia = via - return checkErr - }, + c.CheckRedirect = func(req *Request, via []*Request) error { + lastReq = req + lastVia = via + return checkErr } res, err := c.Get(ts.URL) if err != nil { @@ -313,21 +306,16 @@ func TestClientRedirectContext(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - ctx, cancel := context.WithCancel(context.Background()) - c := &Client{ - Transport: tr, - CheckRedirect: func(req *Request, via []*Request) error { - cancel() - select { - case <-req.Context().Done(): - return nil - case <-time.After(5 * time.Second): - return errors.New("redirected request's context never expired after root request canceled") - } - }, + c := ts.Client() + c.CheckRedirect = func(req *Request, via []*Request) error { + cancel() + select { + case <-req.Context().Done(): + return nil + case <-time.After(5 * time.Second): + return errors.New("redirected request's context never expired after root request canceled") + } } req, _ := NewRequest("GET", ts.URL, nil) req = req.WithContext(ctx) @@ -461,11 +449,12 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa })) defer ts.Close() + c := ts.Client() for _, tt := range table { content := tt.redirectBody req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content)) req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil } - res, err := DefaultClient.Do(req) + res, err := c.Do(req) if err != nil { t.Fatal(err) @@ -519,17 +508,12 @@ func TestClientRedirectUseResponse(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - - c := &Client{ - Transport: tr, - CheckRedirect: func(req *Request, via []*Request) error { - if req.Response == nil { - t.Error("expected non-nil Request.Response") - } - return ErrUseLastResponse - }, + c := ts.Client() + c.CheckRedirect = func(req *Request, via []*Request) error { + if req.Response == nil { + t.Error("expected non-nil Request.Response") + } + return ErrUseLastResponse } res, err := c.Get(ts.URL) if err != nil { @@ -558,7 +542,7 @@ func TestClientRedirect308NoLocation(t *testing.T) { w.WriteHeader(308) })) defer ts.Close() - c := &Client{Transport: &Transport{DisableKeepAlives: true}} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -586,7 +570,7 @@ func TestClientRedirect308NoGetBody(t *testing.T) { if err != nil { t.Fatal(err) } - c := &Client{Transport: &Transport{DisableKeepAlives: true}} + c := ts.Client() req.GetBody = nil // so it can't rewind. res, err := c.Do(req) if err != nil { @@ -678,12 +662,8 @@ func TestRedirectCookiesJar(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{ - Transport: tr, - Jar: new(TestJar), - } + c := ts.Client() + c.Jar = new(TestJar) u, _ := url.Parse(ts.URL) c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) resp, err := c.Get(ts.URL) @@ -727,13 +707,10 @@ func TestJarCalls(t *testing.T) { })) defer ts.Close() jar := new(RecordingJar) - c := &Client{ - Jar: jar, - Transport: &Transport{ - Dial: func(_ string, _ string) (net.Conn, error) { - return net.Dial("tcp", ts.Listener.Addr().String()) - }, - }, + c := ts.Client() + c.Jar = jar + c.Transport.(*Transport).Dial = func(_ string, _ string) (net.Conn, error) { + return net.Dial("tcp", ts.Listener.Addr().String()) } _, err := c.Get("http://firsthost.fake/") if err != nil { @@ -845,7 +822,8 @@ func TestClientWrites(t *testing.T) { } return c, err } - c := &Client{Transport: &Transport{Dial: dialer}} + c := ts.Client() + c.Transport.(*Transport).Dial = dialer _, err := c.Get(ts.URL) if err != nil { @@ -878,14 +856,11 @@ func TestClientInsecureTransport(t *testing.T) { // TODO(bradfitz): add tests for skipping hostname checks too? // would require a new cert for testing, and probably // redundant with these tests. + c := ts.Client() for _, insecure := range []bool{true, false} { - tr := &Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecure, - }, + c.Transport.(*Transport).TLSClientConfig = &tls.Config{ + InsecureSkipVerify: insecure, } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} res, err := c.Get(ts.URL) if (err == nil) != insecure { t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) @@ -919,22 +894,6 @@ func TestClientErrorWithRequestURI(t *testing.T) { } } -func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { - certs := x509.NewCertPool() - for _, c := range ts.TLS.Certificates { - roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) - if err != nil { - t.Fatalf("error parsing server's root cert: %v", err) - } - for _, root := range roots { - certs.AddCert(root) - } - } - return &Transport{ - TLSClientConfig: &tls.Config{RootCAs: certs}, - } -} - func TestClientWithCorrectTLSServerName(t *testing.T) { defer afterTest(t) @@ -946,9 +905,8 @@ func TestClientWithCorrectTLSServerName(t *testing.T) { })) defer ts.Close() - trans := newTLSTransport(t, ts) - trans.TLSClientConfig.ServerName = serverName - c := &Client{Transport: trans} + c := ts.Client() + c.Transport.(*Transport).TLSClientConfig.ServerName = serverName if _, err := c.Get(ts.URL); err != nil { t.Fatalf("expected successful TLS connection, got error: %v", err) } @@ -961,9 +919,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ErrorLog = log.New(errc, "", 0) - trans := newTLSTransport(t, ts) - trans.TLSClientConfig.ServerName = "badserver" - c := &Client{Transport: trans} + c := ts.Client() + c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver" _, err := c.Get(ts.URL) if err == nil { t.Fatalf("expected an error") @@ -997,13 +954,12 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) { })) defer ts.Close() - tr := newTLSTransport(t, ts) + c := ts.Client() + tr := c.Transport.(*Transport) tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names tr.Dial = func(netw, addr string) (net.Conn, error) { return net.Dial(netw, ts.Listener.Addr().String()) } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} res, err := c.Get("https://some-other-host.tld/") if err != nil { t.Fatal(err) @@ -1018,13 +974,12 @@ func TestResponseSetsTLSConnectionState(t *testing.T) { })) defer ts.Close() - tr := newTLSTransport(t, ts) + c := ts.Client() + tr := c.Transport.(*Transport) tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA} tr.Dial = func(netw, addr string) (net.Conn, error) { return net.Dial(netw, ts.Listener.Addr().String()) } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} res, err := c.Get("https://example.com/") if err != nil { t.Fatal(err) @@ -1119,14 +1074,12 @@ func TestEmptyPasswordAuth(t *testing.T) { } })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } req.URL.User = url.User(gopher) + c := ts.Client() resp, err := c.Do(req) if err != nil { t.Fatal(err) @@ -1503,21 +1456,17 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { defer ts2.Close() ts2URL = ts2.URL - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{ - Transport: tr, - CheckRedirect: func(r *Request, via []*Request) error { - want := Header{ - "User-Agent": []string{ua}, - "X-Foo": []string{xfoo}, - "Referer": []string{ts2URL}, - } - if !reflect.DeepEqual(r.Header, want) { - t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want) - } - return nil - }, + c := ts1.Client() + c.CheckRedirect = func(r *Request, via []*Request) error { + want := Header{ + "User-Agent": []string{ua}, + "X-Foo": []string{xfoo}, + "Referer": []string{ts2URL}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want) + } + return nil } req, _ := NewRequest("GET", ts2.URL, nil) @@ -1606,13 +1555,9 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() jar, _ := cookiejar.New(nil) - c := &Client{ - Transport: tr, - Jar: jar, - } + c := ts.Client() + c.Jar = jar u, _ := url.Parse(ts.URL) req, _ := NewRequest("GET", ts.URL, nil) @@ -1730,9 +1675,7 @@ func TestClientRedirectTypes(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - + c := ts.Client() for i, tt := range tests { handlerc <- func(w ResponseWriter, r *Request) { w.Header().Set("Location", ts.URL) @@ -1745,7 +1688,6 @@ func TestClientRedirectTypes(t *testing.T) { continue } - c := &Client{Transport: tr} c.CheckRedirect = func(req *Request, via []*Request) error { if got, want := req.Method, tt.wantMethod; got != want { return fmt.Errorf("#%d: got next method %q; want %q", i, got, want) @@ -1799,9 +1741,8 @@ func TestTransportBodyReadError(t *testing.T) { w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err)) })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) // Do one initial successful request to create an idle TCP connection // for the subsequent request to reuse. (The Transport only retries diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go index 1de1cd53d0..e12350efd7 100644 --- a/src/net/http/fs_test.go +++ b/src/net/http/fs_test.go @@ -74,6 +74,7 @@ func TestServeFile(t *testing.T) { ServeFile(w, r, "testdata/file") })) defer ts.Close() + c := ts.Client() var err error @@ -91,7 +92,7 @@ func TestServeFile(t *testing.T) { req.Method = "GET" // straight GET - _, body := getBody(t, "straight get", req) + _, body := getBody(t, "straight get", req, c) if !bytes.Equal(body, file) { t.Fatalf("body mismatch: got %q, want %q", body, file) } @@ -102,7 +103,7 @@ Cases: if rt.r != "" { req.Header.Set("Range", rt.r) } - resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req) + resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req, c) if resp.StatusCode != rt.code { t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code) } @@ -704,7 +705,8 @@ func TestDirectoryIfNotModified(t *testing.T) { req, _ := NewRequest("GET", ts.URL, nil) req.Header.Set("If-Modified-Since", lastMod) - res, err = DefaultClient.Do(req) + c := ts.Client() + res, err = c.Do(req) if err != nil { t.Fatal(err) } @@ -716,7 +718,7 @@ func TestDirectoryIfNotModified(t *testing.T) { // Advance the index.html file's modtime, but not the directory's. indexFile.modtime = indexFile.modtime.Add(1 * time.Hour) - res, err = DefaultClient.Do(req) + res, err = c.Do(req) if err != nil { t.Fatal(err) } @@ -995,7 +997,9 @@ func TestServeContent(t *testing.T) { for k, v := range tt.reqHeader { req.Header.Set(k, v) } - res, err := DefaultClient.Do(req) + + c := ts.Client() + res, err := c.Do(req) if err != nil { t.Fatal(err) } @@ -1050,8 +1054,9 @@ func TestServeContentErrorMessages(t *testing.T) { } ts := httptest.NewServer(FileServer(fs)) defer ts.Close() + c := ts.Client() for _, code := range []int{403, 404, 500} { - res, err := DefaultClient.Get(fmt.Sprintf("%s/%d", ts.URL, code)) + res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code)) if err != nil { t.Errorf("Error fetching /%d: %v", code, err) continue @@ -1125,8 +1130,8 @@ func TestLinuxSendfile(t *testing.T) { } } -func getBody(t *testing.T, testName string, req Request) (*Response, []byte) { - r, err := DefaultClient.Do(&req) +func getBody(t *testing.T, testName string, req Request, client *Client) (*Response, []byte) { + r, err := client.Do(&req) if err != nil { t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err) } diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go index 56ad18ee9b..b5b18c747d 100644 --- a/src/net/http/httptest/server.go +++ b/src/net/http/httptest/server.go @@ -93,7 +93,9 @@ func NewUnstartedServer(handler http.Handler) *Server { return &Server{ Listener: newLocalListener(), Config: &http.Server{Handler: handler}, - client: &http.Client{}, + client: &http.Client{ + Transport: &http.Transport{}, + }, } } diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go index 7d80fa15dd..62846de02c 100644 --- a/src/net/http/httptest/server_test.go +++ b/src/net/http/httptest/server_test.go @@ -121,3 +121,27 @@ func TestServerClient(t *testing.T) { t.Errorf("got %q, want hello", string(got)) } } + +// Tests that the Server.Client.Transport interface is implemented +// by a *http.Transport. +func TestServerClientTransportType(t *testing.T) { + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + client := ts.Client() + if _, ok := client.Transport.(*http.Transport); !ok { + t.Errorf("got %T, want *http.Transport", client.Transport) + } +} + +// Tests that the TLS Server.Client.Transport interface is implemented +// by a *http.Transport. +func TestTLSServerClientTransportType(t *testing.T) { + ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + client := ts.Client() + if _, ok := client.Transport.(*http.Transport); !ok { + t.Errorf("got %T, want *http.Transport", client.Transport) + } +} diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 9153508ef4..008e4e717f 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -79,6 +79,7 @@ func TestReverseProxy(t *testing.T) { proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() + frontendClient := frontend.Client() getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq.Host = "some-name" @@ -86,7 +87,7 @@ func TestReverseProxy(t *testing.T) { getReq.Header.Set("Proxy-Connection", "should be deleted") getReq.Header.Set("Upgrade", "foo") getReq.Close = true - res, err := http.DefaultClient.Do(getReq) + res, err := frontendClient.Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -126,7 +127,7 @@ func TestReverseProxy(t *testing.T) { // a response results in a StatusBadGateway. getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) getReq.Close = true - res, err = http.DefaultClient.Do(getReq) + res, err = frontendClient.Do(getReq) if err != nil { t.Fatal(err) } @@ -172,7 +173,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken) getReq.Header.Set("Upgrade", "original value") getReq.Header.Set(fakeConnectionToken, "should be deleted") - res, err := http.DefaultClient.Do(getReq) + res, err := frontend.Client().Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -220,7 +221,7 @@ func TestXForwardedFor(t *testing.T) { getReq.Header.Set("Connection", "close") getReq.Header.Set("X-Forwarded-For", prevForwardedFor) getReq.Close = true - res, err := http.DefaultClient.Do(getReq) + res, err := frontend.Client().Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -259,7 +260,7 @@ func TestReverseProxyQuery(t *testing.T) { frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) req.Close = true - res, err := http.DefaultClient.Do(req) + res, err := frontend.Client().Do(req) if err != nil { t.Fatalf("%d. Get: %v", i, err) } @@ -295,7 +296,7 @@ func TestReverseProxyFlushInterval(t *testing.T) { req, _ := http.NewRequest("GET", frontend.URL, nil) req.Close = true - res, err := http.DefaultClient.Do(req) + res, err := frontend.Client().Do(req) if err != nil { t.Fatalf("Get: %v", err) } @@ -349,13 +350,14 @@ func TestReverseProxyCancelation(t *testing.T) { frontend := httptest.NewServer(proxyHandler) defer frontend.Close() + frontendClient := frontend.Client() getReq, _ := http.NewRequest("GET", frontend.URL, nil) go func() { <-reqInFlight - http.DefaultTransport.(*http.Transport).CancelRequest(getReq) + frontendClient.Transport.(*http.Transport).CancelRequest(getReq) }() - res, err := http.DefaultClient.Do(getReq) + res, err := frontendClient.Do(getReq) if res != nil { t.Errorf("got response %v; want nil", res.Status) } @@ -363,7 +365,7 @@ func TestReverseProxyCancelation(t *testing.T) { // This should be an error like: // Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079: // use of closed network connection - t.Error("DefaultClient.Do() returned nil error; want non-nil error") + t.Error("Server.Client().Do() returned nil error; want non-nil error") } } @@ -428,11 +430,12 @@ func TestUserAgentHeader(t *testing.T) { proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() + frontendClient := frontend.Client() getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq.Header.Set("User-Agent", explicitUA) getReq.Close = true - res, err := http.DefaultClient.Do(getReq) + res, err := frontendClient.Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -441,7 +444,7 @@ func TestUserAgentHeader(t *testing.T) { getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) getReq.Header.Set("User-Agent", "") getReq.Close = true - res, err = http.DefaultClient.Do(getReq) + res, err = frontendClient.Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -493,7 +496,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) { req, _ := http.NewRequest("GET", frontend.URL, nil) req.Close = true - res, err := http.DefaultClient.Do(req) + res, err := frontend.Client().Do(req) if err != nil { t.Fatalf("Get: %v", err) } @@ -540,7 +543,7 @@ func TestReverseProxy_Post(t *testing.T) { defer frontend.Close() postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) - res, err := http.DefaultClient.Do(postReq) + res, err := frontend.Client().Do(postReq) if err != nil { t.Fatalf("Do: %v", err) } @@ -573,7 +576,7 @@ func TestReverseProxy_NilBody(t *testing.T) { frontend := httptest.NewServer(proxyHandler) defer frontend.Close() - res, err := http.DefaultClient.Get(frontend.URL) + res, err := frontend.Client().Get(frontend.URL) if err != nil { t.Fatal(err) } diff --git a/src/net/http/main_test.go b/src/net/http/main_test.go index 438bd2e58f..fc0437e211 100644 --- a/src/net/http/main_test.go +++ b/src/net/http/main_test.go @@ -151,7 +151,3 @@ func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error } return err } - -func closeClient(c *http.Client) { - c.Transport.(*http.Transport).CloseIdleConnections() -} diff --git a/src/net/http/npn_test.go b/src/net/http/npn_test.go index 4c1f6b573d..618bdbe54a 100644 --- a/src/net/http/npn_test.go +++ b/src/net/http/npn_test.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "crypto/tls" + "crypto/x509" "fmt" "io" "io/ioutil" @@ -43,10 +44,7 @@ func TestNextProtoUpgrade(t *testing.T) { // Normal request, without NPN. { - tr := newTLSTransport(t, ts) - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -63,11 +61,18 @@ func TestNextProtoUpgrade(t *testing.T) { // Request to an advertised but unhandled NPN protocol. // Server will hang up. { - tr := newTLSTransport(t, ts) - tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} + certPool := x509.NewCertPool() + certPool.AddCert(ts.Certificate()) + tr := &Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certPool, + NextProtos: []string{"unhandled-proto"}, + }, + } defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - + c := &Client{ + Transport: tr, + } res, err := c.Get(ts.URL) if err == nil { defer res.Body.Close() @@ -80,7 +85,8 @@ func TestNextProtoUpgrade(t *testing.T) { // Request using the "tls-0.9" protocol, which we register here. // It is HTTP/0.9 over TLS. { - tlsConfig := newTLSTransport(t, ts).TLSClientConfig + c := ts.Client() + tlsConfig := c.Transport.(*Transport).TLSClientConfig tlsConfig.NextProtos = []string{"tls-0.9"} conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) if err != nil { diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 8092cc1bcb..d301d15eb1 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -474,9 +474,7 @@ func TestServerTimeouts(t *testing.T) { defer ts.Close() // Hit the HTTP server successfully. - tr := &Transport{DisableKeepAlives: true} // they interfere with this test - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() r, err := c.Get(ts.URL) if err != nil { t.Fatalf("http Get #1: %v", err) @@ -548,12 +546,10 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { ts.StartTLS() defer ts.Close() - tr := newTLSTransport(t, ts) - defer tr.CloseIdleConnections() - if err := ExportHttp2ConfigureTransport(tr); err != nil { + c := ts.Client() + if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { t.Fatal(err) } - c := &Client{Transport: tr} for i := 1; i <= 3; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -608,9 +604,7 @@ func TestOnlyWriteTimeout(t *testing.T) { ts.Start() defer ts.Close() - tr := &Transport{DisableKeepAlives: false} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() errc := make(chan error) go func() { @@ -671,8 +665,7 @@ func TestIdentityResponse(t *testing.T) { ts := httptest.NewServer(handler) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() // Note: this relies on the assumption (which is true) that // Get sends HTTP/1.1 or greater requests. Otherwise the @@ -949,9 +942,8 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { ts.Start() defer ts.Close() - tr := &Transport{DisableKeepAlives: true} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr, Timeout: time.Second} + c := ts.Client() + c.Timeout = time.Second fetch := func(num int, response chan<- string) { resp, err := c.Get(ts.URL) @@ -1022,9 +1014,7 @@ func TestIdentityResponseHeaders(t *testing.T) { })) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) - + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get error: %v", err) @@ -1145,12 +1135,7 @@ func TestTLSServer(t *testing.T) { t.Errorf("expected test TLS server to start with https://, got %q", ts.URL) return } - noVerifyTransport := &Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - client := &Client{Transport: noVerifyTransport} + client := ts.Client() res, err := client.Get(ts.URL) if err != nil { t.Error(err) @@ -1967,8 +1952,7 @@ func TestTimeoutHandlerRace(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() var wg sync.WaitGroup gate := make(chan bool, 10) @@ -2011,8 +1995,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { if testing.Short() { n = 10 } - c := &Client{Transport: new(Transport)} - defer closeClient(c) + + c := ts.Client() for i := 0; i < n; i++ { gate <- true wg.Add(1) @@ -2099,8 +2083,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() // Issue was caused by the timeout handler starting the timer when // was created, not when the request. So wait for more than the timeout @@ -2127,8 +2110,7 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { @@ -2364,9 +2346,7 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { ts.Start() defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -2411,8 +2391,7 @@ func TestStripPrefix(t *testing.T) { ts := httptest.NewServer(StripPrefix("/foo", h)) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() res, err := c.Get(ts.URL + "/foo/bar") if err != nil { @@ -3654,9 +3633,7 @@ func TestServerConnState(t *testing.T) { } ts.Start() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() mustGet := func(url string, headers ...string) { req, err := NewRequest("GET", url, nil) @@ -4491,15 +4468,9 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { b.ResetTimer() b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { - noVerifyTransport := &Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - defer noVerifyTransport.CloseIdleConnections() - client := &Client{Transport: noVerifyTransport} + c := ts.Client() for pb.Next() { - res, err := client.Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { b.Logf("Get: %v", err) continue @@ -4934,10 +4905,7 @@ func TestServerIdleTimeout(t *testing.T) { ts.Config.IdleTimeout = 2 * time.Second ts.Start() defer ts.Close() - - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() get := func() string { res, err := c.Get(ts.URL) @@ -4998,9 +4966,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) get := func() string { return get(t, c, ts.URL) } @@ -5119,9 +5086,7 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { ts.Start() defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index cb315f14f4..09bfef4b10 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -131,11 +131,9 @@ func TestTransportKeepAlives(t *testing.T) { ts := httptest.NewServer(hostPortHandler) defer ts.Close() + c := ts.Client() for _, disableKeepAlive := range []bool{false, true} { - tr := &Transport{DisableKeepAlives: disableKeepAlive} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - + c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive fetch := func(n int) string { res, err := c.Get(ts.URL) if err != nil { @@ -166,12 +164,11 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { connSet, testDial := makeTestDial(t) - for _, connectionClose := range []bool{false, true} { - tr := &Transport{ - Dial: testDial, - } - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) + tr.Dial = testDial + for _, connectionClose := range []bool{false, true} { fetch := func(n int) string { req := new(Request) var err error @@ -217,12 +214,10 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { connSet, testDial := makeTestDial(t) + c := ts.Client() + tr := c.Transport.(*Transport) + tr.Dial = testDial for _, connectionClose := range []bool{false, true} { - tr := &Transport{ - Dial: testDial, - } - c := &Client{Transport: tr} - fetch := func(n int) string { req := new(Request) var err error @@ -273,10 +268,9 @@ func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { ts := httptest.NewServer(hostPortHandler) defer ts.Close() - tr := &Transport{ - DisableKeepAlives: true, - } - c := &Client{Transport: tr} + c := ts.Client() + c.Transport.(*Transport).DisableKeepAlives = true + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -291,9 +285,8 @@ func TestTransportIdleCacheKeys(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() - - tr := &Transport{DisableKeepAlives: false} - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) @@ -385,9 +378,11 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } })) defer ts.Close() + + c := ts.Client() + tr := c.Transport.(*Transport) maxIdleConnsPerHost := 2 - tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConnsPerHost} - c := &Client{Transport: tr} + tr.MaxIdleConnsPerHost = maxIdleConnsPerHost // Start 3 outstanding requests and wait for the server to get them. // Their responses will hang until we write to resch, though. @@ -450,9 +445,8 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) doReq := func(name string) string { // Do a POST instead of a GET to prevent the Transport's @@ -496,9 +490,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() - - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() fetch := func(n, retries int) string { condFatalf := func(format string, arg ...interface{}) { @@ -564,10 +556,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { conn.Close() })) defer ts.Close() - - tr := &Transport{DisableKeepAlives: false} - c := &Client{Transport: tr} - defer tr.CloseIdleConnections() + c := ts.Client() // Do a bunch of traffic from different goroutines. Send to activityc // after each request completes, regardless of whether it failed. @@ -620,9 +609,8 @@ func TestTransportHeadResponses(t *testing.T) { w.WriteHeader(200) })) defer ts.Close() + c := ts.Client() - tr := &Transport{DisableKeepAlives: false} - c := &Client{Transport: tr} for i := 0; i < 2; i++ { res, err := c.Head(ts.URL) if err != nil { @@ -656,10 +644,7 @@ func TestTransportHeadChunkedResponse(t *testing.T) { w.WriteHeader(200) })) defer ts.Close() - - tr := &Transport{DisableKeepAlives: false} - c := &Client{Transport: tr} - defer tr.CloseIdleConnections() + c := ts.Client() // Ensure that we wait for the readLoop to complete before // calling Head again @@ -720,6 +705,7 @@ func TestRoundTripGzip(t *testing.T) { } })) defer ts.Close() + tr := ts.Client().Transport.(*Transport) for i, test := range roundTripTests { // Test basic request (no accept-encoding) @@ -727,7 +713,7 @@ func TestRoundTripGzip(t *testing.T) { if test.accept != "" { req.Header.Set("Accept-Encoding", test.accept) } - res, err := DefaultTransport.RoundTrip(req) + res, err := tr.RoundTrip(req) var body []byte if test.compressed { var r *gzip.Reader @@ -792,10 +778,9 @@ func TestTransportGzip(t *testing.T) { gz.Close() })) defer ts.Close() + c := ts.Client() for _, chunked := range []string{"1", "0"} { - c := &Client{Transport: &Transport{}} - // First fetch something large, but only read some of it. res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) if err != nil { @@ -845,7 +830,6 @@ func TestTransportGzip(t *testing.T) { } // And a HEAD request too, because they're always weird. - c := &Client{Transport: &Transport{}} res, err := c.Head(ts.URL) if err != nil { t.Fatalf("Head: %v", err) @@ -915,11 +899,13 @@ func TestTransportExpect100Continue(t *testing.T) { {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. } + c := ts.Client() for i, v := range tests { - tr := &Transport{ExpectContinueTimeout: 2 * time.Second} + tr := &Transport{ + ExpectContinueTimeout: 2 * time.Second, + } defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - + c.Transport = tr body := bytes.NewReader(v.body) req, err := NewRequest("PUT", ts.URL+v.path, body) if err != nil { @@ -1016,7 +1002,8 @@ func TestSocks5Proxy(t *testing.T) { if err != nil { t.Fatal(err) } - c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}} + c := ts.Client() + c.Transport.(*Transport).Proxy = ProxyURL(pu) if _, err := c.Head(ts.URL); err != nil { t.Error(err) } @@ -1052,7 +1039,8 @@ func TestTransportProxy(t *testing.T) { if err != nil { t.Fatal(err) } - c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}} + c := ts.Client() + c.Transport.(*Transport).Proxy = ProxyURL(pu) if _, err := c.Head(ts.URL); err != nil { t.Error(err) } @@ -1122,9 +1110,7 @@ func TestTransportGzipRecursive(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -1152,9 +1138,7 @@ func TestTransportGzipShort(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -1195,9 +1179,8 @@ func TestTransportPersistConnLeak(t *testing.T) { w.WriteHeader(204) })) defer ts.Close() - - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) n0 := runtime.NumGoroutine() @@ -1260,9 +1243,8 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() - - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) n0 := runtime.NumGoroutine() body := []byte("Hello") @@ -1294,8 +1276,7 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { // This used to crash; https://golang.org/issue/3266 func TestTransportIdleConnCrash(t *testing.T) { defer afterTest(t) - tr := &Transport{} - c := &Client{Transport: tr} + var tr *Transport unblockCh := make(chan bool, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1303,6 +1284,8 @@ func TestTransportIdleConnCrash(t *testing.T) { tr.CloseIdleConnections() })) defer ts.Close() + c := ts.Client() + tr = c.Transport.(*Transport) didreq := make(chan bool) go func() { @@ -1332,8 +1315,7 @@ func TestIssue3644(t *testing.T) { } })) defer ts.Close() - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -1358,8 +1340,7 @@ func TestIssue3595(t *testing.T) { Error(w, deniedMsg, StatusUnauthorized) })) defer ts.Close() - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) if err != nil { t.Errorf("Post: %v", err) @@ -1383,8 +1364,8 @@ func TestChunkedNoContent(t *testing.T) { })) defer ts.Close() + c := ts.Client() for _, closeBody := range []bool{true, false} { - c := &Client{Transport: &Transport{}} const n = 4 for i := 1; i <= n; i++ { res, err := c.Get(ts.URL) @@ -1424,10 +1405,7 @@ func TestTransportConcurrency(t *testing.T) { SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) defer SetPendingDialHooks(nil, nil) - tr := &Transport{} - defer tr.CloseIdleConnections() - - c := &Client{Transport: tr} + c := ts.Client() reqs := make(chan string) defer close(reqs) @@ -1469,23 +1447,20 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { io.Copy(w, neverEnding('a')) }) ts := httptest.NewServer(mux) + defer ts.Close() timeout := 100 * time.Millisecond - client := &Client{ - Transport: &Transport{ - Dial: func(n, addr string) (net.Conn, error) { - conn, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - conn.SetDeadline(time.Now().Add(timeout)) - if debug { - conn = NewLoggingConn("client", conn) - } - return conn, nil - }, - DisableKeepAlives: true, - }, + c := ts.Client() + c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil } getFailed := false @@ -1497,7 +1472,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { if debug { println("run", i+1, "of", nRuns) } - sres, err := client.Get(ts.URL + "/get") + sres, err := c.Get(ts.URL + "/get") if err != nil { if !getFailed { // Make the timeout longer, once. @@ -1519,7 +1494,6 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { if debug { println("tests complete; waiting for handlers to finish") } - ts.Close() } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { @@ -1537,21 +1511,17 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ts := httptest.NewServer(mux) timeout := 100 * time.Millisecond - client := &Client{ - Transport: &Transport{ - Dial: func(n, addr string) (net.Conn, error) { - conn, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - conn.SetDeadline(time.Now().Add(timeout)) - if debug { - conn = NewLoggingConn("client", conn) - } - return conn, nil - }, - DisableKeepAlives: true, - }, + c := ts.Client() + c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil } getFailed := false @@ -1563,7 +1533,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { if debug { println("run", i+1, "of", nRuns) } - sres, err := client.Get(ts.URL + "/get") + sres, err := c.Get(ts.URL + "/get") if err != nil { if !getFailed { // Make the timeout longer, once. @@ -1577,7 +1547,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { break } req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) - _, err = client.Do(req) + _, err = c.Do(req) if err == nil { sres.Body.Close() t.Errorf("Unexpected successful PUT") @@ -1609,11 +1579,8 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { ts := httptest.NewServer(mux) defer ts.Close() - tr := &Transport{ - ResponseHeaderTimeout: 500 * time.Millisecond, - } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond tests := []struct { path string @@ -1680,9 +1647,8 @@ func TestTransportCancelRequest(t *testing.T) { defer ts.Close() defer close(unblockc) - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) res, err := c.Do(req) @@ -1790,9 +1756,8 @@ func TestCancelRequestWithChannel(t *testing.T) { defer ts.Close() defer close(unblockc) - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) ch := make(chan struct{}) @@ -1849,9 +1814,7 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { defer ts.Close() defer close(unblockc) - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) if withCtx { @@ -1939,9 +1902,8 @@ func TestTransportCloseResponseBody(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) defer tr.CancelRequest(req) @@ -2061,18 +2023,12 @@ func TestTransportSocketLateBinding(t *testing.T) { defer ts.Close() dialGate := make(chan bool, 1) - tr := &Transport{ - Dial: func(n, addr string) (net.Conn, error) { - if <-dialGate { - return net.Dial(n, addr) - } - return nil, errors.New("manually closed") - }, - DisableKeepAlives: false, - } - defer tr.CloseIdleConnections() - c := &Client{ - Transport: tr, + c := ts.Client() + c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { + if <-dialGate { + return net.Dial(n, addr) + } + return nil, errors.New("manually closed") } dialGate <- true // only allow one dial @@ -2326,14 +2282,11 @@ func TestIdleConnChannelLeak(t *testing.T) { SetReadLoopBeforeNextReadHook(func() { didRead <- true }) defer SetReadLoopBeforeNextReadHook(nil) - tr := &Transport{ - Dial: func(netw, addr string) (net.Conn, error) { - return net.Dial(netw, ts.Listener.Addr().String()) - }, + c := ts.Client() + tr := c.Transport.(*Transport) + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) } - defer tr.CloseIdleConnections() - - c := &Client{Transport: tr} // First, without keep-alives. for _, disableKeep := range []bool{true, false} { @@ -2376,13 +2329,11 @@ func TestTransportClosesRequestBody(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - cl := &Client{Transport: tr} + c := ts.Client() closes := 0 - res, err := cl.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) + res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) if err != nil { t.Fatal(err) } @@ -2468,20 +2419,16 @@ func TestTLSServerClosesConnection(t *testing.T) { fmt.Fprintf(w, "hello") })) defer ts.Close() - tr := &Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - defer tr.CloseIdleConnections() - client := &Client{Transport: tr} + + c := ts.Client() + tr := c.Transport.(*Transport) var nSuccess = 0 var errs []error const trials = 20 for i := 0; i < trials; i++ { tr.CloseIdleConnections() - res, err := client.Get(ts.URL + "/keep-alive-then-die") + res, err := c.Get(ts.URL + "/keep-alive-then-die") if err != nil { t.Fatal(err) } @@ -2496,7 +2443,7 @@ func TestTLSServerClosesConnection(t *testing.T) { // Now try again and see if we successfully // pick a new connection. - res, err = client.Get(ts.URL + "/") + res, err = c.Get(ts.URL + "/") if err != nil { errs = append(errs, err) continue @@ -2575,22 +2522,20 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { go io.Copy(ioutil.Discard, conn) })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - client := &Client{Transport: tr} + c := ts.Client() const bodySize = 256 << 10 finalBit := make(byteFromChanReader, 1) req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) req.ContentLength = bodySize - res, err := client.Do(req) + res, err := c.Do(req) if err := wantBody(res, err, "foo"); err != nil { t.Errorf("POST response: %v", err) } donec := make(chan bool) go func() { defer close(donec) - res, err = client.Get(ts.URL) + res, err = c.Get(ts.URL) if err := wantBody(res, err, "bar"); err != nil { t.Errorf("GET response: %v", err) return @@ -2622,10 +2567,9 @@ func TestTransportIssue10457(t *testing.T) { conn.Close() })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - cl := &Client{Transport: tr} - res, err := cl.Get(ts.URL) + c := ts.Client() + + res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get: %v", err) } @@ -2686,29 +2630,26 @@ func TestRetryIdempotentRequestsOnError(t *testing.T) { defer ts.Close() var writeNumAtomic int32 - tr := &Transport{ - Dial: func(network, addr string) (net.Conn, error) { - logf("Dial") - c, err := net.Dial(network, ts.Listener.Addr().String()) - if err != nil { - logf("Dial error: %v", err) - return nil, err - } - return &writerFuncConn{ - Conn: c, - write: func(p []byte) (n int, err error) { - if atomic.AddInt32(&writeNumAtomic, 1) == 2 { - logf("intentional write failure") - return 0, errors.New("second write fails") - } - logf("Write(%q)", p) - return c.Write(p) - }, - }, nil - }, + c := ts.Client() + c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) { + logf("Dial") + c, err := net.Dial(network, ts.Listener.Addr().String()) + if err != nil { + logf("Dial error: %v", err) + return nil, err + } + return &writerFuncConn{ + Conn: c, + write: func(p []byte) (n int, err error) { + if atomic.AddInt32(&writeNumAtomic, 1) == 2 { + logf("intentional write failure") + return 0, errors.New("second write fails") + } + logf("Write(%q)", p) + return c.Write(p) + }, + }, nil } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} SetRoundTripRetried(func() { logf("Retried.") @@ -2752,6 +2693,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { readBody <- err })) defer ts.Close() + c := ts.Client() fakeErr := errors.New("fake error") didClose := make(chan bool, 1) req, _ := NewRequest("POST", ts.URL, struct { @@ -2767,7 +2709,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { return nil }), }) - res, err := DefaultClient.Do(req) + res, err := c.Do(req) if res != nil { defer res.Body.Close() } @@ -2801,23 +2743,19 @@ func TestTransportDialTLS(t *testing.T) { mu.Unlock() })) defer ts.Close() - tr := &Transport{ - DialTLS: func(netw, addr string) (net.Conn, error) { - mu.Lock() - didDial = true - mu.Unlock() - c, err := tls.Dial(netw, addr, &tls.Config{ - InsecureSkipVerify: true, - }) - if err != nil { - return nil, err - } - return c, c.Handshake() - }, + c := ts.Client() + c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { + mu.Lock() + didDial = true + mu.Unlock() + c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) + if err != nil { + return nil, err + } + return c, c.Handshake() } - defer tr.CloseIdleConnections() - client := &Client{Transport: tr} - res, err := client.Get(ts.URL) + + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } @@ -2899,10 +2837,11 @@ func TestTransportRangeAndGzip(t *testing.T) { reqc <- r })) defer ts.Close() + c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) req.Header.Set("Range", "bytes=7-11") - res, err := DefaultClient.Do(req) + res, err := c.Do(req) if err != nil { t.Fatal(err) } @@ -2931,9 +2870,7 @@ func TestTransportResponseCancelRace(t *testing.T) { w.Write(b[:]) })) defer ts.Close() - - tr := &Transport{} - defer tr.CloseIdleConnections() + tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -2967,9 +2904,7 @@ func TestTransportDialCancelRace(t *testing.T) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() - - tr := &Transport{} - defer tr.CloseIdleConnections() + tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -3096,6 +3031,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) { w.WriteHeader(StatusOK) })) defer ts.Close() + c := ts.Client() fail := 0 count := 100 @@ -3105,10 +3041,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) { if err != nil { t.Fatal(err) } - tr := new(Transport) - defer tr.CloseIdleConnections() - client := &Client{Transport: tr} - resp, err := client.Do(req) + resp, err := c.Do(req) if err != nil { fail++ t.Logf("%d = %#v", i, err) @@ -3321,10 +3254,8 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { w.Write(rgz) // arbitrary gzip response })) defer ts.Close() + c := ts.Client() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} for i := 0; i < 2; i++ { res, err := c.Get(ts.URL) if err != nil { @@ -3353,12 +3284,9 @@ func TestTransportResponseHeaderLength(t *testing.T) { } })) defer ts.Close() + c := ts.Client() + c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 - tr := &Transport{ - MaxResponseHeaderBytes: 512 << 10, - } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} if res, err := c.Get(ts.URL); err != nil { t.Fatal(err) } else { @@ -3619,8 +3547,8 @@ func TestTransportRejectsAlphaPort(t *testing.T) { // connections. The http2 test is done in TestTransportEventTrace_h2 func TestTLSHandshakeTrace(t *testing.T) { defer afterTest(t) - s := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer s.Close() + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() var mu sync.Mutex var start, done bool @@ -3640,10 +3568,8 @@ func TestTLSHandshakeTrace(t *testing.T) { }, } - tr := &Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - req, err := NewRequest("GET", s.URL, nil) + c := ts.Client() + req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal("Unable to construct test request:", err) } @@ -3670,16 +3596,14 @@ func TestTransportMaxIdleConns(t *testing.T) { // No body for convenience. })) defer ts.Close() - tr := &Transport{ - MaxIdleConns: 4, - } - defer tr.CloseIdleConnections() + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxIdleConns = 4 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } - c := &Client{Transport: tr} ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) { return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil }) @@ -3975,17 +3899,16 @@ func TestTransportProxyConnectHeader(t *testing.T) { c.Close() })) defer ts.Close() - tr := &Transport{ - ProxyConnectHeader: Header{ - "User-Agent": {"foo"}, - "Other": {"bar"}, - }, - Proxy: func(r *Request) (*url.URL, error) { - return url.Parse(ts.URL) - }, + + c := ts.Client() + c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { + return url.Parse(ts.URL) } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c.Transport.(*Transport).ProxyConnectHeader = Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + } + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT if err == nil { res.Body.Close() |
