aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/net/http/transport.go11
-rw-r--r--src/net/http/transport_test.go143
2 files changed, 154 insertions, 0 deletions
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index bb9657f4ee..f0ae6ef0b9 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -1375,6 +1375,17 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) {
return
}
+// ReadFrom exposes persistConnWriter's underlying Conn to io.Copy and if
+// the Conn implements io.ReaderFrom, it can take advantage of optimizations
+// such as sendfile.
+func (w persistConnWriter) ReadFrom(r io.Reader) (n int64, err error) {
+ n, err = io.Copy(w.pc.conn, r)
+ w.pc.nwrite += n
+ return
+}
+
+var _ io.ReaderFrom = (*persistConnWriter)(nil)
+
// connectMethod is the map key (in its String form) for keeping persistent
// TCP connections alive for subsequent HTTP requests.
//
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index 6e075847dd..74767f8499 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -5059,3 +5059,146 @@ func TestTransportRequestReplayable(t *testing.T) {
})
}
}
+
+// testMockTCPConn is a mock TCP connection used to test that
+// ReadFrom is called when sending the request body.
+type testMockTCPConn struct {
+ *net.TCPConn
+
+ ReadFromCalled bool
+}
+
+func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
+ c.ReadFromCalled = true
+ return c.TCPConn.ReadFrom(r)
+}
+
+func TestTransportRequestWriteRoundTrip(t *testing.T) {
+ nBytes := int64(1 << 10)
+ newFileFunc := func() (r io.Reader, done func(), err error) {
+ f, err := ioutil.TempFile("", "net-http-newfilefunc")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Write some bytes to the file to enable reading.
+ if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
+ return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
+ }
+
+ done = func() {
+ f.Close()
+ os.Remove(f.Name())
+ }
+
+ return f, done, nil
+ }
+
+ newBufferFunc := func() (io.Reader, func(), error) {
+ return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
+ }
+
+ cases := []struct {
+ name string
+ readerFunc func() (io.Reader, func(), error)
+ contentLength int64
+ expectedReadFrom bool
+ }{
+ {
+ name: "file, length",
+ readerFunc: newFileFunc,
+ contentLength: nBytes,
+ expectedReadFrom: true,
+ },
+ {
+ name: "file, no length",
+ readerFunc: newFileFunc,
+ },
+ {
+ name: "file, negative length",
+ readerFunc: newFileFunc,
+ contentLength: -1,
+ },
+ {
+ name: "buffer",
+ contentLength: nBytes,
+ readerFunc: newBufferFunc,
+ },
+ {
+ name: "buffer, no length",
+ readerFunc: newBufferFunc,
+ },
+ {
+ name: "buffer, length -1",
+ contentLength: -1,
+ readerFunc: newBufferFunc,
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ r, cleanup, err := tc.readerFunc()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ tConn := &testMockTCPConn{}
+ trFunc := func(tr *Transport) {
+ tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ var d net.Dialer
+ conn, err := d.DialContext(ctx, network, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ tcpConn, ok := conn.(*net.TCPConn)
+ if !ok {
+ return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
+ }
+
+ tConn.TCPConn = tcpConn
+ return tConn, nil
+ }
+ }
+
+ cst := newClientServerTest(
+ t,
+ h1Mode,
+ HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(ioutil.Discard, r.Body)
+ r.Body.Close()
+ w.WriteHeader(200)
+ }),
+ trFunc,
+ )
+ defer cst.close()
+
+ req, err := NewRequest("PUT", cst.ts.URL, r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = tc.contentLength
+ req.Header.Set("Content-Type", "application/octet-stream")
+ resp, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ t.Fatalf("status code = %d; want 200", resp.StatusCode)
+ }
+
+ if !tConn.ReadFromCalled && tc.expectedReadFrom {
+ t.Fatalf("did not call ReadFrom")
+ }
+
+ if tConn.ReadFromCalled && !tc.expectedReadFrom {
+ t.Fatalf("ReadFrom was unexpectedly invoked")
+ }
+ })
+ }
+}