diff options
Diffstat (limited to 'src/net/http/httputil/reverseproxy.go')
| -rw-r--r-- | src/net/http/httputil/reverseproxy.go | 61 |
1 files changed, 28 insertions, 33 deletions
diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index eece455ac6..2a76b0b8dc 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -524,9 +524,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // Force chunking if we saw a response trailer. // This prevents net/http from calculating the length for short // bodies and adding a Content-Length. - if fl, ok := rw.(http.Flusher); ok { - fl.Flush() - } + http.NewResponseController(rw).Flush() } if len(res.Trailer) == announcedTrailers { @@ -601,21 +599,22 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { return p.FlushInterval } -func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { +func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error { + var w io.Writer = dst + if flushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: flushInterval, - } - defer mlw.stop() + mlw := &maxLatencyWriter{ + dst: dst, + flush: http.NewResponseController(dst).Flush, + latency: flushInterval, + } + defer mlw.stop() - // set up initial timer so headers get flushed even if body writes are delayed - mlw.flushPending = true - mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) - dst = mlw - } + w = mlw } var buf []byte @@ -623,7 +622,7 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval buf = p.BufferPool.Get() defer p.BufferPool.Put(buf) } - _, err := p.copyBuffer(dst, src, buf) + _, err := p.copyBuffer(w, src, buf) return err } @@ -668,13 +667,9 @@ func (p *ReverseProxy) logf(format string, args ...any) { } } -type writeFlusher interface { - io.Writer - http.Flusher -} - type maxLatencyWriter struct { - dst writeFlusher + dst io.Writer + flush func() error latency time.Duration // non-zero; negative means to flush immediately mu sync.Mutex // protects t, flushPending, and dst.Flush @@ -687,7 +682,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { defer m.mu.Unlock() n, err = m.dst.Write(p) if m.latency < 0 { - m.dst.Flush() + m.flush() return } if m.flushPending { @@ -708,7 +703,7 @@ func (m *maxLatencyWriter) delayedFlush() { if !m.flushPending { // if stop was called but AfterFunc already started this goroutine return } - m.dst.Flush() + m.flush() m.flushPending = false } @@ -739,17 +734,19 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R return } - hj, ok := rw.(http.Hijacker) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) - return - } backConn, ok := res.Body.(io.ReadWriteCloser) if !ok { p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) return } + rc := http.NewResponseController(rw) + conn, brw, hijackErr := rc.Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConnCloseCh := make(chan bool) go func() { // Ensure that the cancellation of a request closes the backend. @@ -760,12 +757,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R } backConn.Close() }() - defer close(backConnCloseCh) - conn, brw, err := hj.Hijack() - if err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + if hijackErr != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr)) return } defer conn.Close() |
