aboutsummaryrefslogtreecommitdiff
path: root/src/net/http
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http')
-rw-r--r--src/net/http/client_test.go10
-rw-r--r--src/net/http/export_test.go6
-rw-r--r--src/net/http/request.go8
-rw-r--r--src/net/http/transfer.go23
-rw-r--r--src/net/http/transport.go147
-rw-r--r--src/net/http/transport_internal_test.go5
-rw-r--r--src/net/http/transport_test.go4
7 files changed, 137 insertions, 66 deletions
diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go
index 534986e867..c75456ae53 100644
--- a/src/net/http/client_test.go
+++ b/src/net/http/client_test.go
@@ -304,6 +304,7 @@ func TestClientRedirects(t *testing.T) {
}
}
+// Tests that Client redirects' contexts are derived from the original request's context.
func TestClientRedirectContext(t *testing.T) {
setParallel(t)
defer afterTest(t)
@@ -320,10 +321,12 @@ func TestClientRedirectContext(t *testing.T) {
Transport: tr,
CheckRedirect: func(req *Request, via []*Request) error {
cancel()
- if len(via) > 2 {
- return errors.New("too many redirects")
+ select {
+ case <-req.Context().Done():
+ return nil
+ case <-time.After(5 * time.Second):
+ return errors.New("redirected request's context never expired after root request canceled")
}
- return nil
},
}
req, _ := NewRequest("GET", ts.URL, nil)
@@ -1818,6 +1821,7 @@ func TestTransportBodyReadError(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ req = req.WithT(t)
_, err = tr.RoundTrip(req)
if err != someErr {
t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr)
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
index b61f58b2db..596171f5f0 100644
--- a/src/net/http/export_test.go
+++ b/src/net/http/export_test.go
@@ -8,9 +8,11 @@
package http
import (
+ "context"
"net"
"sort"
"sync"
+ "testing"
"time"
)
@@ -199,3 +201,7 @@ func (s *Server) ExportAllConnsIdle() bool {
}
return true
}
+
+func (r *Request) WithT(t *testing.T) *Request {
+ return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf))
+}
diff --git a/src/net/http/request.go b/src/net/http/request.go
index 168c03e86c..09d998dacf 100644
--- a/src/net/http/request.go
+++ b/src/net/http/request.go
@@ -621,6 +621,9 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
// Write body and trailer
err = tw.WriteBody(w)
if err != nil {
+ if tw.bodyReadError == err {
+ err = requestBodyReadError{err}
+ }
return err
}
@@ -630,6 +633,11 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
return nil
}
+// requestBodyReadError wraps an error from (*Request).write to indicate
+// that the error came from a Read call on the Request.Body.
+// This error type should not escape the net/http package to users.
+type requestBodyReadError struct{ error }
+
func idnaASCII(v string) (string, error) {
if isASCII(v) {
return v, nil
diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go
index 4f47637aa7..2a021154c9 100644
--- a/src/net/http/transfer.go
+++ b/src/net/http/transfer.go
@@ -51,6 +51,19 @@ func (br *byteReader) Read(p []byte) (n int, err error) {
return 1, io.EOF
}
+// transferBodyReader is an io.Reader that reads from tw.Body
+// and records any non-EOF error in tw.bodyReadError.
+// It is exactly 1 pointer wide to avoid allocations into interfaces.
+type transferBodyReader struct{ tw *transferWriter }
+
+func (br transferBodyReader) Read(p []byte) (n int, err error) {
+ n, err = br.tw.Body.Read(p)
+ if err != nil && err != io.EOF {
+ br.tw.bodyReadError = err
+ }
+ return
+}
+
// transferWriter inspects the fields of a user-supplied Request or Response,
// sanitizes them without changing the user object and provides methods for
// writing the respective header, body and trailer in wire format.
@@ -64,6 +77,7 @@ type transferWriter struct {
TransferEncoding []string
Trailer Header
IsResponse bool
+ bodyReadError error // any non-EOF error from reading Body
FlushHeaders bool // flush headers to network before body
ByteReadCh chan readResult // non-nil if probeRequestBody called
@@ -304,24 +318,25 @@ func (t *transferWriter) WriteBody(w io.Writer) error {
// Write body
if t.Body != nil {
+ var body = transferBodyReader{t}
if chunked(t.TransferEncoding) {
if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse {
w = &internal.FlushAfterChunkWriter{Writer: bw}
}
cw := internal.NewChunkedWriter(w)
- _, err = io.Copy(cw, t.Body)
+ _, err = io.Copy(cw, body)
if err == nil {
err = cw.Close()
}
} else if t.ContentLength == -1 {
- ncopy, err = io.Copy(w, t.Body)
+ ncopy, err = io.Copy(w, body)
} else {
- ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength))
+ ncopy, err = io.Copy(w, io.LimitReader(body, t.ContentLength))
if err != nil {
return err
}
var nextra int64
- nextra, err = io.Copy(ioutil.Discard, t.Body)
+ nextra, err = io.Copy(ioutil.Discard, body)
ncopy += nextra
}
if err != nil {
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index 2aa00de50a..0d4f427a57 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -303,11 +303,15 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
}
// transportRequest is a wrapper around a *Request that adds
-// optional extra headers to write.
+// optional extra headers to write and stores any error to return
+// from roundTrip.
type transportRequest struct {
*Request // original request, not to be mutated
extra Header // extra headers to write, or nil
trace *httptrace.ClientTrace // optional
+
+ mu sync.Mutex // guards err
+ err error // first setError value for mapRoundTripError to consider
}
func (tr *transportRequest) extraHeaders() Header {
@@ -317,6 +321,14 @@ func (tr *transportRequest) extraHeaders() Header {
return tr.extra
}
+func (tr *transportRequest) setError(err error) {
+ tr.mu.Lock()
+ if tr.err == nil {
+ tr.err = err
+ }
+ tr.mu.Unlock()
+}
+
// RoundTrip implements the RoundTripper interface.
//
// For higher-level HTTP client support (such as handling of cookies
@@ -1420,22 +1432,41 @@ func (pc *persistConn) closeConnIfStillIdle() {
pc.close(errIdleConnTimeout)
}
-// mapRoundTripErrorFromReadLoop maps the provided readLoop error into
-// the error value that should be returned from persistConn.roundTrip.
+// mapRoundTripError returns the appropriate error value for
+// persistConn.roundTrip.
+//
+// The provided err is the first error that (*persistConn).roundTrip
+// happened to receive from its select statement.
//
// The startBytesWritten value should be the value of pc.nwrite before the roundTrip
// started writing the request.
-func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWritten int64, err error) (out error) {
+func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error {
if err == nil {
return nil
}
- if err := pc.canceled(); err != nil {
- return err
+
+ // If the request was canceled, that's better than network
+ // failures that were likely the result of tearing down the
+ // connection.
+ if cerr := pc.canceled(); cerr != nil {
+ return cerr
+ }
+
+ // See if an error was set explicitly.
+ req.mu.Lock()
+ reqErr := req.err
+ req.mu.Unlock()
+ if reqErr != nil {
+ return reqErr
}
+
if err == errServerClosedIdle {
+ // Don't decorate
return err
}
+
if _, ok := err.(transportReadFromServerError); ok {
+ // Don't decorate
return err
}
if pc.isBroken() {
@@ -1443,40 +1474,11 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWri
if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 {
return nothingWrittenError{err}
}
+ return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err)
}
return err
}
-// mapRoundTripErrorAfterClosed returns the error value to be propagated
-// up to Transport.RoundTrip method when persistConn.roundTrip sees
-// its pc.closech channel close, indicating the persistConn is dead.
-// (after closech is closed, pc.closed is valid).
-func (pc *persistConn) mapRoundTripErrorAfterClosed(req *Request, startBytesWritten int64) error {
- if err := pc.canceled(); err != nil {
- return err
- }
- err := pc.closed
- if err == errServerClosedIdle {
- // Don't decorate
- return err
- }
- if _, ok := err.(transportReadFromServerError); ok {
- // Don't decorate
- return err
- }
-
- // Wait for the writeLoop goroutine to terminated, and then
- // see if we actually managed to write anything. If not, we
- // can retry the request.
- <-pc.writeLoopDone
- if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 {
- return nothingWrittenError{err}
- }
-
- return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err)
-
-}
-
func (pc *persistConn) readLoop() {
closeErr := errReadLoopExiting // default value, if not changed below
defer func() {
@@ -1746,6 +1748,17 @@ func (pc *persistConn) writeLoop() {
case wr := <-pc.writech:
startBytesWritten := pc.nwrite
err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh))
+ if bre, ok := err.(requestBodyReadError); ok {
+ err = bre.error
+ // Errors reading from the user's
+ // Request.Body are high priority.
+ // Set it here before sending on the
+ // channels below or calling
+ // pc.close() which tears town
+ // connections and causes other
+ // errors.
+ wr.req.setError(err)
+ }
if err == nil {
err = pc.bw.Flush()
}
@@ -1913,6 +1926,14 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
gone := make(chan struct{})
defer close(gone)
+ defer func() {
+ if err != nil {
+ pc.t.setReqCanceler(req.Request, nil)
+ }
+ }()
+
+ const debugRoundTrip = false
+
// Write the request concurrently with waiting for a response,
// in case the server decides to reply before reading our full
// request body.
@@ -1929,38 +1950,50 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
callerGone: gone,
}
- var re responseAndError
var respHeaderTimer <-chan time.Time
cancelChan := req.Request.Cancel
ctxDoneChan := req.Context().Done()
-WaitResponse:
for {
testHookWaitResLoop()
select {
case err := <-writeErrCh:
+ if debugRoundTrip {
+ req.logf("writeErrCh resv: %T/%#v", err, err)
+ }
if err != nil {
- if cerr := pc.canceled(); cerr != nil {
- err = cerr
- }
- re = responseAndError{err: err}
pc.close(fmt.Errorf("write error: %v", err))
- break WaitResponse
+ return nil, pc.mapRoundTripError(req, startBytesWritten, err)
}
if d := pc.t.ResponseHeaderTimeout; d > 0 {
+ if debugRoundTrip {
+ req.logf("starting timer for %v", d)
+ }
timer := time.NewTimer(d)
defer timer.Stop() // prevent leaks
respHeaderTimer = timer.C
}
case <-pc.closech:
- re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(req.Request, startBytesWritten)}
- break WaitResponse
+ if debugRoundTrip {
+ req.logf("closech recv: %T %#v", pc.closed, pc.closed)
+ }
+ return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
case <-respHeaderTimer:
+ if debugRoundTrip {
+ req.logf("timeout waiting for response headers.")
+ }
pc.close(errTimeout)
- re = responseAndError{err: errTimeout}
- break WaitResponse
- case re = <-resc:
- re.err = pc.mapRoundTripErrorFromReadLoop(req.Request, startBytesWritten, re.err)
- break WaitResponse
+ return nil, errTimeout
+ case re := <-resc:
+ if (re.res == nil) == (re.err == nil) {
+ panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
+ }
+ if debugRoundTrip {
+ req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
+ }
+ if re.err != nil {
+ return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
+ }
+ return re.res, nil
case <-cancelChan:
pc.t.CancelRequest(req.Request)
cancelChan = nil
@@ -1970,14 +2003,16 @@ WaitResponse:
ctxDoneChan = nil
}
}
+}
- if re.err != nil {
- pc.t.setReqCanceler(req.Request, nil)
- }
- if (re.res == nil) == (re.err == nil) {
- panic("internal error: exactly one of res or err should be set")
+// tLogKey is a context WithValue key for test debugging contexts containing
+// a t.Logf func. See export_test.go's Request.WithT method.
+type tLogKey struct{}
+
+func (r *transportRequest) logf(format string, args ...interface{}) {
+ if logf, ok := r.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok {
+ logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...)
}
- return re.res, re.err
}
// markReused marks this connection as having been successfully used for a
diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go
index 3d24fc127d..262d8b4ac5 100644
--- a/src/net/http/transport_internal_test.go
+++ b/src/net/http/transport_internal_test.go
@@ -30,6 +30,7 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
tr := new(Transport)
req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
+ req = req.WithT(t)
treq := &transportRequest{Request: req}
cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
pc, err := tr.getConn(treq, cm)
@@ -47,13 +48,13 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
_, err = pc.roundTrip(treq)
if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
- t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
+ t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
}
<-pc.closech
err = pc.closed
if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
- t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
+ t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
}
}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index ce98157ed5..cb315f14f4 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -1625,7 +1625,9 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
{path: "/fast", want: 200},
}
for i, tt := range tests {
- res, err := c.Get(ts.URL + tt.path)
+ req, _ := NewRequest("GET", ts.URL+tt.path, nil)
+ req = req.WithT(t)
+ res, err := c.Do(req)
select {
case <-inHandler:
case <-time.After(5 * time.Second):