aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/net/http/export_test.go7
-rw-r--r--src/net/http/serve_test.go63
-rw-r--r--src/net/http/server.go20
3 files changed, 71 insertions, 19 deletions
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
index 096a6d382a..a849327f45 100644
--- a/src/net/http/export_test.go
+++ b/src/net/http/export_test.go
@@ -88,12 +88,7 @@ func SetPendingDialHooks(before, after func()) {
func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
-func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
- ctx, cancel := context.WithCancel(context.Background())
- go func() {
- <-ch
- cancel()
- }()
+func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
return &timeoutHandler{
handler: handler,
testContext: ctx,
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
index a98d6c313f..e8fb77446c 100644
--- a/src/net/http/serve_test.go
+++ b/src/net/http/serve_test.go
@@ -2274,6 +2274,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
}
}
+// cancelableTimeoutContext overwrites the error message to DeadlineExceeded
+type cancelableTimeoutContext struct {
+ context.Context
+}
+
+func (c cancelableTimeoutContext) Err() error {
+ if c.Context.Err() != nil {
+ return context.DeadlineExceeded
+ }
+ return nil
+}
+
func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
func testTimeoutHandler(t *testing.T, h2 bool) {
@@ -2286,8 +2298,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
- timeout := make(chan time.Time, 1) // write to this to force timeouts
- cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout))
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, h2, h)
defer cst.close()
// Succeed without timing out:
@@ -2308,7 +2321,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
}
// Times out:
- timeout <- time.Time{}
+ cancel()
+
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
@@ -2429,8 +2443,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
- timeout := make(chan time.Time, 1) // write to this to force timeouts
- cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, h1Mode, h)
defer cst.close()
// Succeed without timing out:
@@ -2451,7 +2466,8 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
}
// Times out:
- timeout <- time.Time{}
+ cancel()
+
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
@@ -2501,6 +2517,41 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
}
}
+func TestTimeoutHandlerContextCanceled(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ sendHi := make(chan bool, 1)
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ <-sendHi
+ _, werr := w.Write([]byte("hi"))
+ writeErrors <- werr
+ })
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Hour)
+ h := NewTestTimeoutHandler(sayHi, ctx)
+ cancel()
+ cst := newClientServerTest(t, h1Mode, h)
+ defer cst.close()
+
+ // Succeed without timing out:
+ sendHi <- true
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := io.ReadAll(res.Body)
+ if g, e := string(body), ""; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g, e := <-writeErrors, context.Canceled; g != e {
+ t.Errorf("got unexpected Write error on first request: %v", g)
+ }
+}
+
// https://golang.org/issue/15948
func TestTimeoutHandlerEmptyResponse(t *testing.T) {
setParallel(t)
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 91fad68694..08fd478ed9 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -3391,9 +3391,15 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
case <-ctx.Done():
tw.mu.Lock()
defer tw.mu.Unlock()
- w.WriteHeader(StatusServiceUnavailable)
- io.WriteString(w, h.errorBody())
- tw.timedOut = true
+ switch err := ctx.Err(); err {
+ case context.DeadlineExceeded:
+ w.WriteHeader(StatusServiceUnavailable)
+ io.WriteString(w, h.errorBody())
+ tw.err = ErrHandlerTimeout
+ default:
+ w.WriteHeader(StatusServiceUnavailable)
+ tw.err = err
+ }
}
}
@@ -3404,7 +3410,7 @@ type timeoutWriter struct {
req *Request
mu sync.Mutex
- timedOut bool
+ err error
wroteHeader bool
code int
}
@@ -3424,8 +3430,8 @@ func (tw *timeoutWriter) Header() Header { return tw.h }
func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
- if tw.timedOut {
- return 0, ErrHandlerTimeout
+ if tw.err != nil {
+ return 0, tw.err
}
if !tw.wroteHeader {
tw.writeHeaderLocked(StatusOK)
@@ -3437,7 +3443,7 @@ func (tw *timeoutWriter) writeHeaderLocked(code int) {
checkWriteHeaderCode(code)
switch {
- case tw.timedOut:
+ case tw.err != nil:
return
case tw.wroteHeader:
if tw.req != nil {