diff options
Diffstat (limited to 'src/net/http')
| -rw-r--r-- | src/net/http/client_test.go | 10 | ||||
| -rw-r--r-- | src/net/http/export_test.go | 6 | ||||
| -rw-r--r-- | src/net/http/request.go | 8 | ||||
| -rw-r--r-- | src/net/http/transfer.go | 23 | ||||
| -rw-r--r-- | src/net/http/transport.go | 147 | ||||
| -rw-r--r-- | src/net/http/transport_internal_test.go | 5 | ||||
| -rw-r--r-- | src/net/http/transport_test.go | 4 |
7 files changed, 137 insertions, 66 deletions
diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index 534986e867..c75456ae53 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -304,6 +304,7 @@ func TestClientRedirects(t *testing.T) { } } +// Tests that Client redirects' contexts are derived from the original request's context. func TestClientRedirectContext(t *testing.T) { setParallel(t) defer afterTest(t) @@ -320,10 +321,12 @@ func TestClientRedirectContext(t *testing.T) { Transport: tr, CheckRedirect: func(req *Request, via []*Request) error { cancel() - if len(via) > 2 { - return errors.New("too many redirects") + select { + case <-req.Context().Done(): + return nil + case <-time.After(5 * time.Second): + return errors.New("redirected request's context never expired after root request canceled") } - return nil }, } req, _ := NewRequest("GET", ts.URL, nil) @@ -1818,6 +1821,7 @@ func TestTransportBodyReadError(t *testing.T) { if err != nil { t.Fatal(err) } + req = req.WithT(t) _, err = tr.RoundTrip(req) if err != someErr { t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr) diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index b61f58b2db..596171f5f0 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -8,9 +8,11 @@ package http import ( + "context" "net" "sort" "sync" + "testing" "time" ) @@ -199,3 +201,7 @@ func (s *Server) ExportAllConnsIdle() bool { } return true } + +func (r *Request) WithT(t *testing.T) *Request { + return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) +} diff --git a/src/net/http/request.go b/src/net/http/request.go index 168c03e86c..09d998dacf 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -621,6 +621,9 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai // Write body and trailer err = tw.WriteBody(w) if err != nil { + if tw.bodyReadError == err { + err = requestBodyReadError{err} + } return err } @@ -630,6 +633,11 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai return nil } +// requestBodyReadError wraps an error from (*Request).write to indicate +// that the error came from a Read call on the Request.Body. +// This error type should not escape the net/http package to users. +type requestBodyReadError struct{ error } + func idnaASCII(v string) (string, error) { if isASCII(v) { return v, nil diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go index 4f47637aa7..2a021154c9 100644 --- a/src/net/http/transfer.go +++ b/src/net/http/transfer.go @@ -51,6 +51,19 @@ func (br *byteReader) Read(p []byte) (n int, err error) { return 1, io.EOF } +// transferBodyReader is an io.Reader that reads from tw.Body +// and records any non-EOF error in tw.bodyReadError. +// It is exactly 1 pointer wide to avoid allocations into interfaces. +type transferBodyReader struct{ tw *transferWriter } + +func (br transferBodyReader) Read(p []byte) (n int, err error) { + n, err = br.tw.Body.Read(p) + if err != nil && err != io.EOF { + br.tw.bodyReadError = err + } + return +} + // transferWriter inspects the fields of a user-supplied Request or Response, // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. @@ -64,6 +77,7 @@ type transferWriter struct { TransferEncoding []string Trailer Header IsResponse bool + bodyReadError error // any non-EOF error from reading Body FlushHeaders bool // flush headers to network before body ByteReadCh chan readResult // non-nil if probeRequestBody called @@ -304,24 +318,25 @@ func (t *transferWriter) WriteBody(w io.Writer) error { // Write body if t.Body != nil { + var body = transferBodyReader{t} if chunked(t.TransferEncoding) { if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse { w = &internal.FlushAfterChunkWriter{Writer: bw} } cw := internal.NewChunkedWriter(w) - _, err = io.Copy(cw, t.Body) + _, err = io.Copy(cw, body) if err == nil { err = cw.Close() } } else if t.ContentLength == -1 { - ncopy, err = io.Copy(w, t.Body) + ncopy, err = io.Copy(w, body) } else { - ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) + ncopy, err = io.Copy(w, io.LimitReader(body, t.ContentLength)) if err != nil { return err } var nextra int64 - nextra, err = io.Copy(ioutil.Discard, t.Body) + nextra, err = io.Copy(ioutil.Discard, body) ncopy += nextra } if err != nil { diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 2aa00de50a..0d4f427a57 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -303,11 +303,15 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) { } // transportRequest is a wrapper around a *Request that adds -// optional extra headers to write. +// optional extra headers to write and stores any error to return +// from roundTrip. type transportRequest struct { *Request // original request, not to be mutated extra Header // extra headers to write, or nil trace *httptrace.ClientTrace // optional + + mu sync.Mutex // guards err + err error // first setError value for mapRoundTripError to consider } func (tr *transportRequest) extraHeaders() Header { @@ -317,6 +321,14 @@ func (tr *transportRequest) extraHeaders() Header { return tr.extra } +func (tr *transportRequest) setError(err error) { + tr.mu.Lock() + if tr.err == nil { + tr.err = err + } + tr.mu.Unlock() +} + // RoundTrip implements the RoundTripper interface. // // For higher-level HTTP client support (such as handling of cookies @@ -1420,22 +1432,41 @@ func (pc *persistConn) closeConnIfStillIdle() { pc.close(errIdleConnTimeout) } -// mapRoundTripErrorFromReadLoop maps the provided readLoop error into -// the error value that should be returned from persistConn.roundTrip. +// mapRoundTripError returns the appropriate error value for +// persistConn.roundTrip. +// +// The provided err is the first error that (*persistConn).roundTrip +// happened to receive from its select statement. // // The startBytesWritten value should be the value of pc.nwrite before the roundTrip // started writing the request. -func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWritten int64, err error) (out error) { +func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error { if err == nil { return nil } - if err := pc.canceled(); err != nil { - return err + + // If the request was canceled, that's better than network + // failures that were likely the result of tearing down the + // connection. + if cerr := pc.canceled(); cerr != nil { + return cerr + } + + // See if an error was set explicitly. + req.mu.Lock() + reqErr := req.err + req.mu.Unlock() + if reqErr != nil { + return reqErr } + if err == errServerClosedIdle { + // Don't decorate return err } + if _, ok := err.(transportReadFromServerError); ok { + // Don't decorate return err } if pc.isBroken() { @@ -1443,40 +1474,11 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWri if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 { return nothingWrittenError{err} } + return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err) } return err } -// mapRoundTripErrorAfterClosed returns the error value to be propagated -// up to Transport.RoundTrip method when persistConn.roundTrip sees -// its pc.closech channel close, indicating the persistConn is dead. -// (after closech is closed, pc.closed is valid). -func (pc *persistConn) mapRoundTripErrorAfterClosed(req *Request, startBytesWritten int64) error { - if err := pc.canceled(); err != nil { - return err - } - err := pc.closed - if err == errServerClosedIdle { - // Don't decorate - return err - } - if _, ok := err.(transportReadFromServerError); ok { - // Don't decorate - return err - } - - // Wait for the writeLoop goroutine to terminated, and then - // see if we actually managed to write anything. If not, we - // can retry the request. - <-pc.writeLoopDone - if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 { - return nothingWrittenError{err} - } - - return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err) - -} - func (pc *persistConn) readLoop() { closeErr := errReadLoopExiting // default value, if not changed below defer func() { @@ -1746,6 +1748,17 @@ func (pc *persistConn) writeLoop() { case wr := <-pc.writech: startBytesWritten := pc.nwrite err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh)) + if bre, ok := err.(requestBodyReadError); ok { + err = bre.error + // Errors reading from the user's + // Request.Body are high priority. + // Set it here before sending on the + // channels below or calling + // pc.close() which tears town + // connections and causes other + // errors. + wr.req.setError(err) + } if err == nil { err = pc.bw.Flush() } @@ -1913,6 +1926,14 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err gone := make(chan struct{}) defer close(gone) + defer func() { + if err != nil { + pc.t.setReqCanceler(req.Request, nil) + } + }() + + const debugRoundTrip = false + // Write the request concurrently with waiting for a response, // in case the server decides to reply before reading our full // request body. @@ -1929,38 +1950,50 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err callerGone: gone, } - var re responseAndError var respHeaderTimer <-chan time.Time cancelChan := req.Request.Cancel ctxDoneChan := req.Context().Done() -WaitResponse: for { testHookWaitResLoop() select { case err := <-writeErrCh: + if debugRoundTrip { + req.logf("writeErrCh resv: %T/%#v", err, err) + } if err != nil { - if cerr := pc.canceled(); cerr != nil { - err = cerr - } - re = responseAndError{err: err} pc.close(fmt.Errorf("write error: %v", err)) - break WaitResponse + return nil, pc.mapRoundTripError(req, startBytesWritten, err) } if d := pc.t.ResponseHeaderTimeout; d > 0 { + if debugRoundTrip { + req.logf("starting timer for %v", d) + } timer := time.NewTimer(d) defer timer.Stop() // prevent leaks respHeaderTimer = timer.C } case <-pc.closech: - re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(req.Request, startBytesWritten)} - break WaitResponse + if debugRoundTrip { + req.logf("closech recv: %T %#v", pc.closed, pc.closed) + } + return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) case <-respHeaderTimer: + if debugRoundTrip { + req.logf("timeout waiting for response headers.") + } pc.close(errTimeout) - re = responseAndError{err: errTimeout} - break WaitResponse - case re = <-resc: - re.err = pc.mapRoundTripErrorFromReadLoop(req.Request, startBytesWritten, re.err) - break WaitResponse + return nil, errTimeout + case re := <-resc: + if (re.res == nil) == (re.err == nil) { + panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) + } + if debugRoundTrip { + req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) + } + if re.err != nil { + return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) + } + return re.res, nil case <-cancelChan: pc.t.CancelRequest(req.Request) cancelChan = nil @@ -1970,14 +2003,16 @@ WaitResponse: ctxDoneChan = nil } } +} - if re.err != nil { - pc.t.setReqCanceler(req.Request, nil) - } - if (re.res == nil) == (re.err == nil) { - panic("internal error: exactly one of res or err should be set") +// tLogKey is a context WithValue key for test debugging contexts containing +// a t.Logf func. See export_test.go's Request.WithT method. +type tLogKey struct{} + +func (r *transportRequest) logf(format string, args ...interface{}) { + if logf, ok := r.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok { + logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...) } - return re.res, re.err } // markReused marks this connection as having been successfully used for a diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go index 3d24fc127d..262d8b4ac5 100644 --- a/src/net/http/transport_internal_test.go +++ b/src/net/http/transport_internal_test.go @@ -30,6 +30,7 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) { tr := new(Transport) req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil) + req = req.WithT(t) treq := &transportRequest{Request: req} cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} pc, err := tr.getConn(treq, cm) @@ -47,13 +48,13 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) { _, err = pc.roundTrip(treq) if !isTransportReadFromServerError(err) && err != errServerClosedIdle { - t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) + t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) } <-pc.closech err = pc.closed if !isTransportReadFromServerError(err) && err != errServerClosedIdle { - t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) + t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) } } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index ce98157ed5..cb315f14f4 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -1625,7 +1625,9 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { {path: "/fast", want: 200}, } for i, tt := range tests { - res, err := c.Get(ts.URL + tt.path) + req, _ := NewRequest("GET", ts.URL+tt.path, nil) + req = req.WithT(t) + res, err := c.Do(req) select { case <-inHandler: case <-time.After(5 * time.Second): |
