diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/net/http/httputil/reverseproxy.go | 11 | ||||
| -rw-r--r-- | src/net/http/httputil/reverseproxy_test.go | 48 |
2 files changed, 55 insertions, 4 deletions
diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index bbb7c13d41..079d5c86f7 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -794,16 +794,19 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R go spc.copyToBackend(errc) go spc.copyFromBackend(errc) - // wait until both copy functions have sent on the error channel + // Wait until both copy functions have sent on the error channel, + // or until one fails. err := <-errc if err == nil { err = <-errc } - if err != nil { + if err != nil && err != errCopyDone { p.getErrorHandler()(rw, req, fmt.Errorf("can't copy: %v", err)) } } +var errCopyDone = errors.New("hijacked connection copy complete") + // switchProtocolCopier exists so goroutines proxying data back and // forth have nice names in stacks. type switchProtocolCopier struct { @@ -822,7 +825,7 @@ func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { return } - errc <- nil + errc <- errCopyDone } func (c switchProtocolCopier) copyToBackend(errc chan<- error) { @@ -837,7 +840,7 @@ func (c switchProtocolCopier) copyToBackend(errc chan<- error) { return } - errc <- nil + errc <- errCopyDone } func cleanQueryParams(s string) string { diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index a826dc82fe..1acbc296c3 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -1701,6 +1701,54 @@ func TestReverseProxyWebSocketHalfTCP(t *testing.T) { } } +func TestReverseProxyUpgradeNoCloseWrite(t *testing.T) { + // The backend hijacks the connection, + // reads all data from the client, + // and returns. + backendDone := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Connection", "upgrade") + w.Header().Set("Upgrade", "u") + w.WriteHeader(101) + conn, _, err := http.NewResponseController(w).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + } + io.Copy(io.Discard, conn) + close(backendDone) + })) + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + // The proxy includes a ModifyResponse function which replaces the response body + // with its own wrapper, dropping the original body's CloseWrite method. + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ModifyResponse = func(resp *http.Response) error { + type readWriteCloserOnly struct { + io.ReadWriteCloser + } + resp.Body = readWriteCloserOnly{resp.Body.(io.ReadWriteCloser)} + return nil + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + // The client sends a request and closes the connection. + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "u") + resp, err := frontend.Client().Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + // We expect that the client's closure of the connection is propagated to the backend. + <-backendDone +} + func TestUnannouncedTrailer(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) |
