aboutsummaryrefslogtreecommitdiff
path: root/src/net/http
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http')
-rw-r--r--src/net/http/server.go10
-rw-r--r--src/net/http/transport_test.go29
2 files changed, 39 insertions, 0 deletions
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 5b113cff97..4d0ce5619f 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -1794,6 +1794,7 @@ func isCommonNetReadError(err error) bool {
func (c *conn) serve(ctx context.Context) {
c.remoteAddr = c.rwc.RemoteAddr().String()
ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
+ var inFlightResponse *response
defer func() {
if err := recover(); err != nil && err != ErrAbortHandler {
const size = 64 << 10
@@ -1801,7 +1802,14 @@ func (c *conn) serve(ctx context.Context) {
buf = buf[:runtime.Stack(buf, false)]
c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
}
+ if inFlightResponse != nil {
+ inFlightResponse.cancelCtx()
+ }
if !c.hijacked() {
+ if inFlightResponse != nil {
+ inFlightResponse.conn.r.abortPendingRead()
+ inFlightResponse.reqBody.Close()
+ }
c.close()
c.setState(c.rwc, StateClosed, runHooks)
}
@@ -1926,7 +1934,9 @@ func (c *conn) serve(ctx context.Context) {
// in parallel even if their responses need to be serialized.
// But we're not going to implement HTTP pipelining because it
// was never deployed in the wild and the answer is HTTP/2.
+ inFlightResponse = w
serverHandler{c.server}.ServeHTTP(w, w.req)
+ inFlightResponse = nil
w.cancelCtx()
if c.hijacked() {
return
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index eeaa492644..0cdd946de4 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -6512,3 +6512,32 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) {
close(r2c)
wg.Wait()
}
+
+func TestHandlerAbortRacesBodyRead(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ go io.Copy(io.Discard, req.Body)
+ panic(ErrAbortHandler)
+ }))
+ defer ts.Close()
+
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 10; j++ {
+ const reqLen = 6 * 1024 * 1024
+ req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
+ req.ContentLength = reqLen
+ resp, _ := ts.Client().Transport.RoundTrip(req)
+ if resp != nil {
+ resp.Body.Close()
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}