diff options
| author | Damien Neil <dneil@google.com> | 2026-03-03 08:12:25 -0800 |
|---|---|---|
| committer | Gopher Robot <gobot@golang.org> | 2026-03-12 08:13:20 -0700 |
| commit | 080aa8e9647e5211650f34f3a93fb493afbe396d (patch) | |
| tree | ab00dcd761d3622f08ff5aa7e2a52ff8c1fd0591 /src/net/http/internal | |
| parent | 81908597a8787b09b1da90e7c6d3461b4302820f (diff) | |
| download | go-080aa8e9647e5211650f34f3a93fb493afbe396d.tar.xz | |
net/http: use net/http/internal/http2 rather than h2_bundle.go
Rework net/http/internal/http2 to use internally-defined types
rather than net/http types (to avoid an import cycle).
Remove h2_bundle.go, and replace it with calls into
net/http/internal/http2 instead.
For #67810
Change-Id: I56a1b28dbd0e302ab15a30f819dd46256a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/751304
Reviewed-by: Nicholas Husin <nsh@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Nicholas Husin <husin@google.com>
Diffstat (limited to 'src/net/http/internal')
20 files changed, 924 insertions, 892 deletions
diff --git a/src/net/http/internal/common.go b/src/net/http/internal/common.go new file mode 100644 index 0000000000..b32f9553c4 --- /dev/null +++ b/src/net/http/internal/common.go @@ -0,0 +1,14 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import "errors" + +var ( + ErrAbortHandler = errors.New("net/http: abort Handler") + ErrBodyNotAllowed = errors.New("http: request method or response status code does not allow body") + ErrRequestCanceled = errors.New("net/http: request canceled") + ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol") +) diff --git a/src/net/http/internal/http2/api.go b/src/net/http/internal/http2/api.go new file mode 100644 index 0000000000..33a711a278 --- /dev/null +++ b/src/net/http/internal/http2/api.go @@ -0,0 +1,158 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "context" + "crypto/tls" + "errors" + "io" + "log" + "mime/multipart" + "net" + "net/http/internal" + "net/textproto" + "net/url" + "time" +) + +// Since net/http imports the http2 package, http2 cannot use any net/http types. +// This file contains definitions which exist to to avoid introducing a dependency cycle. + +// Variables defined in net/http and initialized by an init func in that package. +// +// NoBody and LocalAddrContextKey have concrete types in net/http, +// and therefore can't be moved into a common package without introducing +// a dependency cycle. +var ( + NoBody io.ReadCloser + LocalAddrContextKey any +) + +var ( + ErrAbortHandler = internal.ErrAbortHandler + ErrBodyNotAllowed = internal.ErrBodyNotAllowed + ErrNotSupported = errors.ErrUnsupported + ErrSkipAltProtocol = internal.ErrSkipAltProtocol +) + +// A ClientRequest is a Request used by the HTTP/2 client (Transport). +type ClientRequest struct { + Context context.Context + Method string + URL *url.URL + Header Header + Trailer Header + Body io.ReadCloser + Host string + GetBody func() (io.ReadCloser, error) + ContentLength int64 + Cancel <-chan struct{} + Close bool + ResTrailer *Header + + // Include the per-request stream in the ClientRequest to avoid an allocation. + stream clientStream +} + +// Clone makes a shallow copy of ClientRequest. +// +// Clone is only used in shouldRetryRequest. +// We can drop it if we ever get rid of or rework that function. +func (req *ClientRequest) Clone() *ClientRequest { + return &ClientRequest{ + Context: req.Context, + Method: req.Method, + URL: req.URL, + Header: req.Header, + Trailer: req.Trailer, + Body: req.Body, + Host: req.Host, + GetBody: req.GetBody, + ContentLength: req.ContentLength, + Cancel: req.Cancel, + Close: req.Close, + ResTrailer: req.ResTrailer, + } +} + +// A ClientResponse is a Request used by the HTTP/2 client (Transport). +type ClientResponse struct { + Status string // e.g. "200" + StatusCode int // e.g. 200 + ContentLength int64 + Uncompressed bool + Header Header + Trailer Header + Body io.ReadCloser + TLS *tls.ConnectionState +} + +type Header = textproto.MIMEHeader + +// TransportConfig is configuration from an http.Transport. +type TransportConfig interface { + MaxResponseHeaderBytes() int64 + DisableCompression() bool + DisableKeepAlives() bool + ExpectContinueTimeout() time.Duration + ResponseHeaderTimeout() time.Duration + IdleConnTimeout() time.Duration + HTTP2Config() Config +} + +// ServerConfig is configuration from an http.Server. +type ServerConfig interface { + MaxHeaderBytes() int + ConnState(net.Conn, ConnState) + DoKeepAlives() bool + WriteTimeout() time.Duration + SendPingTimeout() time.Duration + ErrorLog() *log.Logger + ReadTimeout() time.Duration + HTTP2Config() Config + DisableClientPriority() bool +} + +type Handler interface { + ServeHTTP(*ResponseWriter, *ServerRequest) +} + +type ResponseWriter = responseWriter + +type PushOptions struct { + Method string + Header Header +} + +// A ServerRequest is a Request used by the HTTP/2 server. +type ServerRequest struct { + Context context.Context + Proto string // e.g. "HTTP/1.0" + ProtoMajor int // e.g. 1 + ProtoMinor int // e.g. 0 + Method string + URL *url.URL + Header Header + Trailer Header + Body io.ReadCloser + Host string + ContentLength int64 + RemoteAddr string + RequestURI string + TLS *tls.ConnectionState + MultipartForm *multipart.Form +} + +// ConnState is identical to net/http.ConnState. +type ConnState int + +const ( + ConnStateNew ConnState = iota + ConnStateActive + ConnStateIdle + ConnStateHijacked + ConnStateClosed +) diff --git a/src/net/http/internal/http2/client_conn_pool.go b/src/net/http/internal/http2/client_conn_pool.go index e81b73e6a7..ded7c39e77 100644 --- a/src/net/http/internal/http2/client_conn_pool.go +++ b/src/net/http/internal/http2/client_conn_pool.go @@ -10,7 +10,6 @@ import ( "context" "errors" "net" - "net/http" "sync" ) @@ -21,8 +20,8 @@ type ClientConnPool interface { // returned ClientConn accounts for the upcoming RoundTrip // call, so the caller should not omit it. If the caller needs // to, ClientConn.RoundTrip can be called with a bogus - // new(http.Request) to release the stream reservation. - GetClientConn(req *http.Request, addr string) (*ClientConn, error) + // new(ClientRequest) to release the stream reservation. + GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) MarkDead(*ClientConn) } @@ -51,7 +50,7 @@ type clientConnPool struct { addConnCalls map[string]*addConnCall // in-flight addConnIfNeeded calls } -func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { +func (p *clientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) { return p.getClientConn(req, addr, dialOnMiss) } @@ -60,13 +59,13 @@ const ( noDialOnMiss = false ) -func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) { +func (p *clientConnPool) getClientConn(req *ClientRequest, addr string, dialOnMiss bool) (*ClientConn, error) { // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set? if isConnectionCloseRequest(req) && dialOnMiss { // It gets its own connection. traceGetConn(req, addr) const singleUse = true - cc, err := p.t.dialClientConn(req.Context(), addr, singleUse) + cc, err := p.t.dialClientConn(req.Context, addr, singleUse) if err != nil { return nil, err } @@ -92,7 +91,7 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis return nil, ErrNoCachedConn } traceGetConn(req, addr) - call := p.getStartDialLocked(req.Context(), addr) + call := p.getStartDialLocked(req.Context, addr) p.mu.Unlock() <-call.done if shouldRetryDial(call, req) { @@ -195,7 +194,7 @@ type addConnCall struct { } func (c *addConnCall) run(t *Transport, key string, nc net.Conn) { - cc, err := t.NewClientConn(nc) + cc, err := t.newClientConn(nc, t.disableKeepAlives(), nil) p := c.p p.mu.Lock() @@ -281,7 +280,7 @@ func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn { // connection instead. type noDialClientConnPool struct{ *clientConnPool } -func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { +func (p noDialClientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) { return p.getClientConn(req, addr, noDialOnMiss) } @@ -289,12 +288,12 @@ func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*Cl // retry dialing after the call finished unsuccessfully, for example // if the dial was canceled because of a context cancellation or // deadline expiry. -func shouldRetryDial(call *dialCall, req *http.Request) bool { +func shouldRetryDial(call *dialCall, req *ClientRequest) bool { if call.err == nil { // No error, no need to retry return false } - if call.ctx == req.Context() { + if call.ctx == req.Context { // If the call has the same context as the request, the dial // should not be retried, since any cancellation will have come // from this request. diff --git a/src/net/http/internal/http2/client_priority_go127.go b/src/net/http/internal/http2/client_priority_go127.go deleted file mode 100644 index 9e94eb642d..0000000000 --- a/src/net/http/internal/http2/client_priority_go127.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2026 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import "net/http" - -func clientPriorityDisabled(s *http.Server) bool { - return s.DisableClientPriority -} diff --git a/src/net/http/internal/http2/clientconn_test.go b/src/net/http/internal/http2/clientconn_test.go index 91a1ed50e3..d34a35e02a 100644 --- a/src/net/http/internal/http2/clientconn_test.go +++ b/src/net/http/internal/http2/clientconn_test.go @@ -14,6 +14,7 @@ import ( "fmt" "internal/gate" "io" + "net" "net/http" . "net/http/internal/http2" "reflect" @@ -21,6 +22,7 @@ import ( "testing" "testing/synctest" "time" + _ "unsafe" // for go:linkname "golang.org/x/net/http2/hpack" ) @@ -116,26 +118,11 @@ func newTestClientConnFromClientConn(t testing.TB, tr *Transport, cc *ClientConn cc: cc, } - // srv is the side controlled by the test. - var srv *synctestNetConn - if tconn := cc.TestNetConn(); tconn == nil { - // If cc.tconn is nil, we're being called with a new conn created by the - // Transport's client pool. This path skips dialing the server, and we - // create a test connection pair here. - var cli *synctestNetConn - cli, srv = synctestNetPipe() - cc.TestSetNetConn(cli) - } else { - // If cc.tconn is non-nil, we're in a test which provides a conn to the - // Transport via a TLSNextProto hook. Extract the test connection pair. - if tc, ok := tconn.(*tls.Conn); ok { - // Unwrap any *tls.Conn to the underlying net.Conn, - // to avoid dealing with encryption in tests. - tconn = tc.NetConn() - cc.TestSetNetConn(tconn) - } - srv = tconn.(*synctestNetConn).peer - } + // cli is the conn used by the client under test, srv is the side controlled by the test. + // We replace the conn being used by the client (possibly a *tls.Conn) with a new one, + // to avoid dealing with encryption in tests. + cli, srv := synctestNetPipe() + cc.TestSetNetConn(cli) srv.SetReadDeadline(time.Now()) tc.netconn = srv @@ -171,7 +158,8 @@ func newTestClientConn(t testing.TB, opts ...any) *testClientConn { tt := newTestTransport(t, opts...) const singleUse = false - _, err := tt.tr.TestNewClientConn(nil, singleUse, nil) + tr := transportFromH1Transport(tt.tr1).(*Transport) + _, err := tr.TestNewClientConn(nil, singleUse, nil) if err != nil { t.Fatalf("newClientConn: %v", err) } @@ -307,10 +295,48 @@ func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { } tc.roundtrips = append(tc.roundtrips, rt) go func() { + // TODO: This duplicates too much of the net/http RoundTrip flow. + // We need to do that here because many of the http2 Transport tests + // rely on having a ClientConn to operate on. + // + // We should switch to using net/http.Transport.NewClientConn to create + // single-target client connections, and move any http2 tests which + // exercise pooling behavior into net/http. defer close(rt.donec) - rt.resp, rt.respErr = tc.cc.TestRoundTrip(req, func(streamID uint32) { - rt.id.Store(streamID) + cresp := &http.Response{} + creq := &ClientRequest{ + Context: req.Context(), + Method: req.Method, + URL: req.URL, + Header: Header(req.Header), + Trailer: Header(req.Trailer), + Body: req.Body, + Host: req.Host, + GetBody: req.GetBody, + ContentLength: req.ContentLength, + Cancel: req.Cancel, + Close: req.Close, + ResTrailer: (*Header)(&cresp.Trailer), + } + resp, err := tc.cc.TestRoundTrip(creq, func(id uint32) { + rt.id.Store(id) }) + rt.respErr = err + if resp != nil { + cresp.Status = resp.Status + " " + http.StatusText(resp.StatusCode) + cresp.StatusCode = resp.StatusCode + cresp.Proto = "HTTP/2.0" + cresp.ProtoMajor = 2 + cresp.ProtoMinor = 0 + cresp.ContentLength = resp.ContentLength + cresp.Uncompressed = resp.Uncompressed + cresp.Header = http.Header(resp.Header) + cresp.Trailer = http.Header(resp.Trailer) + cresp.Body = resp.Body + cresp.TLS = resp.TLS + cresp.Request = req + rt.resp = cresp + } }() synctest.Wait() @@ -493,8 +519,8 @@ func diffHeaders(got, want http.Header) string { // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling // should use testClientConn instead. type testTransport struct { - t testing.TB - tr *Transport + t testing.TB + tr1 *http.Transport ccs []*testClientConn } @@ -505,16 +531,33 @@ func newTestTransport(t testing.TB, opts ...any) *testTransport { t: t, } - tr := &Transport{} + tr1 := &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + // This connection will be replaced by newTestClientConnFromClientConn. + // net/http will perform a TLS handshake on it, though. + // + // TODO: We can simplify connection handling if we support + // returning a non-*tls.Conn from Transport.DialTLSContext, + // in which case we could have a DialTLSContext function that + // returns an unencrypted conn. + cli, srv := synctestNetPipe() + go func() { + tlsSrv := tls.Server(srv, testServerTLSConfig) + if err := tlsSrv.Handshake(); err != nil { + t.Errorf("unexpected TLS server handshake error: %v", err) + } + }() + return cli, nil + }, + Protocols: protocols("h2"), + TLSClientConfig: testClientTLSConfig, + } for _, o := range opts { switch o := o.(type) { case nil: case func(*http.Transport): - o(tr.TestTransport()) - case *Transport: - tr = o + o(tr1) case func(*http.HTTP2Config): - tr1 := tr.TestTransport() if tr1.HTTP2 == nil { tr1.HTTP2 = &http.HTTP2Config{} } @@ -523,10 +566,11 @@ func newTestTransport(t testing.TB, opts ...any) *testTransport { t.Fatalf("unknown newTestTransport option type %T", o) } } - tt.tr = tr + tt.tr1 = tr1 - tr.TestSetNewClientConnHook(func(cc *ClientConn) { - tc := newTestClientConnFromClientConn(t, tr, cc) + tr2 := transportFromH1Transport(tr1).(*Transport) + tr2.TestSetNewClientConnHook(func(cc *ClientConn) { + tc := newTestClientConnFromClientConn(t, tr2, cc) tt.ccs = append(tt.ccs, tc) }) @@ -552,20 +596,22 @@ func (tt *testTransport) getConn() *testClientConn { } tc := tt.ccs[0] tt.ccs = tt.ccs[1:] - synctest.Wait() tc.readClientPreface() synctest.Wait() return tc } func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) rt := &testRoundTrip{ - t: tt.t, - donec: make(chan struct{}), + t: tt.t, + donec: make(chan struct{}), + cancel: cancel, } go func() { defer close(rt.donec) - rt.resp, rt.respErr = tt.tr.RoundTrip(req) + rt.resp, rt.respErr = tt.tr1.RoundTrip(req) }() synctest.Wait() diff --git a/src/net/http/internal/http2/config.go b/src/net/http/internal/http2/config.go index 8a7a89d016..53dfec367f 100644 --- a/src/net/http/internal/http2/config.go +++ b/src/net/http/internal/http2/config.go @@ -6,72 +6,52 @@ package http2 import ( "math" - "net/http" "time" ) -// http2Config is a package-internal version of net/http.HTTP2Config. -// -// http.HTTP2Config was added in Go 1.24. -// When running with a version of net/http that includes HTTP2Config, -// we merge the configuration with the fields in Transport or Server -// to produce an http2Config. -// -// Zero valued fields in http2Config are interpreted as in the -// net/http.HTTPConfig documentation. -// -// Precedence order for reconciling configurations is: -// -// - Use the net/http.{Server,Transport}.HTTP2Config value, when non-zero. -// - Otherwise use the http2.{Server.Transport} value. -// - If the resulting value is zero or out of range, use a default. -type http2Config struct { - MaxConcurrentStreams uint32 - StrictMaxConcurrentRequests bool - MaxDecoderHeaderTableSize uint32 - MaxEncoderHeaderTableSize uint32 - MaxReadFrameSize uint32 - MaxUploadBufferPerConnection int32 - MaxUploadBufferPerStream int32 - SendPingTimeout time.Duration - PingTimeout time.Duration - WriteByteTimeout time.Duration - PermitProhibitedCipherSuites bool - CountError func(errType string) +// Config must be kept in sync with net/http.HTTP2Config. +type Config struct { + MaxConcurrentStreams int + StrictMaxConcurrentRequests bool + MaxDecoderHeaderTableSize int + MaxEncoderHeaderTableSize int + MaxReadFrameSize int + MaxReceiveBufferPerConnection int + MaxReceiveBufferPerStream int + SendPingTimeout time.Duration + PingTimeout time.Duration + WriteByteTimeout time.Duration + PermitProhibitedCipherSuites bool + CountError func(errType string) } -// configFromServer merges configuration settings from -// net/http.Server.HTTP2Config and http2.Server. -func configFromServer(h1 *http.Server, h2 *Server) http2Config { - conf := http2Config{ - MaxConcurrentStreams: h2.MaxConcurrentStreams, - MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize, - MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize, - MaxReadFrameSize: h2.MaxReadFrameSize, - MaxUploadBufferPerConnection: h2.MaxUploadBufferPerConnection, - MaxUploadBufferPerStream: h2.MaxUploadBufferPerStream, - SendPingTimeout: h2.ReadIdleTimeout, - PingTimeout: h2.PingTimeout, - WriteByteTimeout: h2.WriteByteTimeout, - PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites, - CountError: h2.CountError, +func configFromServer(h1 ServerConfig, h2 *Server) Config { + conf := Config{ + MaxConcurrentStreams: int(h2.MaxConcurrentStreams), + MaxEncoderHeaderTableSize: int(h2.MaxEncoderHeaderTableSize), + MaxDecoderHeaderTableSize: int(h2.MaxDecoderHeaderTableSize), + MaxReadFrameSize: int(h2.MaxReadFrameSize), + MaxReceiveBufferPerConnection: int(h2.MaxUploadBufferPerConnection), + MaxReceiveBufferPerStream: int(h2.MaxUploadBufferPerStream), + SendPingTimeout: h2.ReadIdleTimeout, + PingTimeout: h2.PingTimeout, + WriteByteTimeout: h2.WriteByteTimeout, + PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites, + CountError: h2.CountError, } - fillNetHTTPConfig(&conf, h1.HTTP2) + fillNetHTTPConfig(&conf, h1.HTTP2Config()) setConfigDefaults(&conf, true) return conf } -// configFromTransport merges configuration settings from h2 and h2.t1.HTTP2 -// (the net/http Transport). -func configFromTransport(h2 *Transport) http2Config { - conf := http2Config{ - StrictMaxConcurrentRequests: h2.StrictMaxConcurrentStreams, - MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize, - MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize, - MaxReadFrameSize: h2.MaxReadFrameSize, - SendPingTimeout: h2.ReadIdleTimeout, - PingTimeout: h2.PingTimeout, - WriteByteTimeout: h2.WriteByteTimeout, +func configFromTransport(h2 *Transport) Config { + conf := Config{ + MaxEncoderHeaderTableSize: int(h2.MaxEncoderHeaderTableSize), + MaxDecoderHeaderTableSize: int(h2.MaxDecoderHeaderTableSize), + MaxReadFrameSize: int(h2.MaxReadFrameSize), + SendPingTimeout: h2.ReadIdleTimeout, + PingTimeout: h2.PingTimeout, + WriteByteTimeout: h2.WriteByteTimeout, } // Unlike most config fields, where out-of-range values revert to the default, @@ -83,8 +63,9 @@ func configFromTransport(h2 *Transport) http2Config { } if h2.t1 != nil { - fillNetHTTPConfig(&conf, h2.t1.HTTP2) + fillNetHTTPConfig(&conf, h2.t1.HTTP2Config()) } + setConfigDefaults(&conf, false) return conf } @@ -95,19 +76,19 @@ func setDefault[T ~int | ~int32 | ~uint32 | ~int64](v *T, minval, maxval, defval } } -func setConfigDefaults(conf *http2Config, server bool) { - setDefault(&conf.MaxConcurrentStreams, 1, math.MaxUint32, defaultMaxStreams) - setDefault(&conf.MaxEncoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize) - setDefault(&conf.MaxDecoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize) +func setConfigDefaults(conf *Config, server bool) { + setDefault(&conf.MaxConcurrentStreams, 1, math.MaxInt32, defaultMaxStreams) + setDefault(&conf.MaxEncoderHeaderTableSize, 1, math.MaxInt32, initialHeaderTableSize) + setDefault(&conf.MaxDecoderHeaderTableSize, 1, math.MaxInt32, initialHeaderTableSize) if server { - setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, 1<<20) + setDefault(&conf.MaxReceiveBufferPerConnection, initialWindowSize, math.MaxInt32, 1<<20) } else { - setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, transportDefaultConnFlow) + setDefault(&conf.MaxReceiveBufferPerConnection, initialWindowSize, math.MaxInt32, transportDefaultConnFlow) } if server { - setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, 1<<20) + setDefault(&conf.MaxReceiveBufferPerStream, 1, math.MaxInt32, 1<<20) } else { - setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, transportDefaultStreamFlow) + setDefault(&conf.MaxReceiveBufferPerStream, 1, math.MaxInt32, transportDefaultStreamFlow) } setDefault(&conf.MaxReadFrameSize, minMaxFrameSize, maxFrameSize, defaultMaxReadFrameSize) setDefault(&conf.PingTimeout, 1, math.MaxInt64, 15*time.Second) @@ -123,33 +104,30 @@ func adjustHTTP1MaxHeaderSize(n int64) int64 { return n + typicalHeaders*perFieldOverhead } -func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) { - if h2 == nil { - return - } +func fillNetHTTPConfig(conf *Config, h2 Config) { if h2.MaxConcurrentStreams != 0 { - conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) + conf.MaxConcurrentStreams = h2.MaxConcurrentStreams } - if http2ConfigStrictMaxConcurrentRequests(h2) { + if h2.StrictMaxConcurrentRequests { conf.StrictMaxConcurrentRequests = true } if h2.MaxEncoderHeaderTableSize != 0 { - conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize) + conf.MaxEncoderHeaderTableSize = h2.MaxEncoderHeaderTableSize } if h2.MaxDecoderHeaderTableSize != 0 { - conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize) + conf.MaxDecoderHeaderTableSize = h2.MaxDecoderHeaderTableSize } if h2.MaxConcurrentStreams != 0 { - conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) + conf.MaxConcurrentStreams = h2.MaxConcurrentStreams } if h2.MaxReadFrameSize != 0 { - conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize) + conf.MaxReadFrameSize = h2.MaxReadFrameSize } if h2.MaxReceiveBufferPerConnection != 0 { - conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection) + conf.MaxReceiveBufferPerConnection = h2.MaxReceiveBufferPerConnection } if h2.MaxReceiveBufferPerStream != 0 { - conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream) + conf.MaxReceiveBufferPerStream = h2.MaxReceiveBufferPerStream } if h2.SendPingTimeout != 0 { conf.SendPingTimeout = h2.SendPingTimeout diff --git a/src/net/http/internal/http2/config_go126.go b/src/net/http/internal/http2/config_go126.go deleted file mode 100644 index e1a9c63153..0000000000 --- a/src/net/http/internal/http2/config_go126.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "net/http" -) - -func http2ConfigStrictMaxConcurrentRequests(h2 *http.HTTP2Config) bool { - return h2.StrictMaxConcurrentRequests -} diff --git a/src/net/http/internal/http2/errors.go b/src/net/http/internal/http2/errors.go index f2067dabc5..35c34b7ba5 100644 --- a/src/net/http/internal/http2/errors.go +++ b/src/net/http/internal/http2/errors.go @@ -7,6 +7,7 @@ package http2 import ( "errors" "fmt" + "reflect" ) // An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. @@ -90,6 +91,33 @@ func (e StreamError) Error() string { return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) } +// This As function permits converting a StreamError into a x/net/http2.StreamError. +func (e StreamError) As(target any) bool { + dst := reflect.ValueOf(target).Elem() + dstType := dst.Type() + if dstType.Kind() != reflect.Struct { + return false + } + src := reflect.ValueOf(e) + srcType := src.Type() + numField := srcType.NumField() + if dstType.NumField() != numField { + return false + } + for i := 0; i < numField; i++ { + sf := srcType.Field(i) + df := dstType.Field(i) + if sf.Name != df.Name || !sf.Type.ConvertibleTo(df.Type) { + return false + } + } + for i := 0; i < numField; i++ { + df := dst.Field(i) + df.Set(src.Field(i).Convert(df.Type())) + } + return true +} + // 6.9.1 The Flow Control Window // "If a sender receives a WINDOW_UPDATE that causes a flow control // window to exceed this maximum it MUST terminate either the stream diff --git a/src/net/http/internal/http2/export_test.go b/src/net/http/internal/http2/export_test.go index f818ee8008..a9e230f5d0 100644 --- a/src/net/http/internal/http2/export_test.go +++ b/src/net/http/internal/http2/export_test.go @@ -8,7 +8,6 @@ import ( "context" "fmt" "net" - "net/http" "net/textproto" "sync" "testing" @@ -19,6 +18,10 @@ import ( "golang.org/x/net/http2/hpack" ) +func init() { + inTests = true +} + const ( DefaultMaxReadFrameSize = defaultMaxReadFrameSize DefaultMaxStreams = defaultMaxStreams @@ -68,7 +71,7 @@ func (sc *serverConn) TestFlowControlConsumed() (consumed int32) { donec := make(chan struct{}) sc.sendServeMsg(func(sc *serverConn) { defer close(donec) - initial := conf.MaxUploadBufferPerConnection + initial := int32(conf.MaxReceiveBufferPerConnection) avail := sc.inflow.avail + sc.inflow.unsent consumed = initial - avail }) @@ -117,16 +120,10 @@ func (t *Transport) TestSetNewClientConnHook(f func(*ClientConn)) { } } -func (t *Transport) TestTransport() *http.Transport { - if t.t1 == nil { - t.t1 = &http.Transport{} - } - return t.t1 -} - func (cc *ClientConn) TestNetConn() net.Conn { return cc.tconn } func (cc *ClientConn) TestSetNetConn(c net.Conn) { cc.tconn = c } -func (cc *ClientConn) TestRoundTrip(req *http.Request, f func(stremaID uint32)) (*http.Response, error) { + +func (cc *ClientConn) TestRoundTrip(req *ClientRequest, f func(stremaID uint32)) (*ClientResponse, error) { return cc.roundTrip(req, func(cs *clientStream) { f(cs.ID) }) @@ -237,6 +234,6 @@ func NewNoDialClientConnPool() ClientConnPool { return noDialClientConnPool{new(clientConnPool)} } -func EncodeRequestHeaders(req *http.Request, addGzipHeader bool, peerMaxHeaderListSize uint64, headerf func(name, value string)) (httpcommon.EncodeHeadersResult, error) { +func EncodeRequestHeaders(req *ClientRequest, addGzipHeader bool, peerMaxHeaderListSize uint64, headerf func(name, value string)) (httpcommon.EncodeHeadersResult, error) { return encodeRequestHeaders(req, addGzipHeader, peerMaxHeaderListSize, headerf) } diff --git a/src/net/http/internal/http2/http2.go b/src/net/http/internal/http2/http2.go index edf34e1f77..425e90c4f2 100644 --- a/src/net/http/internal/http2/http2.go +++ b/src/net/http/internal/http2/http2.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "net" - "net/http" "os" "sort" "strconv" @@ -43,6 +42,8 @@ var ( // // Issue #71128. disableExtendedConnectProtocol = true + + inTests = false ) func init() { @@ -224,11 +225,6 @@ func httpCodeString(code int) string { return strconv.Itoa(code) } -// from pkg io -type stringWriter interface { - WriteString(s string) (n int, err error) -} - // A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). type closeWaiter chan struct{} @@ -394,7 +390,7 @@ func (s *sorter) Less(i, j int) bool { return s.v[i] < s.v[j] } // // The returned slice is only valid until s used again or returned to // its pool. -func (s *sorter) Keys(h http.Header) []string { +func (s *sorter) Keys(h Header) []string { keys := s.v[:0] for k := range h { keys = append(keys, k) diff --git a/src/net/http/internal/http2/http2_test.go b/src/net/http/internal/http2/http2_test.go index 89003fd6b5..af82416046 100644 --- a/src/net/http/internal/http2/http2_test.go +++ b/src/net/http/internal/http2/http2_test.go @@ -7,7 +7,6 @@ package http2 import ( "flag" "fmt" - "net/http" "os" "path/filepath" "regexp" @@ -47,7 +46,7 @@ func TestSettingString(t *testing.T) { func TestSorterPoolAllocs(t *testing.T) { ss := []string{"a", "b", "c"} - h := http.Header{ + h := Header{ "a": nil, "b": nil, "c": nil, @@ -106,59 +105,6 @@ func equalError(a, b error) bool { return a.Error() == b.Error() } -// Tests that http2.Server.IdleTimeout is initialized from -// http.Server.{Idle,Read}Timeout. http.Server.IdleTimeout was -// added in Go 1.8. -func TestConfigureServerIdleTimeout_Go18(t *testing.T) { - const timeout = 5 * time.Second - const notThisOne = 1 * time.Second - - // With a zero http2.Server, verify that it copies IdleTimeout: - { - s1 := &http.Server{ - IdleTimeout: timeout, - ReadTimeout: notThisOne, - } - s2 := &Server{} - if err := ConfigureServer(s1, s2); err != nil { - t.Fatal(err) - } - if s2.IdleTimeout != timeout { - t.Errorf("s2.IdleTimeout = %v; want %v", s2.IdleTimeout, timeout) - } - } - - // And that it falls back to ReadTimeout: - { - s1 := &http.Server{ - ReadTimeout: timeout, - } - s2 := &Server{} - if err := ConfigureServer(s1, s2); err != nil { - t.Fatal(err) - } - if s2.IdleTimeout != timeout { - t.Errorf("s2.IdleTimeout = %v; want %v", s2.IdleTimeout, timeout) - } - } - - // Verify that s1's IdleTimeout doesn't overwrite an existing setting: - { - s1 := &http.Server{ - IdleTimeout: notThisOne, - } - s2 := &Server{ - IdleTimeout: timeout, - } - if err := ConfigureServer(s1, s2); err != nil { - t.Fatal(err) - } - if s2.IdleTimeout != timeout { - t.Errorf("s2.IdleTimeout = %v; want %v", s2.IdleTimeout, timeout) - } - } -} - var forbiddenStringsFunctions = map[string]bool{ // Functions that use Unicode-aware case folding. "EqualFold": true, diff --git a/src/net/http/internal/http2/netconn_test.go b/src/net/http/internal/http2/netconn_test.go index ef318a1abe..6cae2dd306 100644 --- a/src/net/http/internal/http2/netconn_test.go +++ b/src/net/http/internal/http2/netconn_test.go @@ -334,3 +334,36 @@ func (t *deadlineContext) setDeadline(deadline time.Time) { t.cancel = nil }) } + +type oneConnListener struct { + ch chan net.Conn + err error + once sync.Once + addr net.Addr +} + +func newOneConnListener(conn net.Conn) net.Listener { + ch := make(chan net.Conn, 1) + ch <- conn + return &oneConnListener{ch: ch} +} + +func (li *oneConnListener) Accept() (net.Conn, error) { + c := <-li.ch + if c == nil { + return nil, li.err + } + return c, nil +} + +func (li *oneConnListener) Close() error { + li.once.Do(func() { + li.err = errors.New("closed") + close(li.ch) + }) + return nil +} + +func (li *oneConnListener) Addr() net.Addr { + return li.addr +} diff --git a/src/net/http/internal/http2/server.go b/src/net/http/internal/http2/server.go index 98bec52343..a37a7280dd 100644 --- a/src/net/http/internal/http2/server.go +++ b/src/net/http/internal/http2/server.go @@ -37,7 +37,7 @@ import ( "log" "math" "net" - "net/http" + "net/http/internal" "net/http/internal/httpcommon" "net/textproto" "net/url" @@ -237,39 +237,18 @@ func (s *serverInternalState) putErrChan(ch chan error) { s.errChanPool.Put(ch) } -// ConfigureServer adds HTTP/2 support to a net/http Server. -// -// The configuration conf may be nil. -// -// ConfigureServer must be called before s begins serving. -func ConfigureServer(s *http.Server, conf *Server) error { - if s == nil { - panic("nil *http.Server") - } - if conf == nil { - conf = new(Server) - } - conf.state = &serverInternalState{ +func (s *Server) Configure(conf ServerConfig, tcfg *tls.Config) error { + s.state = &serverInternalState{ activeConns: make(map[*serverConn]struct{}), errChanPool: sync.Pool{New: func() any { return make(chan error, 1) }}, } - if h1, h2 := s, conf; h2.IdleTimeout == 0 { - if h1.IdleTimeout != 0 { - h2.IdleTimeout = h1.IdleTimeout - } else { - h2.IdleTimeout = h1.ReadTimeout - } - } - s.RegisterOnShutdown(conf.state.startGracefulShutdown) - if s.TLSConfig == nil { - s.TLSConfig = new(tls.Config) - } else if s.TLSConfig.CipherSuites != nil && s.TLSConfig.MinVersion < tls.VersionTLS13 { + if tcfg.CipherSuites != nil && tcfg.MinVersion < tls.VersionTLS13 { // If they already provided a TLS 1.0–1.2 CipherSuite list, return an // error if it is missing ECDHE_RSA_WITH_AES_128_GCM_SHA256 or // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256. haveRequired := false - for _, cs := range s.TLSConfig.CipherSuites { + for _, cs := range tcfg.CipherSuites { switch cs { case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, // Alternative MTI cipher to not discourage ECDSA-only servers. @@ -290,63 +269,14 @@ func ConfigureServer(s *http.Server, conf *Server) error { // during next-proto selection, but using TLS <1.2 with // HTTP/2 is still the client's bug. - s.TLSConfig.PreferServerCipherSuites = true - - if !strSliceContains(s.TLSConfig.NextProtos, NextProtoTLS) { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, NextProtoTLS) - } - if !strSliceContains(s.TLSConfig.NextProtos, "http/1.1") { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1") - } - - if s.TLSNextProto == nil { - s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} - } - protoHandler := func(hs *http.Server, c net.Conn, h http.Handler, sawClientPreface bool) { - if testHookOnConn != nil { - testHookOnConn() - } - // The TLSNextProto interface predates contexts, so - // the net/http package passes down its per-connection - // base context via an exported but unadvertised - // method on the Handler. This is for internal - // net/http<=>http2 use only. - var ctx context.Context - type baseContexter interface { - BaseContext() context.Context - } - if bc, ok := h.(baseContexter); ok { - ctx = bc.BaseContext() - } - conf.ServeConn(c, &ServeConnOpts{ - Context: ctx, - Handler: h, - BaseConfig: hs, - SawClientPreface: sawClientPreface, - }) - } - s.TLSNextProto[NextProtoTLS] = func(hs *http.Server, c *tls.Conn, h http.Handler) { - protoHandler(hs, c, h, false) - } - // The "unencrypted_http2" TLSNextProto key is used to pass off non-TLS HTTP/2 conns. - // - // A connection passed in this method has already had the HTTP/2 preface read from it. - s.TLSNextProto[nextProtoUnencryptedHTTP2] = func(hs *http.Server, c *tls.Conn, h http.Handler) { - nc, err := unencryptedNetConnFromTLSConn(c) - if err != nil { - if lg := hs.ErrorLog; lg != nil { - lg.Print(err) - } else { - log.Print(err) - } - go c.Close() - return - } - protoHandler(hs, nc, h, true) - } + tcfg.PreferServerCipherSuites = true return nil } +func (s *Server) GracefulShutdown() { + s.state.startGracefulShutdown() +} + // ServeConnOpts are options for the Server.ServeConn method. type ServeConnOpts struct { // Context is the base context to use. @@ -355,12 +285,12 @@ type ServeConnOpts struct { // BaseConfig optionally sets the base configuration // for values. If nil, defaults are used. - BaseConfig *http.Server + BaseConfig ServerConfig // Handler specifies which handler to use for processing // requests. If nil, BaseConfig.Handler is used. If BaseConfig // or BaseConfig.Handler is nil, http.DefaultServeMux is used. - Handler http.Handler + Handler Handler // Settings is the decoded contents of the HTTP2-Settings header // in an h2c upgrade request. @@ -378,25 +308,6 @@ func (o *ServeConnOpts) context() context.Context { return context.Background() } -func (o *ServeConnOpts) baseConfig() *http.Server { - if o != nil && o.BaseConfig != nil { - return o.BaseConfig - } - return new(http.Server) -} - -func (o *ServeConnOpts) handler() http.Handler { - if o != nil { - if o.Handler != nil { - return o.Handler - } - if o.BaseConfig != nil && o.BaseConfig.Handler != nil { - return o.BaseConfig.Handler - } - } - return http.DefaultServeMux -} - // ServeConn serves HTTP/2 requests on the provided connection and // blocks until the connection is no longer readable. // @@ -415,23 +326,36 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { if opts == nil { opts = &ServeConnOpts{} } - s.serveConn(c, opts, nil) + + var newf func(*serverConn) + if inTests { + // Fetch NewConnContextKey if set, leave newf as nil otherwise. + newf, _ = opts.Context.Value(NewConnContextKey).(func(*serverConn)) + } + + s.serveConn(c, opts, newf) } +type contextKey string + +var ( + NewConnContextKey = new("NewConnContextKey") + ConnectionStateContextKey = new("ConnectionStateContextKey") +) + func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) { baseCtx, cancel := serverConnBaseContext(c, opts) defer cancel() - http1srv := opts.baseConfig() - conf := configFromServer(http1srv, s) + conf := configFromServer(opts.BaseConfig, s) sc := &serverConn{ srv: s, - hs: http1srv, + hs: opts.BaseConfig, conn: c, baseCtx: baseCtx, remoteAddrStr: c.RemoteAddr().String(), bw: newBufferedWriter(c, conf.WriteByteTimeout), - handler: opts.handler(), + handler: opts.Handler, streams: make(map[uint32]*stream), readFrameCh: make(chan readFrameResult), wantWriteFrameCh: make(chan FrameWriteRequest, 8), @@ -440,9 +364,9 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way doneServing: make(chan struct{}), clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" - advMaxStreams: conf.MaxConcurrentStreams, + advMaxStreams: uint32(conf.MaxConcurrentStreams), initialStreamSendWindowSize: initialWindowSize, - initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream, + initialStreamRecvWindowSize: int32(conf.MaxReceiveBufferPerStream), maxFrameSize: initialMaxFrameSize, pingTimeout: conf.PingTimeout, countErrorFunc: conf.CountError, @@ -462,14 +386,14 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon // passes the connection off to us with the deadline already set. // Write deadlines are set per stream in serverConn.newStream. // Disarm the net.Conn write deadline here. - if sc.hs.WriteTimeout > 0 { + if sc.hs.WriteTimeout() > 0 { sc.conn.SetWriteDeadline(time.Time{}) } switch { case s.NewWriteScheduler != nil: sc.writeSched = s.NewWriteScheduler() - case clientPriorityDisabled(http1srv): + case sc.hs.DisableClientPriority(): sc.writeSched = newRoundRobinWriteScheduler() default: sc.writeSched = newPriorityWriteSchedulerRFC9218() @@ -481,20 +405,29 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon sc.flow.add(initialWindowSize) sc.inflow.init(initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) - sc.hpackEncoder.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize) + sc.hpackEncoder.SetMaxDynamicTableSizeLimit(uint32(conf.MaxEncoderHeaderTableSize)) fr := NewFramer(sc.bw, c) if conf.CountError != nil { fr.countError = conf.CountError } - fr.ReadMetaHeaders = hpack.NewDecoder(conf.MaxDecoderHeaderTableSize, nil) + fr.ReadMetaHeaders = hpack.NewDecoder(uint32(conf.MaxDecoderHeaderTableSize), nil) fr.MaxHeaderListSize = sc.maxHeaderListSize() - fr.SetMaxReadFrameSize(conf.MaxReadFrameSize) + fr.SetMaxReadFrameSize(uint32(conf.MaxReadFrameSize)) sc.framer = fr if tc, ok := c.(connectionStater); ok { sc.tlsState = new(tls.ConnectionState) *sc.tlsState = tc.ConnectionState() + + // Optionally override the ConnectionState in tests. + if inTests { + f, ok := opts.Context.Value(ConnectionStateContextKey).(func() tls.ConnectionState) + if ok { + *sc.tlsState = f() + } + } + // 9.2 Use of TLS Features // An implementation of HTTP/2 over TLS MUST use TLS // 1.2 or higher with the restrictions on feature set @@ -558,12 +491,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon } func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) { - ctx, cancel = context.WithCancel(opts.context()) - ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr()) - if hs := opts.baseConfig(); hs != nil { - ctx = context.WithValue(ctx, http.ServerContextKey, hs) - } - return + return context.WithCancel(opts.context()) } func (sc *serverConn) rejectConn(err ErrCode, debug string) { @@ -577,10 +505,10 @@ func (sc *serverConn) rejectConn(err ErrCode, debug string) { type serverConn struct { // Immutable: srv *Server - hs *http.Server + hs ServerConfig conn net.Conn bw *bufferedWriter // writing to conn - handler http.Handler + handler Handler baseCtx context.Context framer *Framer doneServing chan struct{} // closed when serverConn.serve ends @@ -659,10 +587,12 @@ func (sc *serverConn) writeSchedIgnoresRFC7540() bool { } } +const DefaultMaxHeaderBytes = 1 << 20 // keep this in sync with net/http + func (sc *serverConn) maxHeaderListSize() uint32 { - n := sc.hs.MaxHeaderBytes + n := sc.hs.MaxHeaderBytes() if n <= 0 { - n = http.DefaultMaxHeaderBytes + n = DefaultMaxHeaderBytes } return uint32(adjustHTTP1MaxHeaderSize(int64(n))) } @@ -701,8 +631,8 @@ type stream struct { writeDeadline *time.Timer // nil if unused closeErr error // set before cw is closed - trailer http.Header // accumulated trailers - reqTrailer http.Header // handler's Request.Trailer + trailer Header // accumulated trailers + reqTrailer Header // handler's Request.Trailer } func (sc *serverConn) Framer() *Framer { return sc.framer } @@ -739,10 +669,8 @@ func (sc *serverConn) state(streamID uint32) (streamState, *stream) { // setConnState calls the net/http ConnState hook for this connection, if configured. // Note that the net/http package does StateNew and StateClosed for us. // There is currently no plan for StateHijacked or hijacking HTTP/2 connections. -func (sc *serverConn) setConnState(state http.ConnState) { - if sc.hs.ConnState != nil { - sc.hs.ConnState(sc.conn, state) - } +func (sc *serverConn) setConnState(state ConnState) { + sc.hs.ConnState(sc.conn, state) } func (sc *serverConn) vlogf(format string, args ...interface{}) { @@ -752,7 +680,7 @@ func (sc *serverConn) vlogf(format string, args ...interface{}) { } func (sc *serverConn) logf(format string, args ...interface{}) { - if lg := sc.hs.ErrorLog; lg != nil { + if lg := sc.hs.ErrorLog(); lg != nil { lg.Printf(format, args...) } else { log.Printf(format, args...) @@ -831,7 +759,7 @@ func (sc *serverConn) canonicalHeader(v string) string { if sc.canonHeader == nil { sc.canonHeader = make(map[string]string) } - cv = http.CanonicalHeaderKey(v) + cv = textproto.CanonicalMIMEHeaderKey(v) size := 100 + len(v)*2 // 100 bytes of map overhead + key + value if sc.canonHeaderKeysSize+size <= maxCachedCanonicalHeadersKeysSize { sc.canonHeader[v] = cv @@ -925,7 +853,7 @@ func (sc *serverConn) notePanic() { } } -func (sc *serverConn) serve(conf http2Config) { +func (sc *serverConn) serve(conf Config) { sc.serveG.check() defer sc.notePanic() defer sc.conn.Close() @@ -938,10 +866,10 @@ func (sc *serverConn) serve(conf http2Config) { } settings := writeSettings{ - {SettingMaxFrameSize, conf.MaxReadFrameSize}, + {SettingMaxFrameSize, uint32(conf.MaxReadFrameSize)}, {SettingMaxConcurrentStreams, sc.advMaxStreams}, {SettingMaxHeaderListSize, sc.maxHeaderListSize()}, - {SettingHeaderTableSize, conf.MaxDecoderHeaderTableSize}, + {SettingHeaderTableSize, uint32(conf.MaxDecoderHeaderTableSize)}, {SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)}, } if !disableExtendedConnectProtocol { @@ -957,7 +885,7 @@ func (sc *serverConn) serve(conf http2Config) { // Each connection starts with initialWindowSize inflow tokens. // If a higher value is configured, we add more tokens. - if diff := conf.MaxUploadBufferPerConnection - initialWindowSize; diff > 0 { + if diff := conf.MaxReceiveBufferPerConnection - initialWindowSize; diff > 0 { sc.sendWindowUpdate(nil, int(diff)) } @@ -969,8 +897,8 @@ func (sc *serverConn) serve(conf http2Config) { // "StateNew" state. We can't go directly to idle, though. // Active means we read some data and anticipate a request. We'll // do another Active when we get a HEADERS frame. - sc.setConnState(http.StateActive) - sc.setConnState(http.StateIdle) + sc.setConnState(ConnStateActive) + sc.setConnState(ConnStateIdle) if sc.srv.IdleTimeout > 0 { sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) @@ -1737,7 +1665,7 @@ func (sc *serverConn) closeStream(st *stream, err error) { } delete(sc.streams, st.id) if len(sc.streams) == 0 { - sc.setConnState(http.StateIdle) + sc.setConnState(ConnStateIdle) if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil { sc.idleTimer.Reset(sc.srv.IdleTimeout) } @@ -2126,7 +2054,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { } st.reqTrailer = req.Trailer if st.reqTrailer != nil { - st.trailer = make(http.Header) + st.trailer = make(Header) } st.body = req.Body.(*requestBody).pipe // may be nil st.declBodyBytes = req.ContentLength @@ -2136,7 +2064,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // Their header list was too long. Send a 431 error. handler = handleHeaderListTooLong } else if err := checkValidHTTP2RequestHeaders(req.Header); err != nil { - handler = new400Handler(err) + handler = serve400Handler{err}.ServeHTTP } // The net/http package sets the read deadline from the @@ -2146,9 +2074,9 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // similar to how the http1 server works. Here it's // technically more like the http1 Server's ReadHeaderTimeout // (in Go 1.8), though. That's a more sane option anyway. - if sc.hs.ReadTimeout > 0 { + if sc.hs.ReadTimeout() > 0 { sc.conn.SetReadDeadline(time.Time{}) - st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) + st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout(), st.onReadTimeout) } return sc.scheduleHandler(id, rw, req, handler) @@ -2242,8 +2170,8 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState, priority st.flow.conn = &sc.flow // link to conn-level counter st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.initialStreamRecvWindowSize) - if sc.hs.WriteTimeout > 0 { - st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) + if writeTimeout := sc.hs.WriteTimeout(); writeTimeout > 0 { + st.writeDeadline = time.AfterFunc(writeTimeout, st.onWriteTimeout) } sc.streams[id] = st @@ -2254,13 +2182,13 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState, priority sc.curClientStreams++ } if sc.curOpenStreams() == 1 { - sc.setConnState(http.StateActive) + sc.setConnState(ConnStateActive) } return st } -func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) { +func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *ServerRequest, error) { sc.serveG.check() rp := httpcommon.ServerRequestParam{ @@ -2295,7 +2223,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol)) } - header := make(http.Header) + header := make(Header) rp.Header = header for _, hf := range f.RegularFields() { header.Add(sc.canonicalHeader(hf.Name), hf.Value) @@ -2329,7 +2257,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res return rw, req, nil } -func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *http.Request, error) { +func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *ServerRequest, error) { sc.serveG.check() var tlsState *tls.ConnectionState // nil if not scheme https @@ -2347,7 +2275,9 @@ func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.Server stream: st, needsContinue: res.NeedsContinue, } - req := (&http.Request{ + rw := sc.newResponseWriter(st) + rw.rws.req = ServerRequest{ + Context: st.ctx, Method: rp.Method, URL: res.URL, RemoteAddr: sc.remoteAddrStr, @@ -2360,12 +2290,11 @@ func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.Server Host: rp.Authority, Body: body, Trailer: res.Trailer, - }).WithContext(st.ctx) - rw := sc.newResponseWriter(st, req) - return rw, req, nil + } + return rw, &rw.rws.req, nil } -func (sc *serverConn) newResponseWriter(st *stream, req *http.Request) *responseWriter { +func (sc *serverConn) newResponseWriter(st *stream) *responseWriter { rws := responseWriterStatePool.Get().(*responseWriterState) bwSave := rws.bw *rws = responseWriterState{} // zero all the fields @@ -2373,20 +2302,19 @@ func (sc *serverConn) newResponseWriter(st *stream, req *http.Request) *response rws.bw = bwSave rws.bw.Reset(chunkWriter{rws}) rws.stream = st - rws.req = req return &responseWriter{rws: rws} } type unstartedHandler struct { streamID uint32 rw *responseWriter - req *http.Request - handler func(http.ResponseWriter, *http.Request) + req *ServerRequest + handler func(*ResponseWriter, *ServerRequest) } // scheduleHandler starts a handler goroutine, // or schedules one to start as soon as an existing handler finishes. -func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) error { +func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *ServerRequest, handler func(*ResponseWriter, *ServerRequest)) error { sc.serveG.check() maxHandlers := sc.advMaxStreams if sc.curHandlers < maxHandlers { @@ -2431,7 +2359,7 @@ func (sc *serverConn) handlerDone() { } // Run on its own goroutine. -func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { +func (sc *serverConn) runHandler(rw *responseWriter, req *ServerRequest, handler func(*ResponseWriter, *ServerRequest)) { defer sc.sendServeMsg(handlerDoneMsg) didPanic := true defer func() { @@ -2446,7 +2374,7 @@ func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler stream: rw.rws.stream, }) // Same as net/http: - if e != nil && e != http.ErrAbortHandler { + if e != nil && e != ErrAbortHandler { const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] @@ -2460,7 +2388,7 @@ func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler didPanic = false } -func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) { +func handleHeaderListTooLong(w *ResponseWriter, r *ServerRequest) { // 10.5.1 Limits on Header Block Size: // .. "A server that receives a larger header block than it is // willing to handle can send an HTTP 431 (Request Header Fields Too @@ -2616,30 +2544,23 @@ type responseWriter struct { rws *responseWriterState } -// Optional http.ResponseWriter interfaces implemented. -var ( - _ http.CloseNotifier = (*responseWriter)(nil) - _ http.Flusher = (*responseWriter)(nil) - _ stringWriter = (*responseWriter)(nil) -) - type responseWriterState struct { // immutable within a request: stream *stream - req *http.Request + req ServerRequest conn *serverConn // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState} // mutated by http.Handler goroutine: - handlerHeader http.Header // nil until called - snapHeader http.Header // snapshot of handlerHeader at WriteHeader time - trailers []string // set in writeChunk - status int // status code passed to WriteHeader - wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. - sentHeader bool // have we sent the header frame? - handlerDone bool // handler has finished + handlerHeader Header // nil until called + snapHeader Header // snapshot of handlerHeader at WriteHeader time + trailers []string // set in writeChunk + status int // status code passed to WriteHeader + wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. + sentHeader bool // have we sent the header frame? + handlerDone bool // handler has finished sentContentLen int64 // non-zero if handler set a Content-Length header wroteBytes int64 @@ -2675,7 +2596,7 @@ func (rws *responseWriterState) hasNonemptyTrailers() bool { // response header is written. It notes that a header will need to be // written in the trailers at the end of the response. func (rws *responseWriterState) declareTrailer(k string) { - k = http.CanonicalHeaderKey(k) + k = textproto.CanonicalMIMEHeaderKey(k) if !httpguts.ValidTrailerHeader(k) { // Forbidden by RFC 7230, section 4.1.2. rws.conn.logf("ignoring invalid trailer %q", k) @@ -2686,6 +2607,8 @@ func (rws *responseWriterState) declareTrailer(k string) { } } +const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" // keep in sync with net/http + // writeChunk writes chunks from the bufio.Writer. But because // bufio.Writer may bypass its chunking, sometimes p may be // arbitrarily large. @@ -2723,12 +2646,12 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) { ce := rws.snapHeader.Get("Content-Encoding") hasCE := len(ce) > 0 if !hasCE && !hasContentType && bodyAllowedForStatus(rws.status) && len(p) > 0 { - ctype = http.DetectContentType(p) + ctype = internal.DetectContentType(p) } var date string if _, ok := rws.snapHeader["Date"]; !ok { // TODO(bradfitz): be faster here, like net/http? measure. - date = time.Now().UTC().Format(http.TimeFormat) + date = time.Now().UTC().Format(TimeFormat) } for _, v := range rws.snapHeader["Trailer"] { @@ -2838,7 +2761,7 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() { } trailerKey := strings.TrimPrefix(k, TrailerPrefix) rws.declareTrailer(trailerKey) - rws.handlerHeader[http.CanonicalHeaderKey(trailerKey)] = vv + rws.handlerHeader[textproto.CanonicalMIMEHeaderKey(trailerKey)] = vv } if len(rws.trailers) > 1 { @@ -2954,13 +2877,13 @@ func (w *responseWriter) CloseNotify() <-chan bool { return ch } -func (w *responseWriter) Header() http.Header { +func (w *responseWriter) Header() Header { rws := w.rws if rws == nil { panic("Header called after Handler finished") } if rws.handlerHeader == nil { - rws.handlerHeader = make(http.Header) + rws.handlerHeader = make(Header) } return rws.handlerHeader } @@ -3005,7 +2928,7 @@ func (rws *responseWriterState) writeHeader(code int) { _, cl := h["Content-Length"] _, te := h["Transfer-Encoding"] if cl || te { - h = h.Clone() + h = cloneHeader(h) h.Del("Content-Length") h.Del("Transfer-Encoding") } @@ -3027,8 +2950,8 @@ func (rws *responseWriterState) writeHeader(code int) { } } -func cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) +func cloneHeader(h Header) Header { + h2 := make(Header, len(h)) for k, vv := range h { vv2 := make([]string, len(vv)) copy(vv2, vv) @@ -3063,7 +2986,7 @@ func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int, w.WriteHeader(200) } if !bodyAllowedForStatus(rws.status) { - return 0, http.ErrBodyNotAllowed + return 0, ErrBodyNotAllowed } rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen { @@ -3092,9 +3015,7 @@ var ( ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") ) -var _ http.Pusher = (*responseWriter)(nil) - -func (w *responseWriter) Push(target string, opts *http.PushOptions) error { +func (w *responseWriter) Push(target, method string, header Header) error { st := w.rws.stream sc := st.sc sc.serveG.checkNotOn() @@ -3105,16 +3026,12 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error { return ErrRecursivePush } - if opts == nil { - opts = new(http.PushOptions) - } - // Default options. - if opts.Method == "" { - opts.Method = "GET" + if method == "" { + method = "GET" } - if opts.Header == nil { - opts.Header = http.Header{} + if header == nil { + header = Header{} } wantScheme := "http" if w.rws.req.TLS != nil { @@ -3140,7 +3057,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error { return errors.New("URL must have a host") } } - for k := range opts.Header { + for k := range header { if strings.HasPrefix(k, ":") { return fmt.Errorf("promised request headers cannot include pseudo header %q", k) } @@ -3157,22 +3074,22 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error { return fmt.Errorf("promised request headers cannot include %q", k) } } - if err := checkValidHTTP2RequestHeaders(opts.Header); err != nil { + if err := checkValidHTTP2RequestHeaders(header); err != nil { return err } // The RFC effectively limits promised requests to GET and HEAD: // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]" // http://tools.ietf.org/html/rfc7540#section-8.2 - if opts.Method != "GET" && opts.Method != "HEAD" { - return fmt.Errorf("method %q must be GET or HEAD", opts.Method) + if method != "GET" && method != "HEAD" { + return fmt.Errorf("method %q must be GET or HEAD", method) } msg := &startPushRequest{ parent: st, - method: opts.Method, + method: method, url: u, - header: cloneHeader(opts.Header), + header: cloneHeader(header), done: sc.srv.state.getErrChan(), } @@ -3199,7 +3116,7 @@ type startPushRequest struct { parent *stream method string url *url.URL - header http.Header + header Header done chan error } @@ -3217,7 +3134,7 @@ func (sc *serverConn) startPush(msg *startPushRequest) { // http://tools.ietf.org/html/rfc7540#section-6.6. if !sc.pushEnabled { - msg.done <- http.ErrNotSupported + msg.done <- ErrNotSupported return } @@ -3230,7 +3147,7 @@ func (sc *serverConn) startPush(msg *startPushRequest) { // Check this again, just in case. Technically, we might have received // an updated SETTINGS by the time we got around to writing this frame. if !sc.pushEnabled { - return 0, http.ErrNotSupported + return 0, ErrNotSupported } // http://tools.ietf.org/html/rfc7540#section-6.5.2. if sc.curPushedStreams+1 > sc.clientMaxStreams { @@ -3314,7 +3231,7 @@ var connHeaders = []string{ // checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, // per RFC 7540 Section 8.1.2.2. // The returned error is reported to users. -func checkValidHTTP2RequestHeaders(h http.Header) error { +func checkValidHTTP2RequestHeaders(h Header) error { for _, k := range connHeaders { if _, ok := h[k]; ok { return fmt.Errorf("request header %q is not valid in HTTP/2", k) @@ -3327,24 +3244,27 @@ func checkValidHTTP2RequestHeaders(h http.Header) error { return nil } -func new400Handler(err error) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - http.Error(w, err.Error(), http.StatusBadRequest) - } +type serve400Handler struct { + err error +} + +func (handler serve400Handler) ServeHTTP(w *ResponseWriter, r *ServerRequest) { + const statusBadRequest = 400 + + // TODO: Dedup with http.Error? + h := w.Header() + h.Del("Content-Length") + h.Set("Content-Type", "text/plain; charset=utf-8") + h.Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(statusBadRequest) + fmt.Fprintln(w, handler.err.Error()) } // h1ServerKeepAlivesDisabled reports whether hs has its keep-alives // disabled. See comments on h1ServerShutdownChan above for why // the code is written this way. -func h1ServerKeepAlivesDisabled(hs *http.Server) bool { - var x interface{} = hs - type I interface { - doKeepAlives() bool - } - if hs, ok := x.(I); ok { - return !hs.doKeepAlives() - } - return false +func h1ServerKeepAlivesDisabled(hs ServerConfig) bool { + return !hs.DoKeepAlives() } func (sc *serverConn) countError(name string, err error) error { diff --git a/src/net/http/internal/http2/server_internal_test.go b/src/net/http/internal/http2/server_internal_test.go index 763976e159..379b5de46d 100644 --- a/src/net/http/internal/http2/server_internal_test.go +++ b/src/net/http/internal/http2/server_internal_test.go @@ -7,42 +7,41 @@ package http2 import ( "errors" "fmt" - "net/http" "strings" "testing" ) func TestCheckValidHTTP2Request(t *testing.T) { tests := []struct { - h http.Header + h Header want error }{ { - h: http.Header{"Te": {"trailers"}}, + h: Header{"Te": {"trailers"}}, want: nil, }, { - h: http.Header{"Te": {"trailers", "bogus"}}, + h: Header{"Te": {"trailers", "bogus"}}, want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`), }, { - h: http.Header{"Foo": {""}}, + h: Header{"Foo": {""}}, want: nil, }, { - h: http.Header{"Connection": {""}}, + h: Header{"Connection": {""}}, want: errors.New(`request header "Connection" is not valid in HTTP/2`), }, { - h: http.Header{"Proxy-Connection": {""}}, + h: Header{"Proxy-Connection": {""}}, want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`), }, { - h: http.Header{"Keep-Alive": {""}}, + h: Header{"Keep-Alive": {""}}, want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`), }, { - h: http.Header{"Upgrade": {""}}, + h: Header{"Upgrade": {""}}, want: errors.New(`request header "Upgrade" is not valid in HTTP/2`), }, } diff --git a/src/net/http/internal/http2/server_test.go b/src/net/http/internal/http2/server_test.go index c793cf8444..11edfcb3be 100644 --- a/src/net/http/internal/http2/server_test.go +++ b/src/net/http/internal/http2/server_test.go @@ -30,7 +30,9 @@ import ( "testing" "testing/synctest" "time" + _ "unsafe" // for go:linkname + "net/http/internal/http2" . "net/http/internal/http2" "net/http/internal/testcert" @@ -79,6 +81,7 @@ type serverTester struct { logFilter []string // substrings to filter out scMu sync.Mutex // guards sc sc *ServerConn + wrotePreface bool testConnFramer callsMu sync.Mutex @@ -125,7 +128,6 @@ func newTestServer(t testing.TB, handler http.HandlerFunc, opts ...interface{}) ts.EnableHTTP2 = true ts.Config.ErrorLog = log.New(twriter{t: t}, "", log.LstdFlags) ts.Config.Protocols = protocols("h2") - h2server := new(Server) for _, opt := range opts { switch v := opt.(type) { case func(*httptest.Server): @@ -141,7 +143,6 @@ func newTestServer(t testing.TB, handler http.HandlerFunc, opts ...interface{}) t.Fatalf("unknown newTestServer option type %T", v) } } - ConfigureServer(ts.Config, h2server) if ts.Config.Protocols.HTTP2() { ts.TLS = testServerTLSConfig @@ -176,12 +177,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} t.Helper() h1server := &http.Server{} - h2server := &Server{} - tlsState := tls.ConnectionState{ - Version: tls.VersionTLS13, - ServerName: "go.dev", - CipherSuite: tls.TLS_AES_128_GCM_SHA256, - } + var tlsState *tls.ConnectionState for _, opt := range opts { switch v := opt.(type) { case func(*http.Server): @@ -192,21 +188,51 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} } v(h1server.HTTP2) case func(*tls.ConnectionState): - v(&tlsState) + if tlsState == nil { + tlsState = &tls.ConnectionState{ + Version: tls.VersionTLS13, + ServerName: "go.dev", + CipherSuite: tls.TLS_AES_128_GCM_SHA256, + } + } + v(tlsState) default: t.Fatalf("unknown newServerTester option type %T", v) } } - ConfigureServer(h1server, h2server) - cli, srv := synctestNetPipe() - cli.SetReadDeadline(time.Now()) + tlsConfig := h1server.TLSConfig + if tlsConfig == nil { + cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) + if err != nil { + t.Fatal(err) + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + } + h1server.TLSConfig = tlsConfig + } + + var cli, srv net.Conn + + cliPipe, srvPipe := synctestNetPipe() + + if h1server.Protocols != nil && h1server.Protocols.UnencryptedHTTP2() { + cli, srv = cliPipe, srvPipe + } else { + cli = tls.Client(cliPipe, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + }) + srv = tls.Server(srvPipe, tlsConfig) + } st := &serverTester{ t: t, cc: cli, h1server: h1server, - h2server: h2server, } st.hpackEnc = hpack.NewEncoder(&st.headerBuf) if h1server.ErrorLog == nil { @@ -224,23 +250,42 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} }) connc := make(chan *ServerConn) - go func() { - h2server.TestServeConn(&netConnWithConnectionState{ - Conn: srv, - state: tlsState, - }, &ServeConnOpts{ - Handler: handler, - BaseConfig: h1server, - }, func(sc *ServerConn) { + h1server.ConnContext = func(ctx context.Context, conn net.Conn) context.Context { + ctx = context.WithValue(ctx, NewConnContextKey, func(sc *ServerConn) { connc <- sc }) + if tlsState != nil { + ctx = context.WithValue(ctx, ConnectionStateContextKey, func() tls.ConnectionState { + return *tlsState + }) + } + return ctx + } + go func() { + li := newOneConnListener(srv) + t.Cleanup(func() { + li.Close() + }) + h1server.Serve(li) }() + if cliTLS, ok := cli.(*tls.Conn); ok { + if err := cliTLS.Handshake(); err != nil { + t.Fatalf("client TLS handshake: %v", err) + } + cliTLS.SetReadDeadline(time.Now()) + } else { + // Confusing but difficult to fix: Preface must be written + // before the conn appears on connc. + st.writePreface() + st.wrotePreface = true + cliPipe.SetReadDeadline(time.Now()) + } st.sc = <-connc st.fr = NewFramer(st.cc, st.cc) st.testConnFramer = testConnFramer{ t: t, - fr: NewFramer(st.cc, st.cc), + fr: NewFramer(cli, cli), dec: hpack.NewDecoder(InitialHeaderTableSize, nil), } synctest.Wait() @@ -256,6 +301,10 @@ func (c *netConnWithConnectionState) ConnectionState() tls.ConnectionState { return c.state } +func (c *netConnWithConnectionState) HandshakeContext() tls.ConnectionState { + return c.state +} + type serverTesterHandler struct { st *serverTester } @@ -337,8 +386,6 @@ func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts .. } } - ConfigureServer(ts.Config, h2server) - // Go 1.22 changes the default minimum TLS version to TLS 1.2, // in order to properly test cases where we want to reject low // TLS versions, we need to explicitly configure the minimum @@ -514,6 +561,9 @@ func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error } func (st *serverTester) writePreface() { + if st.wrotePreface { + return + } n, err := st.cc.Write([]byte(ClientPreface)) if err != nil { st.t.Fatalf("Error writing client preface: %v", err) @@ -1256,7 +1306,7 @@ func testServer_MaxQueuedControlFrames(t testing.TB) { st := newServerTester(t, nil) st.greet() - st.cc.(*synctestNetConn).SetReadBufferSize(0) // all writes block + st.cc.(*tls.Conn).NetConn().(*synctestNetConn).SetReadBufferSize(0) // all writes block // Send maxQueuedControlFrames pings, plus a few extra // to account for ones that enter the server's write buffer. @@ -1269,7 +1319,7 @@ func testServer_MaxQueuedControlFrames(t testing.TB) { // Unblock the server. // It should have closed the connection after exceeding the control frame limit. - st.cc.(*synctestNetConn).SetReadBufferSize(math.MaxInt) + st.cc.(*tls.Conn).NetConn().(*synctestNetConn).SetReadBufferSize(math.MaxInt) st.advance(GoAwayTimeout) // Some frames may have persisted in the server's buffers. @@ -3270,8 +3320,6 @@ func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) { }) } -// go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53 -// Verify we don't hang. func TestIssue53(t *testing.T) { synctestTest(t, testIssue53) } func testIssue53(t testing.TB) { const data = "PRI * HTTP/2.0\r\n\r\nSM" + @@ -3279,6 +3327,7 @@ func testIssue53(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) + st.cc.Write([]byte(data)) st.wantFrameType(FrameSettings) st.wantFrameType(FrameWindowUpdate) @@ -3287,18 +3336,13 @@ func testIssue53(t testing.TB) { st.wantClosed() } -// golang.org/issue/12895 -func TestConfigureServer(t *testing.T) { synctestTest(t, testConfigureServer) } -func testConfigureServer(t testing.TB) { +func TestServerServeNoBannedCiphers(t *testing.T) { tests := []struct { name string tlsConfig *tls.Config wantErr string }{ { - name: "empty server", - }, - { name: "empty CipherSuites", tlsConfig: &tls.Config{}, }, @@ -3342,9 +3386,15 @@ func testConfigureServer(t testing.TB) { }, } for _, tt := range tests { - srv := &http.Server{TLSConfig: tt.tlsConfig} - err := ConfigureServer(srv, nil) - if (err != nil) != (tt.wantErr != "") { + tt.tlsConfig.Certificates = testServerTLSConfig.Certificates + + srv := &http.Server{ + TLSConfig: tt.tlsConfig, + Protocols: protocols("h2"), + } + + err := srv.ServeTLS(errListener{}, "", "") + if (err != net.ErrClosed) != (tt.wantErr != "") { if tt.wantErr != "" { t.Errorf("%s: success, but want error", tt.name) } else { @@ -3360,6 +3410,12 @@ func testConfigureServer(t testing.TB) { } } +type errListener struct{} + +func (li errListener) Accept() (net.Conn, error) { return nil, net.ErrClosed } +func (li errListener) Close() error { return nil } +func (li errListener) Addr() net.Addr { return nil } + func TestServerNoAutoContentLengthOnHead(t *testing.T) { synctestTest(t, testServerNoAutoContentLengthOnHead) } @@ -4024,7 +4080,12 @@ func testServerGracefulShutdown(t testing.TB) { st.bodylessReq1() st.sync() - st.h1server.Shutdown(context.Background()) + + shutdownc := make(chan struct{}) + go func() { + defer close(shutdownc) + st.h1server.Shutdown(context.Background()) + }() st.wantGoAway(1, ErrCodeNo) @@ -4045,6 +4106,9 @@ func testServerGracefulShutdown(t testing.TB) { if n != 0 || err == nil { t.Errorf("Read = %v, %v; want 0, non-nil", n, err) } + + // Shutdown happens after GoAwayTimeout and net/http.Server polling delay. + <-shutdownc } // Issue 31753: don't sniff when Content-Encoding is set @@ -5171,6 +5235,15 @@ func testServerRFC9218PriorityAware(t testing.TB) { } } +func TestConsistentConstants(t *testing.T) { + if h1, h2 := http.DefaultMaxHeaderBytes, http2.DefaultMaxHeaderBytes; h1 != h2 { + t.Errorf("DefaultMaxHeaderBytes: http (%v) != http2 (%v)", h1, h2) + } + if h1, h2 := http.TimeFormat, http2.TimeFormat; h1 != h2 { + t.Errorf("TimeFormat: http (%v) != http2 (%v)", h1, h2) + } +} + var ( testServerTLSConfig *tls.Config testClientTLSConfig *tls.Config @@ -5215,3 +5288,6 @@ func protocols(protos ...string) *http.Protocols { } return p } + +//go:linkname transportFromH1Transport +func transportFromH1Transport(tr *http.Transport) any diff --git a/src/net/http/internal/http2/transport.go b/src/net/http/internal/http2/transport.go index 4fd5070701..6caba5046a 100644 --- a/src/net/http/internal/http2/transport.go +++ b/src/net/http/internal/http2/transport.go @@ -23,10 +23,11 @@ import ( "math/bits" mathrand "math/rand" "net" - "net/http" "net/http/httptrace" + "net/http/internal" "net/http/internal/httpcommon" "net/textproto" + "slices" "strconv" "strings" "sync" @@ -177,10 +178,7 @@ type Transport struct { // The errType consists of only ASCII word characters. CountError func(errType string) - // t1, if non-nil, is the standard library Transport using - // this transport. Its settings are used (but not its - // RoundTrip method, etc). - t1 *http.Transport + t1 TransportConfig connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool @@ -197,12 +195,9 @@ type transportTestHooks struct { } func (t *Transport) maxHeaderListSize() uint32 { - n := int64(t.MaxHeaderListSize) - if t.t1 != nil && t.t1.MaxResponseHeaderBytes != 0 { - n = t.t1.MaxResponseHeaderBytes - if n > 0 { - n = adjustHTTP1MaxHeaderSize(n) - } + n := t.t1.MaxResponseHeaderBytes() + if n > 0 { + n = adjustHTTP1MaxHeaderSize(n) } if n <= 0 { return 10 << 20 @@ -214,84 +209,38 @@ func (t *Transport) maxHeaderListSize() uint32 { } func (t *Transport) disableCompression() bool { - return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) -} - -// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. -// It returns an error if t1 has already been HTTP/2-enabled. -// -// Use ConfigureTransports instead to configure the HTTP/2 Transport. -func ConfigureTransport(t1 *http.Transport) error { - _, err := ConfigureTransports(t1) - return err -} - -// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2. -// It returns a new HTTP/2 Transport for further configuration. -// It returns an error if t1 has already been HTTP/2-enabled. -func ConfigureTransports(t1 *http.Transport) (*Transport, error) { - return configureTransports(t1) + return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression()) } -func configureTransports(t1 *http.Transport) (*Transport, error) { +func NewTransport(t1 TransportConfig) *Transport { connPool := new(clientConnPool) t2 := &Transport{ ConnPool: noDialClientConnPool{connPool}, t1: t1, } connPool.t = t2 - if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { - return nil, err - } - if t1.TLSClientConfig == nil { - t1.TLSClientConfig = new(tls.Config) - } - if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { - t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) - } - if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { - t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") - } - upgradeFn := func(scheme, authority string, c net.Conn) http.RoundTripper { - addr := authorityAddr(scheme, authority) - if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { - go c.Close() - return erringRoundTripper{err} - } else if !used { - // Turns out we don't need this c. - // For example, two goroutines made requests to the same host - // at the same time, both kicking off TCP dials. (since protocol - // was unknown) - go c.Close() - } - if scheme == "http" { - return (*unencryptedTransport)(t2) - } - return t2 - } - if t1.TLSNextProto == nil { - t1.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) - } - t1.TLSNextProto[NextProtoTLS] = func(authority string, c *tls.Conn) http.RoundTripper { - return upgradeFn("https", authority, c) + return t2 +} + +func (t *Transport) AddConn(scheme, authority string, c net.Conn) error { + connPool, ok := t.ConnPool.(noDialClientConnPool) + if !ok { + go c.Close() + return nil } - // The "unencrypted_http2" TLSNextProto key is used to pass off non-TLS HTTP/2 conns. - t1.TLSNextProto[nextProtoUnencryptedHTTP2] = func(authority string, c *tls.Conn) http.RoundTripper { - nc, err := unencryptedNetConnFromTLSConn(c) - if err != nil { - go c.Close() - return erringRoundTripper{err} - } - return upgradeFn("http", authority, nc) + addr := authorityAddr(scheme, authority) + used, err := connPool.addConnIfNeeded(addr, t, c) + if !used { + go c.Close() } - return t2, nil + return err } // unencryptedTransport is a Transport with a RoundTrip method that // always permits http:// URLs. type unencryptedTransport Transport -func (t *unencryptedTransport) RoundTrip(req *http.Request) (*http.Response, error) { +func (t *unencryptedTransport) RoundTrip(req *ClientRequest) (*ClientResponse, error) { return (*Transport)(t).RoundTripOpt(req, RoundTripOpt{allowHTTP: true}) } @@ -428,8 +377,8 @@ type clientStream struct { donec chan struct{} // closed after the stream is in the closed state on100 chan struct{} // buffered; written to if a 100 is received - respHeaderRecv chan struct{} // closed when headers are received - res *http.Response // set if respHeaderRecv is closed + respHeaderRecv chan struct{} // closed when headers are received + res *ClientResponse // set if respHeaderRecv is closed flow outflow // guarded by cc.mu inflow inflow // guarded by cc.mu @@ -452,8 +401,10 @@ type clientStream struct { readAborted bool // read loop reset the stream totalHeaderSize int64 // total size of 1xx headers seen - trailer http.Header // accumulated trailers - resTrailer *http.Header // client's Response.Trailer + trailer Header // accumulated trailers + resTrailer *Header // client's Response.Trailer + + staticResp ClientResponse } var got1xxFuncForTests func(int, textproto.MIMEHeader) error @@ -557,7 +508,7 @@ type RoundTripOpt struct { allowHTTP bool // allow http:// URLs } -func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { +func (t *Transport) RoundTrip(req *ClientRequest) (*ClientResponse, error) { return t.RoundTripOpt(req, RoundTripOpt{}) } @@ -586,7 +537,7 @@ func authorityAddr(scheme string, authority string) (addr string) { } // RoundTripOpt is like RoundTrip, but takes options. -func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { +func (t *Transport) RoundTripOpt(req *ClientRequest, opt RoundTripOpt) (*ClientResponse, error) { switch req.URL.Scheme { case "https": // Always okay. @@ -624,9 +575,9 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res case <-tm.C: t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue - case <-req.Context().Done(): + case <-req.Context.Done(): tm.Stop() - err = req.Context().Err() + err = req.Context.Err() } } } @@ -654,6 +605,26 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res } } +func (t *Transport) IdleConnStrsForTesting() []string { + pool, ok := t.connPool().(noDialClientConnPool) + if !ok { + return nil + } + + var ret []string + pool.mu.Lock() + defer pool.mu.Unlock() + for k, ccs := range pool.conns { + for _, cc := range ccs { + if cc.idleState().canTakeNewRequest { + ret = append(ret, k) + } + } + } + slices.Sort(ret) + return ret +} + // CloseIdleConnections closes any connections which were previously // connected from previous requests but are now sitting idle. // It does not interrupt any connections currently in use. @@ -675,13 +646,13 @@ var ( // response headers. It is always called with a non-nil error. // It returns either a request to retry (either the same request, or a // modified clone), or an error if the request can't be replayed. -func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { +func shouldRetryRequest(req *ClientRequest, err error) (*ClientRequest, error) { if !canRetryError(err) { return nil, err } // If the Body is nil (or http.NoBody), it's safe to reuse // this request and its Body. - if req.Body == nil || req.Body == http.NoBody { + if req.Body == nil || req.Body == NoBody { return req, nil } @@ -692,9 +663,9 @@ func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { if err != nil { return nil, err } - newReq := *req + newReq := req.Clone() newReq.Body = body - return &newReq, nil + return newReq, nil } // The Request.Body can't reset back to the beginning, but we @@ -770,18 +741,27 @@ func (t *Transport) dialTLS(ctx context.Context, network, addr string, tlsCfg *t // disableKeepAlives reports whether connections should be closed as // soon as possible after handling the first request. func (t *Transport) disableKeepAlives() bool { - return t.t1 != nil && t.t1.DisableKeepAlives + return t.t1 != nil && t.t1.DisableKeepAlives() } func (t *Transport) expectContinueTimeout() time.Duration { if t.t1 == nil { return 0 } - return t.t1.ExpectContinueTimeout + return t.t1.ExpectContinueTimeout() } -func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives(), nil) +func (t *Transport) NewClientConn(c net.Conn, internalStateHook func()) (NetHTTPClientConn, error) { + cc, err := t.newClientConn(c, t.disableKeepAlives(), internalStateHook) + if err != nil { + return NetHTTPClientConn{}, err + } + + // RoundTrip should block when the conn is at its concurrency limit, + // not return an error. Setting strictMaxConcurrentStreams enables this. + cc.strictMaxConcurrentStreams = true + + return NetHTTPClientConn{cc}, nil } func (t *Transport) newClientConn(c net.Conn, singleUse bool, internalStateHook func()) (*ClientConn, error) { @@ -793,7 +773,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, internalStateHook nextStreamID: 1, maxFrameSize: 16 << 10, // spec default initialWindowSize: 65535, // spec default - initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream, + initialStreamRecvWindowSize: int32(conf.MaxReceiveBufferPerStream), maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. strictMaxConcurrentStreams: conf.StrictMaxConcurrentRequests, peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. @@ -828,16 +808,16 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, internalStateHook }) cc.br = bufio.NewReader(c) cc.fr = NewFramer(cc.bw, cc.br) - cc.fr.SetMaxReadFrameSize(conf.MaxReadFrameSize) + cc.fr.SetMaxReadFrameSize(uint32(conf.MaxReadFrameSize)) if t.CountError != nil { cc.fr.countError = t.CountError } - maxHeaderTableSize := conf.MaxDecoderHeaderTableSize + maxHeaderTableSize := uint32(conf.MaxDecoderHeaderTableSize) cc.fr.ReadMetaHeaders = hpack.NewDecoder(maxHeaderTableSize, nil) cc.fr.MaxHeaderListSize = t.maxHeaderListSize() cc.henc = hpack.NewEncoder(&cc.hbuf) - cc.henc.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize) + cc.henc.SetMaxDynamicTableSizeLimit(uint32(conf.MaxEncoderHeaderTableSize)) cc.peerMaxHeaderTableSize = initialHeaderTableSize if cs, ok := c.(connectionStater); ok { @@ -849,7 +829,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, internalStateHook {ID: SettingEnablePush, Val: 0}, {ID: SettingInitialWindowSize, Val: uint32(cc.initialStreamRecvWindowSize)}, } - initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: conf.MaxReadFrameSize}) + initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: uint32(conf.MaxReadFrameSize)}) if max := t.maxHeaderListSize(); max != 0 { initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max}) } @@ -859,8 +839,8 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, internalStateHook cc.bw.Write(clientPreface) cc.fr.WriteSettings(initialSettings...) - cc.fr.WriteWindowUpdate(0, uint32(conf.MaxUploadBufferPerConnection)) - cc.inflow.init(conf.MaxUploadBufferPerConnection + initialWindowSize) + cc.fr.WriteWindowUpdate(0, uint32(conf.MaxReceiveBufferPerConnection)) + cc.inflow.init(int32(conf.MaxReceiveBufferPerConnection) + initialWindowSize) cc.bw.Flush() if cc.werr != nil { cc.Close() @@ -1266,11 +1246,11 @@ func (cc *ClientConn) closeForLostPing() { // errRequestCanceled is a copy of net/http's errRequestCanceled because it's not // exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. -var errRequestCanceled = errors.New("net/http: request canceled") +var errRequestCanceled = internal.ErrRequestCanceled func (cc *ClientConn) responseHeaderTimeout() time.Duration { if cc.t.t1 != nil { - return cc.t.t1.ResponseHeaderTimeout + return cc.t.t1.ResponseHeaderTimeout() } // No way to do this (yet?) with just an http2.Transport. Probably // no need. Request.Cancel this is the new way. We only need to support @@ -1282,8 +1262,8 @@ func (cc *ClientConn) responseHeaderTimeout() time.Duration { // actualContentLength returns a sanitized version of // req.ContentLength, where 0 actually means zero (not unknown) and -1 // means unknown. -func actualContentLength(req *http.Request) int64 { - if req.Body == nil || req.Body == http.NoBody { +func actualContentLength(req *ClientRequest) int64 { + if req.Body == nil || req.Body == NoBody { return 0 } if req.ContentLength != 0 { @@ -1304,13 +1284,13 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } } -func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { +func (cc *ClientConn) RoundTrip(req *ClientRequest) (*ClientResponse, error) { return cc.roundTrip(req, nil) } -func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) { - ctx := req.Context() - cs := &clientStream{ +func (cc *ClientConn) roundTrip(req *ClientRequest, streamf func(*clientStream)) (*ClientResponse, error) { + ctx := req.Context + req.stream = clientStream{ cc: cc, ctx: ctx, reqCancel: req.Cancel, @@ -1322,7 +1302,9 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) abort: make(chan struct{}), respHeaderRecv: make(chan struct{}), donec: make(chan struct{}), + resTrailer: req.ResTrailer, } + cs := &req.stream cs.requestedGzip = httpcommon.IsRequestGzip(req.Method, req.Header, cc.t.disableCompression()) @@ -1339,7 +1321,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) } } - handleResponseHeaders := func() (*http.Response, error) { + handleResponseHeaders := func() (*ClientResponse, error) { res := cs.res if res.StatusCode > 299 { // On error or status code 3xx, 4xx, 5xx, etc abort any @@ -1353,9 +1335,8 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) // we can keep it. cs.abortRequestBodyWrite() } - res.Request = req res.TLS = cc.tlsState - if res.Body == noBody && actualContentLength(req) == 0 { + if res.Body == NoBody && actualContentLength(req) == 0 { // If there isn't a request or response body still being // written, then wait for the stream to be closed before // RoundTrip returns. @@ -1419,7 +1400,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) // doRequest runs for the duration of the request lifetime. // // It sends the request and performs post-request cleanup (closing Request.Body, etc.). -func (cs *clientStream) doRequest(req *http.Request, streamf func(*clientStream)) { +func (cs *clientStream) doRequest(req *ClientRequest, streamf func(*clientStream)) { err := cs.writeRequest(req, streamf) cs.cleanupWriteRequest(err) } @@ -1433,7 +1414,7 @@ var errExtendedConnectNotSupported = errors.New("net/http: extended connect not // // It returns non-nil if the request ends otherwise. // If the returned error is StreamError, the error Code may be used in resetting the stream. -func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStream)) (err error) { +func (cs *clientStream) writeRequest(req *ClientRequest, streamf func(*clientStream)) (err error) { cc := cs.cc ctx := cs.ctx @@ -1577,7 +1558,7 @@ func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStre } } -func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error { +func (cs *clientStream) encodeAndWriteHeaders(req *ClientRequest) error { cc := cs.cc ctx := cs.ctx @@ -1617,8 +1598,8 @@ func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error { return err } -func encodeRequestHeaders(req *http.Request, addGzipHeader bool, peerMaxHeaderListSize uint64, headerf func(name, value string)) (httpcommon.EncodeHeadersResult, error) { - return httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{ +func encodeRequestHeaders(req *ClientRequest, addGzipHeader bool, peerMaxHeaderListSize uint64, headerf func(name, value string)) (httpcommon.EncodeHeadersResult, error) { + return httpcommon.EncodeHeaders(req.Context, httpcommon.EncodeHeadersParam{ Request: httpcommon.Request{ Header: req.Header, Trailer: req.Trailer, @@ -1852,7 +1833,7 @@ func bufPoolIndex(size int) int { return index } -func (cs *clientStream) writeRequestBody(req *http.Request) (err error) { +func (cs *clientStream) writeRequestBody(req *ClientRequest) (err error) { cc := cs.cc body := cs.reqBody sentEnd := false // whether we sent the final DATA frame w/ END_STREAM @@ -2026,7 +2007,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) } // requires cc.wmu be held. -func (cc *ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) { +func (cc *ClientConn) encodeTrailers(trailer Header) ([]byte, error) { cc.hbuf.Reset() hlSize := uint64(0) @@ -2065,7 +2046,7 @@ func (cc *ClientConn) writeHeader(name, value string) { type resAndError struct { _ incomparable - res *http.Response + res *ClientResponse err error } @@ -2359,7 +2340,6 @@ func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error { // (nil, nil) special case. See handleResponse docs. return nil } - cs.resTrailer = &res.Trailer cs.res = res close(cs.respHeaderRecv) if f.StreamEnded() { @@ -2374,7 +2354,7 @@ func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error { // // As a special case, handleResponse may return (nil, nil) to skip the // frame (currently only used for 1xx responses). -func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFrame) (*http.Response, error) { +func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFrame) (*ClientResponse, error) { if f.Truncated { return nil, errResponseHeaderListSize } @@ -2390,20 +2370,19 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra regularFields := f.RegularFields() strs := make([]string, len(regularFields)) - header := make(http.Header, len(regularFields)) - res := &http.Response{ - Proto: "HTTP/2.0", - ProtoMajor: 2, + header := make(Header, len(regularFields)) + res := &cs.staticResp + cs.staticResp = ClientResponse{ Header: header, StatusCode: statusCode, - Status: status + " " + http.StatusText(statusCode), + Status: status, } for _, hf := range regularFields { key := httpcommon.CanonicalHeader(hf.Name) if key == "Trailer" { t := res.Trailer if t == nil { - t = make(http.Header) + t = make(Header) res.Trailer = t } foreachHeaderElement(hf.Value, func(v string) { @@ -2445,8 +2424,8 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra // Use the larger limit of MaxHeaderListSize and // net/http.Transport.MaxResponseHeaderBytes. limit := int64(cs.cc.t.maxHeaderListSize()) - if t1 := cs.cc.t.t1; t1 != nil && t1.MaxResponseHeaderBytes > limit { - limit = t1.MaxResponseHeaderBytes + if t1 := cs.cc.t.t1; t1 != nil && t1.MaxResponseHeaderBytes() > limit { + limit = t1.MaxResponseHeaderBytes() } for _, h := range f.Fields { cs.totalHeaderSize += int64(h.Size()) @@ -2485,7 +2464,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra } if cs.isHead { - res.Body = noBody + res.Body = NoBody return res, nil } @@ -2493,7 +2472,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra if res.ContentLength > 0 { res.Body = missingBody{} } else { - res.Body = noBody + res.Body = NoBody } return res, nil } @@ -2529,7 +2508,7 @@ func (rl *clientConnReadLoop) processTrailers(cs *clientStream, f *MetaHeadersFr return ConnectionError(ErrCodeProtocol) } - trailer := make(http.Header) + trailer := make(Header) for _, hf := range f.RegularFields() { key := httpcommon.CanonicalHeader(hf.Name) trailer[key] = append(trailer[key], hf.Value) @@ -2808,7 +2787,7 @@ func (cs *clientStream) copyTrailers() { for k, vv := range cs.trailer { t := cs.resTrailer if *t == nil { - *t = make(http.Header) + *t = make(Header) } (*t)[k] = vv } @@ -3105,13 +3084,6 @@ func (t *Transport) logf(format string, args ...interface{}) { log.Printf(format, args...) } -var noBody io.ReadCloser = noBodyReader{} - -type noBodyReader struct{} - -func (noBodyReader) Close() error { return nil } -func (noBodyReader) Read([]byte) (int, error) { return 0, io.EOF } - type missingBody struct{} func (missingBody) Close() error { return nil } @@ -3128,8 +3100,8 @@ func strSliceContains(ss []string, s string) bool { type erringRoundTripper struct{ err error } -func (rt erringRoundTripper) RoundTripErr() error { return rt.err } -func (rt erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { return nil, rt.err } +func (rt erringRoundTripper) RoundTripErr() error { return rt.err } +func (rt erringRoundTripper) RoundTrip(*ClientRequest) (*ClientResponse, error) { return nil, rt.err } var errConcurrentReadOnResBody = errors.New("http2: concurrent read on response body") @@ -3231,69 +3203,25 @@ func (gz *gzipReader) Close() error { // isConnectionCloseRequest reports whether req should use its own // connection for a single request and then close the connection. -func isConnectionCloseRequest(req *http.Request) bool { +func isConnectionCloseRequest(req *ClientRequest) bool { return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close") } -// registerHTTPSProtocol calls Transport.RegisterProtocol but -// converting panics into errors. -func registerHTTPSProtocol(t *http.Transport, rt noDialH2RoundTripper) (err error) { - defer func() { - if e := recover(); e != nil { - err = fmt.Errorf("%v", e) - } - }() - t.RegisterProtocol("https", rt) - return nil -} - -// noDialH2RoundTripper is a RoundTripper which only tries to complete the request -// if there's already a cached connection to the host. -// (The field is exported so it can be accessed via reflect from net/http; tested -// by TestNoDialH2RoundTripperType) -// -// A noDialH2RoundTripper is registered with http1.Transport.RegisterProtocol, -// and the http1.Transport can use type assertions to call non-RoundTrip methods on it. -// This lets us expose, for example, NewClientConn to net/http. -type noDialH2RoundTripper struct{ *Transport } - -func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - res, err := rt.Transport.RoundTrip(req) - if isNoCachedConnError(err) { - return nil, http.ErrSkipAltProtocol - } - return res, err -} - -func (rt noDialH2RoundTripper) NewClientConn(conn net.Conn, internalStateHook func()) (http.RoundTripper, error) { - tr := rt.Transport - cc, err := tr.newClientConn(conn, tr.disableKeepAlives(), internalStateHook) - if err != nil { - return nil, err - } - - // RoundTrip should block when the conn is at its concurrency limit, - // not return an error. Setting strictMaxConcurrentStreams enables this. - cc.strictMaxConcurrentStreams = true - - return netHTTPClientConn{cc}, nil -} - // netHTTPClientConn wraps ClientConn and implements the interface net/http expects from // the RoundTripper returned by NewClientConn. -type netHTTPClientConn struct { +type NetHTTPClientConn struct { cc *ClientConn } -func (cc netHTTPClientConn) RoundTrip(req *http.Request) (*http.Response, error) { +func (cc NetHTTPClientConn) RoundTrip(req *ClientRequest) (*ClientResponse, error) { return cc.cc.RoundTrip(req) } -func (cc netHTTPClientConn) Close() error { +func (cc NetHTTPClientConn) Close() error { return cc.cc.Close() } -func (cc netHTTPClientConn) Err() error { +func (cc NetHTTPClientConn) Err() error { cc.cc.mu.Lock() defer cc.cc.mu.Unlock() if cc.cc.closed { @@ -3302,7 +3230,7 @@ func (cc netHTTPClientConn) Err() error { return nil } -func (cc netHTTPClientConn) Reserve() error { +func (cc NetHTTPClientConn) Reserve() error { defer cc.cc.maybeCallStateHook() cc.cc.mu.Lock() defer cc.cc.mu.Unlock() @@ -3313,7 +3241,7 @@ func (cc netHTTPClientConn) Reserve() error { return nil } -func (cc netHTTPClientConn) Release() { +func (cc NetHTTPClientConn) Release() { defer cc.cc.maybeCallStateHook() cc.cc.mu.Lock() defer cc.cc.mu.Unlock() @@ -3326,13 +3254,13 @@ func (cc netHTTPClientConn) Release() { } } -func (cc netHTTPClientConn) Available() int { +func (cc NetHTTPClientConn) Available() int { cc.cc.mu.Lock() defer cc.cc.mu.Unlock() return cc.cc.availableLocked() } -func (cc netHTTPClientConn) InFlight() int { +func (cc NetHTTPClientConn) InFlight() int { cc.cc.mu.Lock() defer cc.cc.mu.Unlock() return cc.cc.currentRequestCountLocked() @@ -3353,22 +3281,22 @@ func (t *Transport) idleConnTimeout() time.Duration { } if t.t1 != nil { - return t.t1.IdleConnTimeout + return t.t1.IdleConnTimeout() } return 0 } -func traceGetConn(req *http.Request, hostPort string) { - trace := httptrace.ContextClientTrace(req.Context()) +func traceGetConn(req *ClientRequest, hostPort string) { + trace := httptrace.ContextClientTrace(req.Context) if trace == nil || trace.GetConn == nil { return } trace.GetConn(hostPort) } -func traceGotConn(req *http.Request, cc *ClientConn, reused bool) { - trace := httptrace.ContextClientTrace(req.Context()) +func traceGotConn(req *ClientRequest, cc *ClientConn, reused bool) { + trace := httptrace.ContextClientTrace(req.Context) if trace == nil || trace.GotConn == nil { return } diff --git a/src/net/http/internal/http2/transport_internal_test.go b/src/net/http/internal/http2/transport_internal_test.go index 2f8532fd75..a6d67a9567 100644 --- a/src/net/http/internal/http2/transport_internal_test.go +++ b/src/net/http/internal/http2/transport_internal_test.go @@ -11,7 +11,6 @@ import ( "fmt" "io" "io/fs" - "net/http" "reflect" "strings" "testing" @@ -25,27 +24,27 @@ func (panicReader) Close() error { panic("unexpected Close") } func TestActualContentLength(t *testing.T) { tests := []struct { - req *http.Request + req *ClientRequest want int64 }{ // Verify we don't read from Body: 0: { - req: &http.Request{Body: panicReader{}}, + req: &ClientRequest{Body: panicReader{}}, want: -1, }, // nil Body means 0, regardless of ContentLength: 1: { - req: &http.Request{Body: nil, ContentLength: 5}, + req: &ClientRequest{Body: nil, ContentLength: 5}, want: 0, }, // ContentLength is used if set. 2: { - req: &http.Request{Body: panicReader{}, ContentLength: 5}, + req: &ClientRequest{Body: panicReader{}, ContentLength: 5}, want: 5, }, // http.NoBody means 0, not -1. 3: { - req: &http.Request{Body: http.NoBody}, + req: &ClientRequest{Body: NoBody}, want: 0, }, } @@ -200,7 +199,7 @@ func TestTransportUsesGetBodyWhenPresent(t *testing.T) { someBody := func() io.ReadCloser { return struct{ io.ReadCloser }{io.NopCloser(bytes.NewReader(nil))} } - req := &http.Request{ + req := &ClientRequest{ Body: someBody(), GetBody: func() (io.ReadCloser, error) { calls++ @@ -232,28 +231,6 @@ func TestTransportUsesGetBodyWhenPresent(t *testing.T) { } } -// Issue 22891: verify that the "https" altproto we register with net/http -// is a certain type: a struct with one field with our *http2.Transport in it. -func TestNoDialH2RoundTripperType(t *testing.T) { - t1 := new(http.Transport) - t2 := new(Transport) - rt := noDialH2RoundTripper{t2} - if err := registerHTTPSProtocol(t1, rt); err != nil { - t.Fatal(err) - } - rv := reflect.ValueOf(rt) - if rv.Type().Kind() != reflect.Struct { - t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind()) - } - if n := rv.Type().NumField(); n != 1 { - t.Fatalf("fields = %d; net/http expects 1", n) - } - v := rv.Field(0) - if _, ok := v.Interface().(*Transport); !ok { - t.Fatalf("wrong kind %T; want *Transport", v.Interface()) - } -} - func TestClientConnTooIdle(t *testing.T) { tests := []struct { cc func() *ClientConn diff --git a/src/net/http/internal/http2/transport_test.go b/src/net/http/internal/http2/transport_test.go index faae10e6a4..8f1a589624 100644 --- a/src/net/http/internal/http2/transport_test.go +++ b/src/net/http/internal/http2/transport_test.go @@ -36,6 +36,7 @@ import ( "time" . "net/http/internal/http2" + "net/http/internal/httpcommon" "golang.org/x/net/http2/hpack" ) @@ -62,7 +63,6 @@ func newTransport(t testing.TB, opts ...any) *http.Transport { Protocols: protocols("h2"), HTTP2: &http.HTTP2Config{}, } - ConfigureTransport(tr1) for _, o := range opts { switch o := o.(type) { case func(*http.Transport): @@ -90,34 +90,6 @@ func TestTransportExternal(t *testing.T) { res.Write(os.Stdout) } -type fakeTLSConn struct { - net.Conn -} - -func (c *fakeTLSConn) ConnectionState() tls.ConnectionState { - const cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F // defined in ciphers.go - return tls.ConnectionState{ - Version: tls.VersionTLS12, - CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - } -} - -func startH2cServer(t *testing.T) net.Listener { - h2Server := &Server{} - l := newLocalListener(t) - go func() { - conn, err := l.Accept() - if err != nil { - t.Error(err) - return - } - h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil) - })}) - }() - return l -} - func TestIdleConnTimeout(t *testing.T) { for _, test := range []struct { name string @@ -201,9 +173,12 @@ func TestIdleConnTimeout(t *testing.T) { } func TestTransportH2c(t *testing.T) { - l := startH2cServer(t) - defer l.Close() - req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil) + ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil) + }, func(s *http.Server) { + s.Protocols = protocols("h2c") + }) + req, err := http.NewRequest("GET", ts.URL+"/foobar", nil) if err != nil { t.Fatal(err) } @@ -217,7 +192,7 @@ func TestTransportH2c(t *testing.T) { } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) tr := newTransport(t) - tr.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return net.Dial(network, addr) } tr.Protocols = protocols("h2c") @@ -1313,16 +1288,24 @@ func testTransportChecksRequestHeaderListSize(t testing.TB) { rt.wantStatus(http.StatusOK) } headerListSizeForRequest := func(req *http.Request) (size uint64) { - const addGzipHeader = true - const peerMaxHeaderListSize = 0xffffffffffffffff - _, err := EncodeRequestHeaders(req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { + _, err := httpcommon.EncodeHeaders(context.Background(), httpcommon.EncodeHeadersParam{ + Request: httpcommon.Request{ + Header: req.Header, + Trailer: req.Trailer, + URL: req.URL, + Host: req.Host, + Method: req.Method, + ActualContentLength: req.ContentLength, + }, + AddGzipHeader: true, + PeerMaxHeaderListSize: 0xffffffffffffffff, + }, func(name, value string) { hf := hpack.HeaderField{Name: name, Value: value} size += uint64(hf.Size()) }) if err != nil { t.Fatal(err) } - fmt.Println(size) return size } // Create a new Request for each test, rather than reusing the @@ -1330,17 +1313,19 @@ func testTransportChecksRequestHeaderListSize(t testing.TB) { // See https://github.com/golang/go/issues/21316 newRequest := func() *http.Request { // Body must be non-nil to enable writing trailers. - body := strings.NewReader("hello") + const bodytext = "hello" + body := strings.NewReader(bodytext) req, err := http.NewRequest("POST", "https://example.tld/", body) if err != nil { t.Fatalf("newRequest: NewRequest: %v", err) } + req.ContentLength = int64(len(bodytext)) + req.Header = http.Header{"User-Agent": nil} return req } // Pad headers & trailers, but stay under peerSize. req := newRequest() - req.Header = make(http.Header) req.Trailer = make(http.Header) filler := strings.Repeat("*", 1024) padHeaders(t, req.Trailer, peerSize, filler) @@ -1352,7 +1337,6 @@ func testTransportChecksRequestHeaderListSize(t testing.TB) { // Add enough header bytes to push us over peerSize. req = newRequest() - req.Header = make(http.Header) padHeaders(t, req.Header, peerSize, filler) checkRoundTrip(req, ErrRequestHeaderListSize, "Headers over limit") @@ -1365,7 +1349,6 @@ func testTransportChecksRequestHeaderListSize(t testing.TB) { // Send headers with a single large value. req = newRequest() filler = strings.Repeat("*", int(peerSize)) - req.Header = make(http.Header) req.Header.Set("Big", filler) checkRoundTrip(req, ErrRequestHeaderListSize, "Single large header") @@ -2596,10 +2579,18 @@ func TestTransportRequestPathPseudo(t *testing.T) { for i, tt := range tests { hbuf := &bytes.Buffer{} henc := hpack.NewEncoder(hbuf) - - const addGzipHeader = false - const peerMaxHeaderListSize = 0xffffffffffffffff - _, err := EncodeRequestHeaders(tt.req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { + _, err := httpcommon.EncodeHeaders(context.Background(), httpcommon.EncodeHeadersParam{ + Request: httpcommon.Request{ + Header: tt.req.Header, + Trailer: tt.req.Trailer, + URL: tt.req.URL, + Host: tt.req.Host, + Method: tt.req.Method, + ActualContentLength: tt.req.ContentLength, + }, + AddGzipHeader: false, + PeerMaxHeaderListSize: 0xffffffffffffffff, + }, func(name, value string) { henc.WriteField(hpack.HeaderField{Name: name, Value: value}) }) hdrs := hbuf.Bytes() @@ -3631,10 +3622,11 @@ func testClientConnCloseAtBody(t testing.TB) { ), }) tc.writeData(rt.streamID(), false, make([]byte, 64)) + resp := rt.response() tc.cc.Close() synctest.Wait() - if _, err := io.Copy(io.Discard, rt.response().Body); err == nil { + if _, err := io.Copy(io.Discard, resp.Body); err == nil { t.Error("expected a Copy error, got nil") } } @@ -4352,13 +4344,13 @@ func TestTransportCloseRequestBody(t *testing.T) { w.WriteHeader(statusCode) }) - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() + tr := newTransport(t) ctx := context.Background() - cc, err := tr.DialClientConn(ctx, ts.Listener.Addr().String(), false) + cc, err := tr.NewClientConn(ctx, "https", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } + defer cc.Close() for _, status := range []int{200, 401} { t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) { @@ -4966,39 +4958,40 @@ func testTransportDataAfter1xxHeader(t testing.TB) { } func TestIssue66763Race(t *testing.T) { - tr := &Transport{ - IdleConnTimeout: 1 * time.Nanosecond, - AllowHTTP: true, // issue 66763 only occurs when AllowHTTP is true - } - defer tr.CloseIdleConnections() + ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, + func(s *http.Server) { + s.Protocols = protocols("h2c") + }) + tr := newTransport(t) + tr.IdleConnTimeout = 1 * time.Nanosecond + tr.Protocols = protocols("h2c") - cli, srv := net.Pipe() donec := make(chan struct{}) go func() { // Creating the client conn may succeed or fail, // depending on when the idle timeout happens. // Either way, the idle timeout will close the net.Conn. - tr.NewClientConn(cli) + conn, err := tr.NewClientConn(t.Context(), "http", ts.URL) close(donec) + if err == nil { + conn.Close() + } }() // The client sends its preface and SETTINGS frame, // and then closes its conn after the idle timeout. - io.ReadAll(srv) - srv.Close() - <-donec } // Issue 67671: Sending a Connection: close request on a Transport with AllowHTTP // set caused a the transport to wedge. func TestIssue67671(t *testing.T) { - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {}) - tr := &Transport{ - TLSClientConfig: tlsConfigInsecure, - AllowHTTP: true, - } - defer tr.CloseIdleConnections() + ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, + func(s *http.Server) { + s.Protocols = protocols("h2c") + }) + tr := newTransport(t) + tr.Protocols = protocols("h2c") req, _ := http.NewRequest("GET", ts.URL, nil) req.Close = true for i := 0; i < 2; i++ { @@ -5215,6 +5208,7 @@ func testTransportSendNoMoreThanOnePingWithReset(t testing.TB) { // because we haven't received a HEADERS or DATA frame from the server // since the last PING we sent. makeAndResetRequest() + tc.wantIdle() // Server belatedly responds to request 1. // The server has not responded to our first PING yet. @@ -5334,26 +5328,45 @@ func testTransportConnBecomesUnresponsive(t testing.TB) { rt2.response().Body.Close() } -// Test that the Transport can use a conn provided to it by a TLSNextProto hook. -func TestTransportTLSNextProtoConnOK(t *testing.T) { synctestTest(t, testTransportTLSNextProtoConnOK) } -func testTransportTLSNextProtoConnOK(t testing.TB) { - t1 := &http.Transport{} - t2, _ := ConfigureTransports(t1) - tt := newTestTransport(t, t2) +// newTestTransportWithUnusedConn creates a Transport, +// sends a request on the Transport, +// and then cancels the request before the resulting dial completes. +// It then waits for the dial to finish +// and returns the Transport with an unused conn in its pool. +func newTestTransportWithUnusedConn(t testing.TB, opts ...any) *testTransport { + tt := newTestTransport(t, opts...) - // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe() - cliTLS := tls.Client(cli, tlsConfigInsecure) - go func() { - tt.tr.TestTransport().TLSNextProto["h2"]("dummy.tld", cliTLS) - }() + waitc := make(chan struct{}) + dialContext := tt.tr1.DialContext + tt.tr1.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + <-waitc + return dialContext(ctx, network, address) + } + + req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) + rt := tt.roundTrip(req) + rt.cancel() + if rt.err() == nil { + t.Fatalf("RoundTrip still running after request is canceled") + } + + close(waitc) synctest.Wait() + return tt +} + +// Test that the Transport can use a conn created for one request, but never used by it. +func TestTransportUnusedConnOK(t *testing.T) { synctestTest(t, testTransportUnusedConnOK) } +func testTransportUnusedConnOK(t testing.TB) { + tt := newTestTransportWithUnusedConn(t) + + req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) tc := tt.getConn() - tc.greet() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) // Send a request on the Transport. // It uses the conn we provided. - req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt := tt.roundTrip(req) tc.wantHeaders(wantHeader{ streamID: 1, @@ -5364,6 +5377,11 @@ func testTransportTLSNextProtoConnOK(t testing.TB) { ":path": []string{"/"}, }, }) + + tc.writeSettings() + tc.writeSettingsAck() + tc.wantFrameType(FrameSettings) // acknowledgement + tc.writeHeaders(HeadersFrameParam{ StreamID: 1, EndHeaders: true, @@ -5376,26 +5394,16 @@ func testTransportTLSNextProtoConnOK(t testing.TB) { rt.wantBody(nil) } -// Test the case where a conn provided via a TLSNextProto hook immediately encounters an error. -func TestTransportTLSNextProtoConnImmediateFailureUsed(t *testing.T) { - synctestTest(t, testTransportTLSNextProtoConnImmediateFailureUsed) +// Test the case where an unused conn immediately encounters an error. +func TestTransportUnusedConnImmediateFailureUsed(t *testing.T) { + synctestTest(t, testTransportUnusedConnImmediateFailureUsed) } -func testTransportTLSNextProtoConnImmediateFailureUsed(t testing.TB) { - t1 := &http.Transport{} - t2, _ := ConfigureTransports(t1) - tt := newTestTransport(t, t2) - - // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe() - cliTLS := tls.Client(cli, tlsConfigInsecure) - go func() { - t1.TLSNextProto["h2"]("dummy.tld", cliTLS) - }() - synctest.Wait() - tc := tt.getConn() +func testTransportUnusedConnImmediateFailureUsed(t testing.TB) { + tt := newTestTransportWithUnusedConn(t) // The connection encounters an error before we send a request that uses it. - tc.closeWrite() + tc1 := tt.getConn() + tc1.closeWrite() // Send a request on the Transport. // @@ -5407,33 +5415,24 @@ func testTransportTLSNextProtoConnImmediateFailureUsed(t testing.TB) { } // Send the request again. - // This time it should fail with ErrNoCachedConn, + // This time it is sent on a new conn // because the dead conn has been removed from the pool. - rt = tt.roundTrip(req) - if err := rt.err(); !errors.Is(err, ErrNoCachedConn) { - t.Fatalf("RoundTrip after broken conn is used: got %v, want ErrNoCachedConn", err) - } + _ = tt.roundTrip(req) + tc2 := tt.getConn() + tc2.wantFrameType(FrameSettings) + tc2.wantFrameType(FrameWindowUpdate) + tc2.wantFrameType(FrameHeaders) } -// Test the case where a conn provided via a TLSNextProto hook is closed for idleness -// before we use it. -func TestTransportTLSNextProtoConnIdleTimoutBeforeUse(t *testing.T) { - synctestTest(t, testTransportTLSNextProtoConnIdleTimoutBeforeUse) +// Test the case where an unused conn is closed for idleness before we use it. +func TestTransportUnusedConnIdleTimoutBeforeUse(t *testing.T) { + synctestTest(t, testTransportUnusedConnIdleTimoutBeforeUse) } -func testTransportTLSNextProtoConnIdleTimoutBeforeUse(t testing.TB) { - t1 := &http.Transport{ - IdleConnTimeout: 1 * time.Second, - } - t2, _ := ConfigureTransports(t1) - tt := newTestTransport(t, t2) +func testTransportUnusedConnIdleTimoutBeforeUse(t testing.TB) { + tt := newTestTransportWithUnusedConn(t, func(t1 *http.Transport) { + t1.IdleConnTimeout = 1 * time.Second + }) - // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe() - cliTLS := tls.Client(cli, tlsConfigInsecure) - go func() { - t1.TLSNextProto["h2"]("dummy.tld", cliTLS) - }() - synctest.Wait() _ = tt.getConn() // The connection encounters an error before we send a request that uses it. @@ -5442,12 +5441,14 @@ func testTransportTLSNextProtoConnIdleTimoutBeforeUse(t testing.TB) { // Send a request on the Transport. // - // It should fail with ErrNoCachedConn. + // It is sent on a new conn + // because the old one has idled out and been removed from the pool. req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) - rt := tt.roundTrip(req) - if err := rt.err(); !errors.Is(err, ErrNoCachedConn) { - t.Fatalf("RoundTrip with conn closed for idleness: got %v, want ErrNoCachedConn", err) - } + _ = tt.roundTrip(req) + tc2 := tt.getConn() + tc2.wantFrameType(FrameSettings) + tc2.wantFrameType(FrameWindowUpdate) + tc2.wantFrameType(FrameHeaders) } // Test the case where a conn provided via a TLSNextProto hook immediately encounters an error, @@ -5456,21 +5457,13 @@ func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) { synctestTest(t, testTransportTLSNextProtoConnImmediateFailureUnused) } func testTransportTLSNextProtoConnImmediateFailureUnused(t testing.TB) { - t1 := &http.Transport{} - t2, _ := ConfigureTransports(t1) - tt := newTestTransport(t, t2) - - // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe() - cliTLS := tls.Client(cli, tlsConfigInsecure) - go func() { - t1.TLSNextProto["h2"]("dummy.tld", cliTLS) - }() - synctest.Wait() - tc := tt.getConn() + tt := newTestTransportWithUnusedConn(t, func(t1 *http.Transport) { + t1.IdleConnTimeout = 1 * time.Second + }) // The connection encounters an error before we send a request that uses it. - tc.closeWrite() + tc1 := tt.getConn() + tc1.closeWrite() // Some time passes. // The dead connection is removed from the pool. @@ -5478,12 +5471,13 @@ func testTransportTLSNextProtoConnImmediateFailureUnused(t testing.TB) { // Send a request on the Transport. // - // It should fail with ErrNoCachedConn, because the pool contains no conns. + // It is sent on a new conn. req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) - rt := tt.roundTrip(req) - if err := rt.err(); !errors.Is(err, ErrNoCachedConn) { - t.Fatalf("RoundTrip after broken conn expires: got %v, want ErrNoCachedConn", err) - } + _ = tt.roundTrip(req) + tc2 := tt.getConn() + tc2.wantFrameType(FrameSettings) + tc2.wantFrameType(FrameWindowUpdate) + tc2.wantFrameType(FrameHeaders) } func TestExtendedConnectClientWithServerSupport(t *testing.T) { diff --git a/src/net/http/internal/http2/unencrypted.go b/src/net/http/internal/http2/unencrypted.go deleted file mode 100644 index b2de211613..0000000000 --- a/src/net/http/internal/http2/unencrypted.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "crypto/tls" - "errors" - "net" -) - -const nextProtoUnencryptedHTTP2 = "unencrypted_http2" - -// unencryptedNetConnFromTLSConn retrieves a net.Conn wrapped in a *tls.Conn. -// -// TLSNextProto functions accept a *tls.Conn. -// -// When passing an unencrypted HTTP/2 connection to a TLSNextProto function, -// we pass a *tls.Conn with an underlying net.Conn containing the unencrypted connection. -// To be extra careful about mistakes (accidentally dropping TLS encryption in a place -// where we want it), the tls.Conn contains a net.Conn with an UnencryptedNetConn method -// that returns the actual connection we want to use. -func unencryptedNetConnFromTLSConn(tc *tls.Conn) (net.Conn, error) { - conner, ok := tc.NetConn().(interface { - UnencryptedNetConn() net.Conn - }) - if !ok { - return nil, errors.New("http2: TLS conn unexpectedly found in unencrypted handoff") - } - return conner.UnencryptedNetConn(), nil -} diff --git a/src/net/http/internal/http2/write.go b/src/net/http/internal/http2/write.go index 0691934f79..59c4e625a5 100644 --- a/src/net/http/internal/http2/write.go +++ b/src/net/http/internal/http2/write.go @@ -8,7 +8,6 @@ import ( "bytes" "fmt" "log" - "net/http" "net/http/internal/httpcommon" "net/url" @@ -189,9 +188,9 @@ func splitHeaderBlock(ctx writeContext, headerBlock []byte, fn func(ctx writeCon // for HTTP response headers or trailers from a server handler. type writeResHeaders struct { streamID uint32 - httpResCode int // 0 means no ":status" line - h http.Header // may be nil - trailers []string // if non-nil, which keys of h to write. nil means all. + httpResCode int // 0 means no ":status" line + h Header // may be nil + trailers []string // if non-nil, which keys of h to write. nil means all. endStream bool date string @@ -263,7 +262,7 @@ type writePushPromise struct { streamID uint32 // pusher stream method string // for :method url *url.URL // for :scheme, :authority, :path - h http.Header + h Header // Creates an ID for a pushed stream. This runs on serveG just before // the frame is written. The returned ID is copied to promisedID. @@ -341,7 +340,7 @@ func (wu writeWindowUpdate) writeFrame(ctx writeContext) error { // encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) // is encoded only if k is in keys. -func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { +func encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { if keys == nil { sorter := sorterPool.Get().(*sorter) // Using defer here, since the returned keys from the |
