diff options
82 files changed, 16178 insertions, 29 deletions
diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index 2f424e7e3c..dd036a1568 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -569,6 +569,9 @@ var depsRules = ` crypto/mlkem < CRYPTO; + CRYPTO + < golang.org/x/crypto/hkdf; + CGO, fmt, net !< CRYPTO; # CRYPTO-MATH is crypto that exposes math/big APIs - no cgo, net; fmt now ok. @@ -663,6 +666,12 @@ var depsRules = ` < net/http/internal/http2 < net/http; + net/http, golang.org/x/crypto/hkdf, log/slog + < golang.org/x/net/internal/quic/quicwire + < golang.org/x/net/quic, golang.org/x/net/internal/httpcommon + < golang.org/x/net/internal/http3 + < golang.org/x/net/http3; + # HTTP-aware packages encoding/json, net/http diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index c25db82fe4..aadd8e3dc0 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -19,6 +19,7 @@ import ( "log" "maps" "net" + "net/http" . "net/http" "net/http/httptest" "net/http/httptrace" @@ -35,8 +36,17 @@ import ( "testing" "testing/synctest" "time" + _ "unsafe" // for linkname + + _ "golang.org/x/net/http3" ) +//go:linkname registerHTTP3Transport +func registerHTTP3Transport(*http.Transport) + +//go:linkname registerHTTP3Server +func registerHTTP3Server(*http.Server) <-chan string + type testMode string const ( @@ -44,13 +54,14 @@ const ( https1Mode = testMode("https1") // HTTPS/1.1 http2Mode = testMode("h2") // HTTP/2 http2UnencryptedMode = testMode("h2unencrypted") // HTTP/2 + http3Mode = testMode("h3") // HTTP/3 ) func (m testMode) Scheme() string { switch m { case http1Mode, http2UnencryptedMode: return "http" - case https1Mode, http2Mode: + case https1Mode, http2Mode, http3Mode: return "https" } panic("unknown testMode") @@ -189,7 +200,8 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c var transportFuncs []func(*Transport) - if idx := slices.Index(opts, any(optFakeNet)); idx >= 0 { + switch idx := slices.Index(opts, any(optFakeNet)); { + case idx >= 0: opts = slices.Delete(opts, idx, idx+1) cst.li = fakeNetListen() cst.ts = &httptest.Server{ @@ -201,7 +213,12 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c return cst.li.connect(), nil } }) - } else { + case mode == http3Mode: + // TODO: support testing HTTP/3 using fakenet. + cst.ts = &httptest.Server{ + Config: &Server{Handler: h}, + } + default: cst.ts = httptest.NewUnstartedServer(h) } @@ -241,6 +258,24 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c cst.ts.EnableHTTP2 = true cst.ts.TLS = cst.ts.Config.TLSConfig cst.ts.StartTLS() + case http3Mode: + http.ProtocolSetHTTP3(p) + cst.ts.TLS = cst.ts.Config.TLSConfig + cst.ts.StartTLS() + listenAddrCh := registerHTTP3Server(cst.ts.Config) + + cst.ts.Config.TLSConfig = cst.ts.TLS + cst.ts.Config.Addr = "localhost:0" + go cst.ts.Config.ListenAndServeTLS("", "") + + listenAddr := <-listenAddrCh + cst.ts.URL = "https://" + listenAddr + t.Cleanup(func() { + // Same timeout as in HTTP/2 goAwayTimeout when shutting down in tests. + ctx, cancel := context.WithTimeout(t.Context(), 25*time.Millisecond) + defer cancel() + cst.ts.Config.Shutdown(ctx) + }) default: t.Fatalf("unknown test mode %v", mode) } @@ -252,6 +287,9 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c if cst.tr.Protocols == nil { cst.tr.Protocols = p } + if mode == http3Mode { + registerHTTP3Transport(cst.tr) + } t.Cleanup(func() { cst.close() diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index b499769c4f..300785d20d 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -33,6 +33,7 @@ var ( Export_writeStatusLine = writeStatusLine Export_is408Message = is408Message MaxPostCloseReadTime = maxPostCloseReadTime + ProtocolSetHTTP3 = protocolSetHTTP3 ) var MaxWriteWaitBeforeConnReuse = &maxWriteWaitBeforeConnReuse diff --git a/src/net/http/http.go b/src/net/http/http.go index 407e15a1c4..c46b656581 100644 --- a/src/net/http/http.go +++ b/src/net/http/http.go @@ -88,6 +88,9 @@ func (p Protocols) String() string { if p.UnencryptedHTTP2() { s = append(s, "UnencryptedHTTP2") } + if p.http3() { + s = append(s, "HTTP3") + } return "{" + strings.Join(s, ",") + "}" } diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go index 7ae2561b71..fd65e5797a 100644 --- a/src/net/http/httptest/server.go +++ b/src/net/http/httptest/server.go @@ -20,6 +20,7 @@ import ( "strings" "sync" "time" + _ "unsafe" // for linkname ) // A Server is an HTTP server listening on a system-chosen port on the @@ -45,6 +46,9 @@ type Server struct { // certificate is a parsed version of the TLS config certificate, if present. certificate *x509.Certificate + // started indicates whether the server has been started. + started bool + // wg counts the number of outstanding HTTP requests on this server. // Close blocks until all requests are finished. wg sync.WaitGroup @@ -124,30 +128,31 @@ func NewUnstartedServer(handler http.Handler) *Server { // Start starts a server from NewUnstartedServer. func (s *Server) Start() { - if s.URL != "" { + if s.started { panic("Server already started") } + s.started = true + s.wrap() - if s.client == nil { - tr := &http.Transport{} - dialer := net.Dialer{} - // User code may set either of Dial or DialContext, with DialContext taking precedence. - // We set DialContext here to preserve any context values that are passed in, - // but fall back to Dial if the user has set it. - tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - if tr.Dial != nil { - return tr.Dial(network, addr) - } - if addr == "example.com:80" || strings.HasSuffix(addr, ".example.com:80") { - addr = s.Listener.Addr().String() - } - return dialer.DialContext(ctx, network, addr) + tr := &http.Transport{} + s.client = &http.Client{Transport: tr} + if s.Listener == nil { + return + } + dialer := net.Dialer{} + // User code may set either of Dial or DialContext, with DialContext taking precedence. + // We set DialContext here to preserve any context values that are passed in, + // but fall back to Dial if the user has set it. + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + if tr.Dial != nil { + return tr.Dial(network, addr) } - s.client = &http.Client{Transport: tr} - + if addr == "example.com:80" || strings.HasSuffix(addr, ".example.com:80") { + addr = s.Listener.Addr().String() + } + return dialer.DialContext(ctx, network, addr) } s.URL = "http://" + s.Listener.Addr().String() - s.wrap() s.goServe() if serveFlag != "" { fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) @@ -157,12 +162,13 @@ func (s *Server) Start() { // StartTLS starts TLS on a server from NewUnstartedServer. func (s *Server) StartTLS() { - if s.URL != "" { + if s.started { panic("Server already started") } - if s.client == nil { - s.client = &http.Client{} - } + s.started = true + s.wrap() + + s.client = &http.Client{} cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) if err != nil { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) @@ -190,12 +196,18 @@ func (s *Server) StartTLS() { } certpool := x509.NewCertPool() certpool.AddCert(s.certificate) + tr := &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: certpool, }, ForceAttemptHTTP2: s.EnableHTTP2, } + s.client.Transport = tr + + if s.Listener == nil { + return + } dialer := net.Dialer{} tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { if tr.Dial != nil { @@ -206,10 +218,8 @@ func (s *Server) StartTLS() { } return dialer.DialContext(ctx, network, addr) } - s.client.Transport = tr s.Listener = tls.NewListener(s.Listener, s.TLS) s.URL = "https://" + s.Listener.Addr().String() - s.wrap() s.goServe() } @@ -231,7 +241,9 @@ func (s *Server) Close() { s.mu.Lock() if !s.closed { s.closed = true - s.Listener.Close() + if s.Listener != nil { + s.Listener.Close() + } s.Config.SetKeepAlivesEnabled(false) for c, st := range s.conns { // Force-close any idle connections (those between @@ -275,7 +287,6 @@ func (s *Server) Close() { t.CloseIdleConnections() } } - s.wg.Wait() } diff --git a/src/vendor/golang.org/x/crypto/hkdf/hkdf.go b/src/vendor/golang.org/x/crypto/hkdf/hkdf.go new file mode 100644 index 0000000000..3bee66294e --- /dev/null +++ b/src/vendor/golang.org/x/crypto/hkdf/hkdf.go @@ -0,0 +1,95 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package hkdf implements the HMAC-based Extract-and-Expand Key Derivation +// Function (HKDF) as defined in RFC 5869. +// +// HKDF is a cryptographic key derivation function (KDF) with the goal of +// expanding limited input keying material into one or more cryptographically +// strong secret keys. +package hkdf + +import ( + "crypto/hmac" + "errors" + "hash" + "io" +) + +// Extract generates a pseudorandom key for use with Expand from an input secret +// and an optional independent salt. +// +// Only use this function if you need to reuse the extracted key with multiple +// Expand invocations and different context values. Most common scenarios, +// including the generation of multiple keys, should use New instead. +func Extract(hash func() hash.Hash, secret, salt []byte) []byte { + if salt == nil { + salt = make([]byte, hash().Size()) + } + extractor := hmac.New(hash, salt) + extractor.Write(secret) + return extractor.Sum(nil) +} + +type hkdf struct { + expander hash.Hash + size int + + info []byte + counter byte + + prev []byte + buf []byte +} + +func (f *hkdf) Read(p []byte) (int, error) { + // Check whether enough data can be generated + need := len(p) + remains := len(f.buf) + int(255-f.counter+1)*f.size + if remains < need { + return 0, errors.New("hkdf: entropy limit reached") + } + // Read any leftover from the buffer + n := copy(p, f.buf) + p = p[n:] + + // Fill the rest of the buffer + for len(p) > 0 { + if f.counter > 1 { + f.expander.Reset() + } + f.expander.Write(f.prev) + f.expander.Write(f.info) + f.expander.Write([]byte{f.counter}) + f.prev = f.expander.Sum(f.prev[:0]) + f.counter++ + + // Copy the new batch into p + f.buf = f.prev + n = copy(p, f.buf) + p = p[n:] + } + // Save leftovers for next run + f.buf = f.buf[n:] + + return need, nil +} + +// Expand returns a Reader, from which keys can be read, using the given +// pseudorandom key and optional context info, skipping the extraction step. +// +// The pseudorandomKey should have been generated by Extract, or be a uniformly +// random or pseudorandom cryptographically strong key. See RFC 5869, Section +// 3.3. Most common scenarios will want to use New instead. +func Expand(hash func() hash.Hash, pseudorandomKey, info []byte) io.Reader { + expander := hmac.New(hash, pseudorandomKey) + return &hkdf{expander, expander.Size(), info, 1, nil, nil} +} + +// New returns a Reader, from which keys can be read, using the given hash, +// secret, salt and context info. Salt and info can be nil. +func New(hash func() hash.Hash, secret, salt, info []byte) io.Reader { + prk := Extract(hash, secret, salt) + return Expand(hash, prk, info) +} diff --git a/src/vendor/golang.org/x/net/http3/http3.go b/src/vendor/golang.org/x/net/http3/http3.go new file mode 100644 index 0000000000..72ac09041d --- /dev/null +++ b/src/vendor/golang.org/x/net/http3/http3.go @@ -0,0 +1,31 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "net/http" + _ "unsafe" // for linkname + + . "golang.org/x/net/internal/http3" + "golang.org/x/net/quic" +) + +//go:linkname registerHTTP3Server net/http_test.registerHTTP3Server +func registerHTTP3Server(s *http.Server) <-chan string { + listenAddr := make(chan string) + RegisterServer(s, ServerOpts{ + ListenQUIC: func(addr string, config *quic.Config) (*quic.Endpoint, error) { + e, err := quic.Listen("udp", addr, config) + listenAddr <- e.LocalAddr().String() + return e, err + }, + }) + return listenAddr +} + +//go:linkname registerHTTP3Transport net/http_test.registerHTTP3Transport +func registerHTTP3Transport(tr *http.Transport) { + RegisterTransport(tr) +} diff --git a/src/vendor/golang.org/x/net/internal/http3/body.go b/src/vendor/golang.org/x/net/internal/http3/body.go new file mode 100644 index 0000000000..6db183beb8 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/body.go @@ -0,0 +1,230 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "errors" + "fmt" + "io" + "net" + "net/http" + "net/textproto" + "strings" + "sync" + + "golang.org/x/net/http/httpguts" +) + +// extractTrailerFromHeader extracts the "Trailer" header values from a header +// map, and populates a trailer map with those values as keys. The extracted +// header values will be canonicalized. +func extractTrailerFromHeader(header, trailer http.Header) { + for _, names := range header["Trailer"] { + names = textproto.TrimString(names) + for name := range strings.SplitSeq(names, ",") { + name = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(name)) + if !httpguts.ValidTrailerHeader(name) { + continue + } + trailer[name] = nil + } + } +} + +// A bodyWriter writes a request or response body to a stream +// as a series of DATA frames. +type bodyWriter struct { + st *stream + remain int64 // -1 when content-length is not known + flush bool // flush the stream after every write + name string // "request" or "response" + trailer http.Header // trailer headers that will be written once bodyWriter is closed. + enc *qpackEncoder // QPACK encoder used by the connection. +} + +func (w *bodyWriter) write(ps ...[]byte) (n int, err error) { + var size int64 + for _, p := range ps { + size += int64(len(p)) + } + // If write is called with empty byte slices, just return instead of + // sending out a DATA frame containing nothing. + if size == 0 { + return 0, nil + } + if w.remain >= 0 && size > w.remain { + return 0, &streamError{ + code: errH3InternalError, + message: w.name + " body longer than specified content length", + } + } + w.st.writeVarint(int64(frameTypeData)) + w.st.writeVarint(size) + for _, p := range ps { + var n2 int + n2, err = w.st.Write(p) + n += n2 + if w.remain >= 0 { + w.remain -= int64(n) + } + if err != nil { + break + } + } + if w.flush && err == nil { + err = w.st.Flush() + } + if err != nil { + err = fmt.Errorf("writing %v body: %w", w.name, err) + } + return n, err +} + +func (w *bodyWriter) Write(p []byte) (n int, err error) { + return w.write(p) +} + +func (w *bodyWriter) Close() error { + if w.remain > 0 { + return errors.New(w.name + " body shorter than specified content length") + } + if len(w.trailer) > 0 { + encTrailer := w.enc.encode(func(f func(itype indexType, name, value string)) { + for name, values := range w.trailer { + if !httpguts.ValidHeaderFieldName(name) { + continue + } + for _, val := range values { + if !httpguts.ValidHeaderFieldValue(val) { + continue + } + f(mayIndex, name, val) + } + } + }) + w.st.writeVarint(int64(frameTypeHeaders)) + w.st.writeVarint(int64(len(encTrailer))) + w.st.Write(encTrailer) + } + if w.st != nil && w.st.stream != nil { + w.st.stream.CloseWrite() + } + return nil +} + +// A bodyReader reads a request or response body from a stream. +type bodyReader struct { + st *stream + + mu sync.Mutex + remain int64 + err error + // If not nil, the body contains an "Expect: 100-continue" header, and + // send100Continue should be called when Read is invoked for the first + // time. + send100Continue func() + // A map where the key represents the trailer header names we expect. If + // there is a HEADERS frame after reading DATA frames to EOF, the value of + // the headers will be written here, provided that the name of the header + // exists in the map already. + trailer http.Header +} + +func (r *bodyReader) Read(p []byte) (n int, err error) { + // The HTTP/1 and HTTP/2 implementations both permit concurrent reads from a body, + // in the sense that the race detector won't complain. + // Use a mutex here to provide the same behavior. + r.mu.Lock() + defer r.mu.Unlock() + if r.send100Continue != nil { + r.send100Continue() + r.send100Continue = nil + } + if r.err != nil { + return 0, r.err + } + defer func() { + if err != nil { + r.err = err + } + }() + if r.st.lim == 0 { + // We've finished reading the previous DATA frame, so end it. + if err := r.st.endFrame(); err != nil { + return 0, err + } + } + // Read the next DATA frame header, + // if we aren't already in the middle of one. + for r.st.lim < 0 { + ftype, err := r.st.readFrameHeader() + if err == io.EOF && r.remain > 0 { + return 0, &streamError{ + code: errH3MessageError, + message: "body shorter than content-length", + } + } + if err != nil { + return 0, err + } + switch ftype { + case frameTypeData: + if r.remain >= 0 && r.st.lim > r.remain { + return 0, &streamError{ + code: errH3MessageError, + message: "body longer than content-length", + } + } + // Fall out of the loop and process the frame body below. + case frameTypeHeaders: + // This HEADERS frame contains the message trailers. + if r.remain > 0 { + return 0, &streamError{ + code: errH3MessageError, + message: "body shorter than content-length", + } + } + var dec qpackDecoder + if err := dec.decode(r.st, func(_ indexType, name, value string) error { + if _, ok := r.trailer[name]; ok { + r.trailer.Add(name, value) + } + return nil + }); err != nil { + return 0, err + } + if err := r.st.discardFrame(); err != nil { + return 0, err + } + return 0, io.EOF + default: + if err := r.st.discardUnknownFrame(ftype); err != nil { + return 0, err + } + } + } + // We are now reading the content of a DATA frame. + // Fill the read buffer or read to the end of the frame, + // whichever comes first. + if int64(len(p)) > r.st.lim { + p = p[:r.st.lim] + } + n, err = r.st.Read(p) + if r.remain > 0 { + r.remain -= int64(n) + } + return n, err +} + +func (r *bodyReader) Close() error { + // Unlike the HTTP/1 and HTTP/2 body readers (at the time of this comment being written), + // calling Close concurrently with Read will interrupt the read. + r.st.stream.CloseRead() + // Make sure that any data that has already been written to bodyReader + // cannot be read after it has been closed. + r.err = net.ErrClosed + r.remain = 0 + return nil +} diff --git a/src/vendor/golang.org/x/net/internal/http3/conn.go b/src/vendor/golang.org/x/net/internal/http3/conn.go new file mode 100644 index 0000000000..6a3c962b41 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/conn.go @@ -0,0 +1,131 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "context" + "io" + "sync" + + "golang.org/x/net/quic" +) + +type streamHandler interface { + handleControlStream(*stream) error + handlePushStream(*stream) error + handleEncoderStream(*stream) error + handleDecoderStream(*stream) error + handleRequestStream(*stream) error + abort(error) +} + +type genericConn struct { + mu sync.Mutex + + // The peer may create exactly one control, encoder, and decoder stream. + // streamsCreated is a bitset of streams created so far. + // Bits are 1 << streamType. + streamsCreated uint8 +} + +func (c *genericConn) acceptStreams(qconn *quic.Conn, h streamHandler) { + for { + // Use context.Background: This blocks until a stream is accepted + // or the connection closes. + st, err := qconn.AcceptStream(context.Background()) + if err != nil { + return // connection closed + } + if st.IsReadOnly() { + go c.handleUnidirectionalStream(newStream(st), h) + } else { + go c.handleRequestStream(newStream(st), h) + } + } +} + +func (c *genericConn) handleUnidirectionalStream(st *stream, h streamHandler) { + // Unidirectional stream header: One varint with the stream type. + v, err := st.readVarint() + if err != nil { + h.abort(&connectionError{ + code: errH3StreamCreationError, + message: "error reading unidirectional stream header", + }) + return + } + stype := streamType(v) + if err := c.checkStreamCreation(stype); err != nil { + h.abort(err) + return + } + switch stype { + case streamTypeControl: + err = h.handleControlStream(st) + case streamTypePush: + err = h.handlePushStream(st) + case streamTypeEncoder: + err = h.handleEncoderStream(st) + case streamTypeDecoder: + err = h.handleDecoderStream(st) + default: + // "Recipients of unknown stream types MUST either abort reading + // of the stream or discard incoming data without further processing." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2-7 + // + // We should send the H3_STREAM_CREATION_ERROR error code, + // but the quic package currently doesn't allow setting error codes + // for STOP_SENDING frames. + // TODO: Should CloseRead take an error code? + err = nil + } + if err == io.EOF { + err = &connectionError{ + code: errH3ClosedCriticalStream, + message: streamType(stype).String() + " stream closed", + } + } + c.handleStreamError(st, h, err) +} + +func (c *genericConn) handleRequestStream(st *stream, h streamHandler) { + c.handleStreamError(st, h, h.handleRequestStream(st)) +} + +func (c *genericConn) handleStreamError(st *stream, h streamHandler, err error) { + switch err := err.(type) { + case *connectionError: + h.abort(err) + case nil: + st.stream.CloseRead() + st.stream.CloseWrite() + case *streamError: + st.stream.CloseRead() + st.stream.Reset(uint64(err.code)) + default: + st.stream.CloseRead() + st.stream.Reset(uint64(errH3InternalError)) + } +} + +func (c *genericConn) checkStreamCreation(stype streamType) error { + switch stype { + case streamTypeControl, streamTypeEncoder, streamTypeDecoder: + // The peer may create exactly one control, encoder, and decoder stream. + default: + return nil + } + c.mu.Lock() + defer c.mu.Unlock() + bit := uint8(1) << stype + if c.streamsCreated&bit != 0 { + return &connectionError{ + code: errH3StreamCreationError, + message: "multiple " + stype.String() + " streams created", + } + } + c.streamsCreated |= bit + return nil +} diff --git a/src/vendor/golang.org/x/net/internal/http3/doc.go b/src/vendor/golang.org/x/net/internal/http3/doc.go new file mode 100644 index 0000000000..5530113f69 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/doc.go @@ -0,0 +1,10 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package http3 implements the HTTP/3 protocol. +// +// This package is a work in progress. +// It is not ready for production usage. +// Its API is subject to change without notice. +package http3 diff --git a/src/vendor/golang.org/x/net/internal/http3/errors.go b/src/vendor/golang.org/x/net/internal/http3/errors.go new file mode 100644 index 0000000000..273ad014a6 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/errors.go @@ -0,0 +1,102 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import "fmt" + +// http3Error is an HTTP/3 error code. +type http3Error int + +const ( + // https://www.rfc-editor.org/rfc/rfc9114.html#section-8.1 + errH3NoError = http3Error(0x0100) + errH3GeneralProtocolError = http3Error(0x0101) + errH3InternalError = http3Error(0x0102) + errH3StreamCreationError = http3Error(0x0103) + errH3ClosedCriticalStream = http3Error(0x0104) + errH3FrameUnexpected = http3Error(0x0105) + errH3FrameError = http3Error(0x0106) + errH3ExcessiveLoad = http3Error(0x0107) + errH3IDError = http3Error(0x0108) + errH3SettingsError = http3Error(0x0109) + errH3MissingSettings = http3Error(0x010a) + errH3RequestRejected = http3Error(0x010b) + errH3RequestCancelled = http3Error(0x010c) + errH3RequestIncomplete = http3Error(0x010d) + errH3MessageError = http3Error(0x010e) + errH3ConnectError = http3Error(0x010f) + errH3VersionFallback = http3Error(0x0110) + + // https://www.rfc-editor.org/rfc/rfc9204.html#section-8.3 + errQPACKDecompressionFailed = http3Error(0x0200) + errQPACKEncoderStreamError = http3Error(0x0201) + errQPACKDecoderStreamError = http3Error(0x0202) +) + +func (e http3Error) Error() string { + switch e { + case errH3NoError: + return "H3_NO_ERROR" + case errH3GeneralProtocolError: + return "H3_GENERAL_PROTOCOL_ERROR" + case errH3InternalError: + return "H3_INTERNAL_ERROR" + case errH3StreamCreationError: + return "H3_STREAM_CREATION_ERROR" + case errH3ClosedCriticalStream: + return "H3_CLOSED_CRITICAL_STREAM" + case errH3FrameUnexpected: + return "H3_FRAME_UNEXPECTED" + case errH3FrameError: + return "H3_FRAME_ERROR" + case errH3ExcessiveLoad: + return "H3_EXCESSIVE_LOAD" + case errH3IDError: + return "H3_ID_ERROR" + case errH3SettingsError: + return "H3_SETTINGS_ERROR" + case errH3MissingSettings: + return "H3_MISSING_SETTINGS" + case errH3RequestRejected: + return "H3_REQUEST_REJECTED" + case errH3RequestCancelled: + return "H3_REQUEST_CANCELLED" + case errH3RequestIncomplete: + return "H3_REQUEST_INCOMPLETE" + case errH3MessageError: + return "H3_MESSAGE_ERROR" + case errH3ConnectError: + return "H3_CONNECT_ERROR" + case errH3VersionFallback: + return "H3_VERSION_FALLBACK" + case errQPACKDecompressionFailed: + return "QPACK_DECOMPRESSION_FAILED" + case errQPACKEncoderStreamError: + return "QPACK_ENCODER_STREAM_ERROR" + case errQPACKDecoderStreamError: + return "QPACK_DECODER_STREAM_ERROR" + } + return fmt.Sprintf("H3_ERROR_%v", int(e)) +} + +// A streamError is an error which terminates a stream, but not the connection. +// https://www.rfc-editor.org/rfc/rfc9114.html#section-8-1 +type streamError struct { + code http3Error + message string +} + +func (e *streamError) Error() string { return e.message } +func (e *streamError) Unwrap() error { return e.code } + +// A connectionError is an error which results in the entire connection closing. +// https://www.rfc-editor.org/rfc/rfc9114.html#section-8-2 +type connectionError struct { + code http3Error + message string +} + +func (e *connectionError) Error() string { return e.message } +func (e *connectionError) Unwrap() error { return e.code } diff --git a/src/vendor/golang.org/x/net/internal/http3/http3.go b/src/vendor/golang.org/x/net/internal/http3/http3.go new file mode 100644 index 0000000000..189e3e749b --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/http3.go @@ -0,0 +1,95 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "context" + "fmt" +) + +// Stream types. +// +// For unidirectional streams, the value is the stream type sent over the wire. +// +// For bidirectional streams (which are always request streams), +// the value is arbitrary and never sent on the wire. +type streamType int64 + +const ( + // Bidirectional request stream. + // All bidirectional streams are request streams. + // This stream type is never sent over the wire. + // + // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1 + streamTypeRequest = streamType(-1) + + // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2 + streamTypeControl = streamType(0x00) + streamTypePush = streamType(0x01) + + // https://www.rfc-editor.org/rfc/rfc9204.html#section-4.2 + streamTypeEncoder = streamType(0x02) + streamTypeDecoder = streamType(0x03) +) + +// canceledCtx is a canceled Context. +// Used for performing non-blocking QUIC operations. +var canceledCtx = func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx +}() + +func (stype streamType) String() string { + switch stype { + case streamTypeRequest: + return "request" + case streamTypeControl: + return "control" + case streamTypePush: + return "push" + case streamTypeEncoder: + return "encoder" + case streamTypeDecoder: + return "decoder" + default: + return "unknown" + } +} + +// Frame types. +type frameType int64 + +const ( + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2 + frameTypeData = frameType(0x00) + frameTypeHeaders = frameType(0x01) + frameTypeCancelPush = frameType(0x03) + frameTypeSettings = frameType(0x04) + frameTypePushPromise = frameType(0x05) + frameTypeGoaway = frameType(0x07) + frameTypeMaxPushID = frameType(0x0d) +) + +func (ftype frameType) String() string { + switch ftype { + case frameTypeData: + return "DATA" + case frameTypeHeaders: + return "HEADERS" + case frameTypeCancelPush: + return "CANCEL_PUSH" + case frameTypeSettings: + return "SETTINGS" + case frameTypePushPromise: + return "PUSH_PROMISE" + case frameTypeGoaway: + return "GOAWAY" + case frameTypeMaxPushID: + return "MAX_PUSH_ID" + default: + return fmt.Sprintf("UNKNOWN_%d", int64(ftype)) + } +} diff --git a/src/vendor/golang.org/x/net/internal/http3/qpack.go b/src/vendor/golang.org/x/net/internal/http3/qpack.go new file mode 100644 index 0000000000..64ce99aaa0 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/qpack.go @@ -0,0 +1,332 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "errors" + "io" + + "golang.org/x/net/http2/hpack" +) + +// QPACK (RFC 9204) header compression wire encoding. +// https://www.rfc-editor.org/rfc/rfc9204.html + +// tableType is the static or dynamic table. +// +// The T bit in QPACK instructions indicates whether a table index refers to +// the dynamic (T=0) or static (T=1) table. tableTypeForTBit and tableType.tbit +// convert a T bit from the wire encoding to/from a tableType. +type tableType byte + +const ( + dynamicTable = 0x00 // T=0, dynamic table + staticTable = 0xff // T=1, static table +) + +// tableTypeForTbit returns the table type corresponding to a T bit value. +// The input parameter contains a byte masked to contain only the T bit. +func tableTypeForTbit(bit byte) tableType { + if bit == 0 { + return dynamicTable + } + return staticTable +} + +// tbit produces the T bit corresponding to the table type. +// The input parameter contains a byte with the T bit set to 1, +// and the return is either the input or 0 depending on the table type. +func (t tableType) tbit(bit byte) byte { + return bit & byte(t) +} + +// indexType indicates a literal's indexing status. +// +// The N bit in QPACK instructions indicates whether a literal is "never-indexed". +// A never-indexed literal (N=1) must not be encoded as an indexed literal if it +// forwarded on another connection. +// +// (See https://www.rfc-editor.org/rfc/rfc9204.html#section-7.1 for details on the +// security reasons for never-indexed literals.) +type indexType byte + +const ( + mayIndex = 0x00 // N=0, not a never-indexed literal + neverIndex = 0xff // N=1, never-indexed literal +) + +// indexTypeForNBit returns the index type corresponding to a N bit value. +// The input parameter contains a byte masked to contain only the N bit. +func indexTypeForNBit(bit byte) indexType { + if bit == 0 { + return mayIndex + } + return neverIndex +} + +// nbit produces the N bit corresponding to the table type. +// The input parameter contains a byte with the N bit set to 1, +// and the return is either the input or 0 depending on the table type. +func (t indexType) nbit(bit byte) byte { + return bit & byte(t) +} + +// Indexed Field Line: +// +// 0 1 2 3 4 5 6 7 +// +---+---+---+---+---+---+---+---+ +// | 1 | T | Index (6+) | +// +---+---+-----------------------+ +// +// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.2 + +func appendIndexedFieldLine(b []byte, ttype tableType, index int) []byte { + const tbit = 0b_01000000 + return appendPrefixedInt(b, 0b_1000_0000|ttype.tbit(tbit), 6, int64(index)) +} + +func (st *stream) decodeIndexedFieldLine(b byte) (itype indexType, name, value string, err error) { + index, err := st.readPrefixedIntWithByte(b, 6) + if err != nil { + return 0, "", "", err + } + const tbit = 0b_0100_0000 + if tableTypeForTbit(b&tbit) == staticTable { + ent, err := staticTableEntry(index) + if err != nil { + return 0, "", "", err + } + return mayIndex, ent.name, ent.value, nil + } else { + return 0, "", "", errors.New("dynamic table is not supported yet") + } +} + +// Literal Field Line With Name Reference: +// +// 0 1 2 3 4 5 6 7 +// +---+---+---+---+---+---+---+---+ +// | 0 | 1 | N | T |Name Index (4+)| +// +---+---+---+---+---------------+ +// | H | Value Length (7+) | +// +---+---------------------------+ +// | Value String (Length bytes) | +// +-------------------------------+ +// +// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.4 + +func appendLiteralFieldLineWithNameReference(b []byte, ttype tableType, itype indexType, nameIndex int, value string) []byte { + const tbit = 0b_0001_0000 + const nbit = 0b_0010_0000 + b = appendPrefixedInt(b, 0b_0100_0000|itype.nbit(nbit)|ttype.tbit(tbit), 4, int64(nameIndex)) + b = appendPrefixedString(b, 0, 7, value) + return b +} + +func (st *stream) decodeLiteralFieldLineWithNameReference(b byte) (itype indexType, name, value string, err error) { + nameIndex, err := st.readPrefixedIntWithByte(b, 4) + if err != nil { + return 0, "", "", err + } + + const tbit = 0b_0001_0000 + if tableTypeForTbit(b&tbit) == staticTable { + ent, err := staticTableEntry(nameIndex) + if err != nil { + return 0, "", "", err + } + name = ent.name + } else { + return 0, "", "", errors.New("dynamic table is not supported yet") + } + + _, value, err = st.readPrefixedString(7) + if err != nil { + return 0, "", "", err + } + + const nbit = 0b_0010_0000 + itype = indexTypeForNBit(b & nbit) + + return itype, name, value, nil +} + +// Literal Field Line with Literal Name: +// +// 0 1 2 3 4 5 6 7 +// +---+---+---+---+---+---+---+---+ +// | 0 | 0 | 1 | N | H |NameLen(3+)| +// +---+---+---+---+---+-----------+ +// | Name String (Length bytes) | +// +---+---------------------------+ +// | H | Value Length (7+) | +// +---+---------------------------+ +// | Value String (Length bytes) | +// +-------------------------------+ +// +// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.6 + +func appendLiteralFieldLineWithLiteralName(b []byte, itype indexType, name, value string) []byte { + const nbit = 0b_0001_0000 + b = appendPrefixedString(b, 0b_0010_0000|itype.nbit(nbit), 3, name) + b = appendPrefixedString(b, 0, 7, value) + return b +} + +func (st *stream) decodeLiteralFieldLineWithLiteralName(b byte) (itype indexType, name, value string, err error) { + name, err = st.readPrefixedStringWithByte(b, 3) + if err != nil { + return 0, "", "", err + } + _, value, err = st.readPrefixedString(7) + if err != nil { + return 0, "", "", err + } + const nbit = 0b_0001_0000 + itype = indexTypeForNBit(b & nbit) + return itype, name, value, nil +} + +// Prefixed-integer encoding from RFC 7541, section 5.1 +// +// Prefixed integers consist of some number of bits of data, +// N bits of encoded integer, and 0 or more additional bytes of +// encoded integer. +// +// The RFCs represent this as, for example: +// +// 0 1 2 3 4 5 6 7 +// +---+---+---+---+---+---+---+---+ +// | 0 | 0 | 1 | Capacity (5+) | +// +---+---+---+-------------------+ +// +// "Capacity" is an integer with a 5-bit prefix. +// +// In the following functions, a "prefixLen" parameter is the number +// of integer bits in the first byte (5 in the above example), and +// a "firstByte" parameter is a byte containing the first byte of +// the encoded value (0x001x_xxxx in the above example). +// +// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.1 +// https://www.rfc-editor.org/rfc/rfc7541#section-5.1 + +// readPrefixedInt reads an RFC 7541 prefixed integer from st. +func (st *stream) readPrefixedInt(prefixLen uint8) (firstByte byte, v int64, err error) { + firstByte, err = st.ReadByte() + if err != nil { + return 0, 0, errQPACKDecompressionFailed + } + v, err = st.readPrefixedIntWithByte(firstByte, prefixLen) + return firstByte, v, err +} + +// readPrefixedIntWithByte reads an RFC 7541 prefixed integer from st. +// The first byte has already been read from the stream. +func (st *stream) readPrefixedIntWithByte(firstByte byte, prefixLen uint8) (v int64, err error) { + prefixMask := (byte(1) << prefixLen) - 1 + v = int64(firstByte & prefixMask) + if v != int64(prefixMask) { + return v, nil + } + m := 0 + for { + b, err := st.ReadByte() + if err != nil { + return 0, errQPACKDecompressionFailed + } + v += int64(b&127) << m + m += 7 + if b&128 == 0 { + break + } + } + return v, err +} + +// appendPrefixedInt appends an RFC 7541 prefixed integer to b. +// +// The firstByte parameter includes the non-integer bits of the first byte. +// The other bits must be zero. +func appendPrefixedInt(b []byte, firstByte byte, prefixLen uint8, i int64) []byte { + u := uint64(i) + prefixMask := (uint64(1) << prefixLen) - 1 + if u < prefixMask { + return append(b, firstByte|byte(u)) + } + b = append(b, firstByte|byte(prefixMask)) + u -= prefixMask + for u >= 128 { + b = append(b, 0x80|byte(u&0x7f)) + u >>= 7 + } + return append(b, byte(u)) +} + +// String literal encoding from RFC 7541, section 5.2 +// +// String literals consist of a single bit flag indicating +// whether the string is Huffman-encoded, a prefixed integer (see above), +// and the string. +// +// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.2 +// https://www.rfc-editor.org/rfc/rfc7541#section-5.2 + +// readPrefixedString reads an RFC 7541 string from st. +func (st *stream) readPrefixedString(prefixLen uint8) (firstByte byte, s string, err error) { + firstByte, err = st.ReadByte() + if err != nil { + return 0, "", errQPACKDecompressionFailed + } + s, err = st.readPrefixedStringWithByte(firstByte, prefixLen) + return firstByte, s, err +} + +// readPrefixedStringWithByte reads an RFC 7541 string from st. +// The first byte has already been read from the stream. +func (st *stream) readPrefixedStringWithByte(firstByte byte, prefixLen uint8) (s string, err error) { + size, err := st.readPrefixedIntWithByte(firstByte, prefixLen) + if err != nil { + return "", errQPACKDecompressionFailed + } + + hbit := byte(1) << prefixLen + isHuffman := firstByte&hbit != 0 + + // TODO: Avoid allocating here. + data := make([]byte, size) + if _, err := io.ReadFull(st, data); err != nil { + return "", errQPACKDecompressionFailed + } + if isHuffman { + // TODO: Move Huffman functions into a new package that hpack (HTTP/2) + // and this package can both import. Most of the hpack package isn't + // relevant to HTTP/3. + s, err := hpack.HuffmanDecodeToString(data) + if err != nil { + return "", errQPACKDecompressionFailed + } + return s, nil + } + return string(data), nil +} + +// appendPrefixedString appends an RFC 7541 string to st, +// applying Huffman encoding and setting the H bit (indicating Huffman encoding) +// when appropriate. +// +// The firstByte parameter includes the non-integer bits of the first byte. +// The other bits must be zero. +func appendPrefixedString(b []byte, firstByte byte, prefixLen uint8, s string) []byte { + huffmanLen := hpack.HuffmanEncodeLength(s) + if huffmanLen < uint64(len(s)) { + hbit := byte(1) << prefixLen + b = appendPrefixedInt(b, firstByte|hbit, prefixLen, int64(huffmanLen)) + b = hpack.AppendHuffmanString(b, s) + } else { + b = appendPrefixedInt(b, firstByte, prefixLen, int64(len(s))) + b = append(b, s...) + } + return b +} diff --git a/src/vendor/golang.org/x/net/internal/http3/qpack_decode.go b/src/vendor/golang.org/x/net/internal/http3/qpack_decode.go new file mode 100644 index 0000000000..7348ae76f0 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/qpack_decode.go @@ -0,0 +1,81 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "errors" + "math/bits" +) + +type qpackDecoder struct { + // The decoder has no state for now, + // but that'll change once we add dynamic table support. + // + // TODO: dynamic table support. +} + +func (qd *qpackDecoder) decode(st *stream, f func(itype indexType, name, value string) error) error { + // Encoded Field Section prefix. + + // We set SETTINGS_QPACK_MAX_TABLE_CAPACITY to 0, + // so the Required Insert Count must be 0. + _, requiredInsertCount, err := st.readPrefixedInt(8) + if err != nil { + return err + } + if requiredInsertCount != 0 { + return errQPACKDecompressionFailed + } + + // Delta Base. We don't use the dynamic table yet, so this may be ignored. + _, _, err = st.readPrefixedInt(7) + if err != nil { + return err + } + + sawNonPseudo := false + for st.lim > 0 { + firstByte, err := st.ReadByte() + if err != nil { + return err + } + var name, value string + var itype indexType + switch bits.LeadingZeros8(firstByte) { + case 0: + // Indexed Field Line + itype, name, value, err = st.decodeIndexedFieldLine(firstByte) + case 1: + // Literal Field Line With Name Reference + itype, name, value, err = st.decodeLiteralFieldLineWithNameReference(firstByte) + case 2: + // Literal Field Line with Literal Name + itype, name, value, err = st.decodeLiteralFieldLineWithLiteralName(firstByte) + case 3: + // Indexed Field Line With Post-Base Index + err = errors.New("dynamic table is not supported yet") + case 4: + // Indexed Field Line With Post-Base Name Reference + err = errors.New("dynamic table is not supported yet") + } + if err != nil { + return err + } + if len(name) == 0 { + return errH3MessageError + } + if name[0] == ':' { + if sawNonPseudo { + return errH3MessageError + } + } else { + sawNonPseudo = true + } + if err := f(itype, name, value); err != nil { + return err + } + } + return nil +} diff --git a/src/vendor/golang.org/x/net/internal/http3/qpack_encode.go b/src/vendor/golang.org/x/net/internal/http3/qpack_encode.go new file mode 100644 index 0000000000..193f7f93be --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/qpack_encode.go @@ -0,0 +1,45 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +type qpackEncoder struct { + // The encoder has no state for now, + // but that'll change once we add dynamic table support. + // + // TODO: dynamic table support. +} + +func (qe *qpackEncoder) init() { + staticTableOnce.Do(initStaticTableMaps) +} + +// encode encodes a list of headers into a QPACK encoded field section. +// +// The headers func must produce the same headers on repeated calls, +// although the order may vary. +func (qe *qpackEncoder) encode(headers func(func(itype indexType, name, value string))) []byte { + // Encoded Field Section prefix. + // + // We don't yet use the dynamic table, so both values here are zero. + var b []byte + b = appendPrefixedInt(b, 0, 8, 0) // Required Insert Count + b = appendPrefixedInt(b, 0, 7, 0) // Delta Base + + headers(func(itype indexType, name, value string) { + if itype == mayIndex { + if i, ok := staticTableByNameValue[tableEntry{name, value}]; ok { + b = appendIndexedFieldLine(b, staticTable, i) + return + } + } + if i, ok := staticTableByName[name]; ok { + b = appendLiteralFieldLineWithNameReference(b, staticTable, itype, i, value) + } else { + b = appendLiteralFieldLineWithLiteralName(b, itype, name, value) + } + }) + + return b +} diff --git a/src/vendor/golang.org/x/net/internal/http3/qpack_static.go b/src/vendor/golang.org/x/net/internal/http3/qpack_static.go new file mode 100644 index 0000000000..6c0b51c5e6 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/qpack_static.go @@ -0,0 +1,142 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import "sync" + +type tableEntry struct { + name string + value string +} + +// staticTableEntry returns the static table entry with the given index. +func staticTableEntry(index int64) (tableEntry, error) { + if index >= int64(len(staticTableEntries)) { + return tableEntry{}, errQPACKDecompressionFailed + } + return staticTableEntries[index], nil +} + +func initStaticTableMaps() { + staticTableByName = make(map[string]int) + staticTableByNameValue = make(map[tableEntry]int) + for i, ent := range staticTableEntries { + if _, ok := staticTableByName[ent.name]; !ok { + staticTableByName[ent.name] = i + } + staticTableByNameValue[ent] = i + } +} + +var ( + staticTableOnce sync.Once + staticTableByName map[string]int + staticTableByNameValue map[tableEntry]int +) + +// https://www.rfc-editor.org/rfc/rfc9204.html#appendix-A +// +// Note that this is different from the HTTP/2 static table. +var staticTableEntries = [...]tableEntry{ + 0: {":authority", ""}, + 1: {":path", "/"}, + 2: {"age", "0"}, + 3: {"content-disposition", ""}, + 4: {"content-length", "0"}, + 5: {"cookie", ""}, + 6: {"date", ""}, + 7: {"etag", ""}, + 8: {"if-modified-since", ""}, + 9: {"if-none-match", ""}, + 10: {"last-modified", ""}, + 11: {"link", ""}, + 12: {"location", ""}, + 13: {"referer", ""}, + 14: {"set-cookie", ""}, + 15: {":method", "CONNECT"}, + 16: {":method", "DELETE"}, + 17: {":method", "GET"}, + 18: {":method", "HEAD"}, + 19: {":method", "OPTIONS"}, + 20: {":method", "POST"}, + 21: {":method", "PUT"}, + 22: {":scheme", "http"}, + 23: {":scheme", "https"}, + 24: {":status", "103"}, + 25: {":status", "200"}, + 26: {":status", "304"}, + 27: {":status", "404"}, + 28: {":status", "503"}, + 29: {"accept", "*/*"}, + 30: {"accept", "application/dns-message"}, + 31: {"accept-encoding", "gzip, deflate, br"}, + 32: {"accept-ranges", "bytes"}, + 33: {"access-control-allow-headers", "cache-control"}, + 34: {"access-control-allow-headers", "content-type"}, + 35: {"access-control-allow-origin", "*"}, + 36: {"cache-control", "max-age=0"}, + 37: {"cache-control", "max-age=2592000"}, + 38: {"cache-control", "max-age=604800"}, + 39: {"cache-control", "no-cache"}, + 40: {"cache-control", "no-store"}, + 41: {"cache-control", "public, max-age=31536000"}, + 42: {"content-encoding", "br"}, + 43: {"content-encoding", "gzip"}, + 44: {"content-type", "application/dns-message"}, + 45: {"content-type", "application/javascript"}, + 46: {"content-type", "application/json"}, + 47: {"content-type", "application/x-www-form-urlencoded"}, + 48: {"content-type", "image/gif"}, + 49: {"content-type", "image/jpeg"}, + 50: {"content-type", "image/png"}, + 51: {"content-type", "text/css"}, + 52: {"content-type", "text/html; charset=utf-8"}, + 53: {"content-type", "text/plain"}, + 54: {"content-type", "text/plain;charset=utf-8"}, + 55: {"range", "bytes=0-"}, + 56: {"strict-transport-security", "max-age=31536000"}, + 57: {"strict-transport-security", "max-age=31536000; includesubdomains"}, + 58: {"strict-transport-security", "max-age=31536000; includesubdomains; preload"}, + 59: {"vary", "accept-encoding"}, + 60: {"vary", "origin"}, + 61: {"x-content-type-options", "nosniff"}, + 62: {"x-xss-protection", "1; mode=block"}, + 63: {":status", "100"}, + 64: {":status", "204"}, + 65: {":status", "206"}, + 66: {":status", "302"}, + 67: {":status", "400"}, + 68: {":status", "403"}, + 69: {":status", "421"}, + 70: {":status", "425"}, + 71: {":status", "500"}, + 72: {"accept-language", ""}, + 73: {"access-control-allow-credentials", "FALSE"}, + 74: {"access-control-allow-credentials", "TRUE"}, + 75: {"access-control-allow-headers", "*"}, + 76: {"access-control-allow-methods", "get"}, + 77: {"access-control-allow-methods", "get, post, options"}, + 78: {"access-control-allow-methods", "options"}, + 79: {"access-control-expose-headers", "content-length"}, + 80: {"access-control-request-headers", "content-type"}, + 81: {"access-control-request-method", "get"}, + 82: {"access-control-request-method", "post"}, + 83: {"alt-svc", "clear"}, + 84: {"authorization", ""}, + 85: {"content-security-policy", "script-src 'none'; object-src 'none'; base-uri 'none'"}, + 86: {"early-data", "1"}, + 87: {"expect-ct", ""}, + 88: {"forwarded", ""}, + 89: {"if-range", ""}, + 90: {"origin", ""}, + 91: {"purpose", "prefetch"}, + 92: {"server", ""}, + 93: {"timing-allow-origin", "*"}, + 94: {"upgrade-insecure-requests", "1"}, + 95: {"user-agent", ""}, + 96: {"x-forwarded-for", ""}, + 97: {"x-frame-options", "deny"}, + 98: {"x-frame-options", "sameorigin"}, +} diff --git a/src/vendor/golang.org/x/net/internal/http3/quic.go b/src/vendor/golang.org/x/net/internal/http3/quic.go new file mode 100644 index 0000000000..4f1cca179e --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/quic.go @@ -0,0 +1,40 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "crypto/tls" + + "golang.org/x/net/quic" +) + +func initConfig(config *quic.Config) *quic.Config { + if config == nil { + config = &quic.Config{} + } + + // maybeCloneTLSConfig clones the user-provided tls.Config (but only once) + // prior to us modifying it. + needCloneTLSConfig := true + maybeCloneTLSConfig := func() *tls.Config { + if needCloneTLSConfig { + config.TLSConfig = config.TLSConfig.Clone() + needCloneTLSConfig = false + } + return config.TLSConfig + } + + if config.TLSConfig == nil { + config.TLSConfig = &tls.Config{} + needCloneTLSConfig = false + } + if config.TLSConfig.MinVersion == 0 { + maybeCloneTLSConfig().MinVersion = tls.VersionTLS13 + } + if config.TLSConfig.NextProtos == nil { + maybeCloneTLSConfig().NextProtos = []string{"h3"} + } + return config +} diff --git a/src/vendor/golang.org/x/net/internal/http3/roundtrip.go b/src/vendor/golang.org/x/net/internal/http3/roundtrip.go new file mode 100644 index 0000000000..2ea584b773 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/roundtrip.go @@ -0,0 +1,417 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "errors" + "io" + "net/http" + "net/http/httptrace" + "net/textproto" + "strconv" + "sync" + + "golang.org/x/net/http/httpguts" + "golang.org/x/net/internal/httpcommon" +) + +type roundTripState struct { + cc *clientConn + st *stream + + // Request body, provided by the caller. + onceCloseReqBody sync.Once + reqBody io.ReadCloser + + reqBodyWriter bodyWriter + + // Response.Body, provided to the caller. + respBody io.ReadCloser + + trace *httptrace.ClientTrace + + errOnce sync.Once + err error +} + +// abort terminates the RoundTrip. +// It returns the first fatal error encountered by the RoundTrip call. +func (rt *roundTripState) abort(err error) error { + rt.errOnce.Do(func() { + rt.err = err + switch e := err.(type) { + case *connectionError: + rt.cc.abort(e) + case *streamError: + rt.st.stream.CloseRead() + rt.st.stream.Reset(uint64(e.code)) + default: + rt.st.stream.CloseRead() + rt.st.stream.Reset(uint64(errH3NoError)) + } + }) + return rt.err +} + +// closeReqBody closes the Request.Body, at most once. +func (rt *roundTripState) closeReqBody() { + if rt.reqBody != nil { + rt.onceCloseReqBody.Do(func() { + rt.reqBody.Close() + }) + } +} + +// TODO: Set up the rest of the hooks that might be in rt.trace. +func (rt *roundTripState) maybeCallGot1xxResponse(status int, h http.Header) error { + if rt.trace == nil || rt.trace.Got1xxResponse == nil { + return nil + } + return rt.trace.Got1xxResponse(status, textproto.MIMEHeader(h)) +} + +func (rt *roundTripState) maybeCallGot100Continue() { + if rt.trace == nil || rt.trace.Got100Continue == nil { + return + } + rt.trace.Got100Continue() +} + +func (rt *roundTripState) maybeCallWait100Continue() { + if rt.trace == nil || rt.trace.Wait100Continue == nil { + return + } + rt.trace.Wait100Continue() +} + +// RoundTrip sends a request on the connection. +func (cc *clientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) { + // Each request gets its own QUIC stream. + st, err := newConnStream(req.Context(), cc.qconn, streamTypeRequest) + if err != nil { + return nil, err + } + rt := &roundTripState{ + cc: cc, + st: st, + trace: httptrace.ContextClientTrace(req.Context()), + } + defer func() { + if err != nil { + err = rt.abort(err) + } + }() + + // Cancel reads/writes on the stream when the request expires. + st.stream.SetReadContext(req.Context()) + st.stream.SetWriteContext(req.Context()) + + headers := cc.enc.encode(func(yield func(itype indexType, name, value string)) { + _, err = httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{ + Request: httpcommon.Request{ + URL: req.URL, + Method: req.Method, + Host: req.Host, + Header: req.Header, + Trailer: req.Trailer, + ActualContentLength: actualContentLength(req), + }, + AddGzipHeader: false, // TODO: add when appropriate + PeerMaxHeaderListSize: 0, + DefaultUserAgent: "Go-http-client/3", + }, func(name, value string) { + // Issue #71374: Consider supporting never-indexed fields. + yield(mayIndex, name, value) + }) + }) + if err != nil { + return nil, err + } + + // Write the HEADERS frame. + st.writeVarint(int64(frameTypeHeaders)) + st.writeVarint(int64(len(headers))) + st.Write(headers) + if err := st.Flush(); err != nil { + return nil, err + } + + var bodyAndTrailerWritten bool + is100ContinueReq := httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") + if is100ContinueReq { + rt.maybeCallWait100Continue() + } else { + bodyAndTrailerWritten = true + go cc.writeBodyAndTrailer(rt, req) + } + + // Read the response headers. + for { + ftype, err := st.readFrameHeader() + if err != nil { + return nil, err + } + switch ftype { + case frameTypeHeaders: + statusCode, h, err := cc.handleHeaders(st) + if err != nil { + return nil, err + } + + // TODO: Handle 1xx responses. + if isInfoStatus(statusCode) { + if err := rt.maybeCallGot1xxResponse(statusCode, h); err != nil { + return nil, err + } + switch statusCode { + case 100: + rt.maybeCallGot100Continue() + if is100ContinueReq && !bodyAndTrailerWritten { + bodyAndTrailerWritten = true + go cc.writeBodyAndTrailer(rt, req) + continue + } + // If we did not send "Expect: 100-continue" request but + // received status 100 anyways, just continue per usual and + // let the caller decide what to do with the response. + default: + continue + } + } + + // We have the response headers. + // Set up the response and return it to the caller. + contentLength, err := parseResponseContentLength(req.Method, statusCode, h) + if err != nil { + return nil, err + } + + trailer := make(http.Header) + extractTrailerFromHeader(h, trailer) + delete(h, "Trailer") + + if (contentLength != 0 && req.Method != http.MethodHead) || len(trailer) > 0 { + rt.respBody = &bodyReader{ + st: st, + remain: contentLength, + trailer: trailer, + } + } else { + rt.respBody = http.NoBody + } + resp := &http.Response{ + Proto: "HTTP/3.0", + ProtoMajor: 3, + Header: h, + StatusCode: statusCode, + Status: strconv.Itoa(statusCode) + " " + http.StatusText(statusCode), + ContentLength: contentLength, + Trailer: trailer, + Body: (*transportResponseBody)(rt), + } + // TODO: Automatic Content-Type: gzip decoding. + return resp, nil + case frameTypePushPromise: + if err := cc.handlePushPromise(st); err != nil { + return nil, err + } + default: + if err := st.discardUnknownFrame(ftype); err != nil { + return nil, err + } + } + } +} + +// actualContentLength returns a sanitized version of req.ContentLength, +// where 0 actually means zero (not unknown) and -1 means unknown. +func actualContentLength(req *http.Request) int64 { + if req.Body == nil || req.Body == http.NoBody { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + +// writeBodyAndTrailer handles writing the body and trailer for a given +// request, if any. This function will close the write direction of the stream. +func (cc *clientConn) writeBodyAndTrailer(rt *roundTripState, req *http.Request) { + defer rt.closeReqBody() + + declaredTrailer := req.Trailer.Clone() + + rt.reqBody = req.Body + rt.reqBodyWriter.st = rt.st + rt.reqBodyWriter.remain = actualContentLength(req) + rt.reqBodyWriter.flush = true + rt.reqBodyWriter.name = "request" + rt.reqBodyWriter.trailer = req.Trailer + rt.reqBodyWriter.enc = &cc.enc + if req.Body == nil { + rt.reqBody = http.NoBody + } + + if _, err := io.Copy(&rt.reqBodyWriter, rt.reqBody); err != nil { + rt.abort(err) + } + // Get rid of any trailer that was not declared beforehand, before we + // close the request body which will cause the trailer headers to be + // written. + for name := range req.Trailer { + if _, ok := declaredTrailer[name]; !ok { + delete(req.Trailer, name) + } + } + if err := rt.reqBodyWriter.Close(); err != nil { + rt.abort(err) + } +} + +// transportResponseBody is the Response.Body returned by RoundTrip. +type transportResponseBody roundTripState + +// Read is Response.Body.Read. +func (b *transportResponseBody) Read(p []byte) (n int, err error) { + return b.respBody.Read(p) +} + +var errRespBodyClosed = errors.New("response body closed") + +// Close is Response.Body.Close. +// Closing the response body is how the caller signals that they're done with a request. +func (b *transportResponseBody) Close() error { + rt := (*roundTripState)(b) + // Close the request body, which should wake up copyRequestBody if it's + // currently blocked reading the body. + rt.closeReqBody() + // Close the request stream, since we're done with the request. + // Reset closes the sending half of the stream. + rt.st.stream.Reset(uint64(errH3NoError)) + // respBody.Close is responsible for closing the receiving half. + err := rt.respBody.Close() + if err == nil { + err = errRespBodyClosed + } + err = rt.abort(err) + if err == errRespBodyClosed { + // No other errors occurred before closing Response.Body, + // so consider this a successful request. + return nil + } + return err +} + +func parseResponseContentLength(method string, statusCode int, h http.Header) (int64, error) { + clens := h["Content-Length"] + if len(clens) == 0 { + return -1, nil + } + + // We allow duplicate Content-Length headers, + // but only if they all have the same value. + for _, v := range clens[1:] { + if clens[0] != v { + return -1, &streamError{errH3MessageError, "mismatching Content-Length headers"} + } + } + + // "A server MUST NOT send a Content-Length header field in any response + // with a status code of 1xx (Informational) or 204 (No Content). + // A server MUST NOT send a Content-Length header field in any 2xx (Successful) + // response to a CONNECT request [...]" + // https://www.rfc-editor.org/rfc/rfc9110#section-8.6-8 + if (statusCode >= 100 && statusCode < 200) || + statusCode == 204 || + (method == "CONNECT" && statusCode >= 200 && statusCode < 300) { + // This is a protocol violation, but a fairly harmless one. + // Just ignore the header. + return -1, nil + } + + contentLen, err := strconv.ParseUint(clens[0], 10, 63) + if err != nil { + return -1, &streamError{errH3MessageError, "invalid Content-Length header"} + } + return int64(contentLen), nil +} + +func (cc *clientConn) handleHeaders(st *stream) (statusCode int, h http.Header, err error) { + haveStatus := false + cookie := "" + // Issue #71374: Consider tracking the never-indexed status of headers + // with the N bit set in their QPACK encoding. + err = cc.dec.decode(st, func(_ indexType, name, value string) error { + switch { + case name == ":status": + if haveStatus { + return &streamError{errH3MessageError, "duplicate :status"} + } + haveStatus = true + statusCode, err = strconv.Atoi(value) + if err != nil { + return &streamError{errH3MessageError, "invalid :status"} + } + case name[0] == ':': + // "Endpoints MUST treat a request or response + // that contains undefined or invalid + // pseudo-header fields as malformed." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3-3 + return &streamError{errH3MessageError, "undefined pseudo-header"} + case name == "cookie": + // "If a decompressed field section contains multiple cookie field lines, + // these MUST be concatenated into a single byte string [...]" + // using the two-byte delimiter of "; "'' + // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2.1-2 + if cookie == "" { + cookie = value + } else { + cookie += "; " + value + } + default: + if h == nil { + h = make(http.Header) + } + // TODO: Use a per-connection canonicalization cache as we do in HTTP/2. + // Maybe we could put this in the QPACK decoder and have it deliver + // pre-canonicalized headers to us here? + cname := httpcommon.CanonicalHeader(name) + // TODO: Consider using a single []string slice for all headers, + // as we do in the HTTP/1 and HTTP/2 cases. + // This is a bit tricky, since we don't know the number of headers + // at the start of decoding. Perhaps it's worth doing a two-pass decode, + // or perhaps we should just allocate header value slices in + // reasonably-sized chunks. + h[cname] = append(h[cname], value) + } + return nil + }) + if !haveStatus { + // "[The :status] pseudo-header field MUST be included in all responses [...]" + // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3.2-1 + err = errH3MessageError + } + if cookie != "" { + if h == nil { + h = make(http.Header) + } + h["Cookie"] = []string{cookie} + } + if err := st.endFrame(); err != nil { + return 0, nil, err + } + return statusCode, h, err +} + +func (cc *clientConn) handlePushPromise(st *stream) error { + // "A client MUST treat receipt of a PUSH_PROMISE frame that contains a + // larger push ID than the client has advertised as a connection error of H3_ID_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.5-5 + return &connectionError{ + code: errH3IDError, + message: "PUSH_PROMISE received when no MAX_PUSH_ID has been sent", + } +} diff --git a/src/vendor/golang.org/x/net/internal/http3/server.go b/src/vendor/golang.org/x/net/internal/http3/server.go new file mode 100644 index 0000000000..28c8cda849 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/server.go @@ -0,0 +1,791 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "maps" + "net/http" + "slices" + "strconv" + "strings" + "sync" + "time" + + "golang.org/x/net/http/httpguts" + "golang.org/x/net/internal/httpcommon" + "golang.org/x/net/quic" +) + +// A server is an HTTP/3 server. +// The zero value for server is a valid server. +type server struct { + // handler to invoke for requests, http.DefaultServeMux if nil. + handler http.Handler + + config *quic.Config + + listenQUIC func(addr string, config *quic.Config) (*quic.Endpoint, error) + + initOnce sync.Once + + serveCtx context.Context + serveCtxCancel context.CancelFunc + + // connClosed is used to signal that a connection has been unregistered + // from activeConns. That way, when shutting down gracefully, the server + // can avoid busy-waiting for activeConns to be empty. + connClosed chan any + mu sync.Mutex // Guards fields below. + activeConns map[*serverConn]struct{} +} + +// netHTTPHandler is an interface that is implemented by +// net/http.http3ServerHandler in std. +// +// It provides a way for information to be passed between x/net and net/http +// that would otherwise be inaccessible, such as the TLS configs that users +// have supplied to net/http servers. +// +// This allows us to integrate our HTTP/3 server implementation with the +// net/http server when RegisterServer is called. +type netHTTPHandler interface { + http.Handler + TLSConfig() *tls.Config + BaseContext() context.Context + Addr() string + ListenErrHook(err error) + ShutdownContext() context.Context +} + +type ServerOpts struct { + // ListenQUIC determines how the server will open a QUIC endpoint. + // By default, quic.Listen("udp", addr, config) is used. + ListenQUIC func(addr string, config *quic.Config) (*quic.Endpoint, error) + + // QUICConfig is the QUIC configuration used by the server. + // QUICConfig may be nil and should not be modified after calling + // RegisterServer. + // If QUICConfig.TLSConfig is nil, the TLSConfig of the net/http Server + // given to RegisterServer will be used. + QUICConfig *quic.Config +} + +// RegisterServer adds HTTP/3 support to a net/http Server. +// +// RegisterServer must be called before s begins serving, and only affects +// s.ListenAndServeTLS. +func RegisterServer(s *http.Server, opts ServerOpts) { + if s.TLSNextProto == nil { + s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + } + s.TLSNextProto["http/3"] = func(s *http.Server, c *tls.Conn, h http.Handler) { + stdHandler, ok := h.(netHTTPHandler) + if !ok { + panic("RegisterServer was given a server that does not implement netHTTPHandler") + } + if opts.QUICConfig == nil { + opts.QUICConfig = &quic.Config{} + } + if opts.QUICConfig.TLSConfig == nil { + opts.QUICConfig.TLSConfig = stdHandler.TLSConfig() + } + s3 := &server{ + config: opts.QUICConfig, + listenQUIC: opts.ListenQUIC, + handler: stdHandler, + serveCtx: stdHandler.BaseContext(), + } + s3.init() + s.RegisterOnShutdown(func() { + s3.shutdown(stdHandler.ShutdownContext()) + }) + stdHandler.ListenErrHook(s3.listenAndServe(stdHandler.Addr())) + } +} + +func (s *server) init() { + s.initOnce.Do(func() { + s.config = initConfig(s.config) + if s.handler == nil { + s.handler = http.DefaultServeMux + } + if s.serveCtx == nil { + s.serveCtx = context.Background() + } + if s.listenQUIC == nil { + s.listenQUIC = func(addr string, config *quic.Config) (*quic.Endpoint, error) { + return quic.Listen("udp", addr, config) + } + } + s.serveCtx, s.serveCtxCancel = context.WithCancel(s.serveCtx) + s.activeConns = make(map[*serverConn]struct{}) + s.connClosed = make(chan any, 1) + }) +} + +// listenAndServe listens on the UDP network address addr +// and then calls Serve to handle requests on incoming connections. +func (s *server) listenAndServe(addr string) error { + s.init() + e, err := s.listenQUIC(addr, s.config) + if err != nil { + return err + } + go s.serve(e) + return nil +} + +// serve accepts incoming connections on the QUIC endpoint e, +// and handles requests from those connections. +func (s *server) serve(e *quic.Endpoint) error { + s.init() + defer e.Close(canceledCtx) + for { + qconn, err := e.Accept(s.serveCtx) + if err != nil { + return err + } + go s.newServerConn(qconn, s.handler) + } +} + +// shutdown attempts a graceful shutdown for the server. +func (s *server) shutdown(ctx context.Context) { + // Set a reasonable default in case ctx is nil. + if ctx == nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + } + + // Send GOAWAY frames to all active connections to give a chance for them + // to gracefully terminate. + s.mu.Lock() + for sc := range s.activeConns { + // TODO: Modify x/net/quic stream API so that write errors from context + // deadline are sticky. + go sc.sendGoaway() + } + s.mu.Unlock() + + // Complete shutdown as soon as there are no more active connections or ctx + // is done, whichever comes first. + defer func() { + s.mu.Lock() + defer s.mu.Unlock() + s.serveCtxCancel() + for sc := range s.activeConns { + sc.abort(&connectionError{ + code: errH3NoError, + message: "server is shutting down", + }) + } + }() + noMoreConns := func() bool { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.activeConns) == 0 + } + for { + if noMoreConns() { + return + } + select { + case <-ctx.Done(): + return + case <-s.connClosed: + } + } +} + +func (s *server) registerConn(sc *serverConn) { + s.mu.Lock() + defer s.mu.Unlock() + s.activeConns[sc] = struct{}{} +} + +func (s *server) unregisterConn(sc *serverConn) { + s.mu.Lock() + delete(s.activeConns, sc) + s.mu.Unlock() + select { + case s.connClosed <- struct{}{}: + default: + // Channel already full. No need to send more values since we are just + // using this channel as a simpler sync.Cond. + } +} + +type serverConn struct { + qconn *quic.Conn + + genericConn // for handleUnidirectionalStream + enc qpackEncoder + dec qpackDecoder + handler http.Handler + + // For handling shutdown. + controlStream *stream + mu sync.Mutex // Guards everything below. + maxRequestStreamID int64 + goawaySent bool +} + +func (s *server) newServerConn(qconn *quic.Conn, handler http.Handler) { + sc := &serverConn{ + qconn: qconn, + handler: handler, + } + s.registerConn(sc) + defer s.unregisterConn(sc) + sc.enc.init() + + // Create control stream and send SETTINGS frame. + // TODO: Time out on creating stream. + var err error + sc.controlStream, err = newConnStream(context.Background(), sc.qconn, streamTypeControl) + if err != nil { + return + } + sc.controlStream.writeSettings() + sc.controlStream.Flush() + + sc.acceptStreams(sc.qconn, sc) +} + +func (sc *serverConn) handleControlStream(st *stream) error { + // "A SETTINGS frame MUST be sent as the first frame of each control stream [...]" + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-2 + if err := st.readSettings(func(settingsType, settingsValue int64) error { + switch settingsType { + case settingsMaxFieldSectionSize: + _ = settingsValue // TODO + case settingsQPACKMaxTableCapacity: + _ = settingsValue // TODO + case settingsQPACKBlockedStreams: + _ = settingsValue // TODO + default: + // Unknown settings types are ignored. + } + return nil + }); err != nil { + return err + } + + for { + ftype, err := st.readFrameHeader() + if err != nil { + return err + } + switch ftype { + case frameTypeCancelPush: + // "If a server receives a CANCEL_PUSH frame for a push ID + // that has not yet been mentioned by a PUSH_PROMISE frame, + // this MUST be treated as a connection error of type H3_ID_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.3-8 + return &connectionError{ + code: errH3IDError, + message: "CANCEL_PUSH for unsent push ID", + } + case frameTypeGoaway: + return errH3NoError + default: + // Unknown frames are ignored. + if err := st.discardUnknownFrame(ftype); err != nil { + return err + } + } + } +} + +func (sc *serverConn) handleEncoderStream(*stream) error { + // TODO + return nil +} + +func (sc *serverConn) handleDecoderStream(*stream) error { + // TODO + return nil +} + +func (sc *serverConn) handlePushStream(*stream) error { + // "[...] if a server receives a client-initiated push stream, + // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3 + return &connectionError{ + code: errH3StreamCreationError, + message: "client created push stream", + } +} + +type pseudoHeader struct { + method string + scheme string + path string + authority string +} + +func (sc *serverConn) parseHeader(st *stream) (http.Header, pseudoHeader, error) { + ftype, err := st.readFrameHeader() + if err != nil { + return nil, pseudoHeader{}, err + } + if ftype != frameTypeHeaders { + return nil, pseudoHeader{}, err + } + header := make(http.Header) + var pHeader pseudoHeader + var dec qpackDecoder + if err := dec.decode(st, func(_ indexType, name, value string) error { + switch name { + case ":method": + pHeader.method = value + case ":scheme": + pHeader.scheme = value + case ":path": + pHeader.path = value + case ":authority": + pHeader.authority = value + default: + header.Add(name, value) + } + return nil + }); err != nil { + return nil, pseudoHeader{}, err + } + if err := st.endFrame(); err != nil { + return nil, pseudoHeader{}, err + } + return header, pHeader, nil +} + +func (sc *serverConn) sendGoaway() { + sc.mu.Lock() + if sc.goawaySent || sc.controlStream == nil { + sc.mu.Unlock() + return + } + sc.goawaySent = true + sc.mu.Unlock() + + // No lock in this section in case writing to stream blocks. This is safe + // since sc.maxRequestStreamID is only updated when sc.goawaySent is false. + sc.controlStream.writeVarint(int64(frameTypeGoaway)) + sc.controlStream.writeVarint(int64(sizeVarint(uint64(sc.maxRequestStreamID)))) + sc.controlStream.writeVarint(sc.maxRequestStreamID) + sc.controlStream.Flush() +} + +// requestShouldGoAway returns true if st has a stream ID that is equal or +// greater than the ID we have sent in a GOAWAY frame, if any. +func (sc *serverConn) requestShouldGoaway(st *stream) bool { + sc.mu.Lock() + defer sc.mu.Unlock() + if sc.goawaySent { + return st.stream.ID() >= sc.maxRequestStreamID + } else { + sc.maxRequestStreamID = max(sc.maxRequestStreamID, st.stream.ID()) + return false + } +} + +func (sc *serverConn) handleRequestStream(st *stream) error { + if sc.requestShouldGoaway(st) { + return &streamError{ + code: errH3RequestRejected, + message: "GOAWAY request with equal or lower ID than the stream has been sent", + } + } + header, pHeader, err := sc.parseHeader(st) + if err != nil { + return err + } + + reqInfo := httpcommon.NewServerRequest(httpcommon.ServerRequestParam{ + Method: pHeader.method, + Scheme: pHeader.scheme, + Authority: pHeader.authority, + Path: pHeader.path, + Header: header, + }) + if reqInfo.InvalidReason != "" { + return &streamError{ + code: errH3MessageError, + message: reqInfo.InvalidReason, + } + } + + var body io.ReadCloser + contentLength := int64(-1) + if n, err := strconv.Atoi(header.Get("Content-Length")); err == nil { + contentLength = int64(n) + } + if contentLength != 0 || len(reqInfo.Trailer) != 0 { + body = &bodyReader{ + st: st, + remain: contentLength, + trailer: reqInfo.Trailer, + } + } else { + body = http.NoBody + } + + req := &http.Request{ + Proto: "HTTP/3.0", + Method: pHeader.method, + Host: pHeader.authority, + URL: reqInfo.URL, + RequestURI: reqInfo.RequestURI, + Trailer: reqInfo.Trailer, + ProtoMajor: 3, + RemoteAddr: sc.qconn.RemoteAddr().String(), + Body: body, + Header: header, + ContentLength: contentLength, + } + defer req.Body.Close() + + rw := &responseWriter{ + st: st, + headers: make(http.Header), + trailer: make(http.Header), + bb: make(bodyBuffer, 0, defaultBodyBufferCap), + cannotHaveBody: req.Method == "HEAD", + bw: &bodyWriter{ + st: st, + remain: -1, + flush: false, + name: "response", + enc: &sc.enc, + }, + } + defer rw.close() + if reqInfo.NeedsContinue { + req.Body.(*bodyReader).send100Continue = func() { + rw.WriteHeader(100) + } + } + + // TODO: handle panic coming from the HTTP handler. + sc.handler.ServeHTTP(rw, req) + return nil +} + +// abort closes the connection with an error. +func (sc *serverConn) abort(err error) { + if e, ok := err.(*connectionError); ok { + sc.qconn.Abort(&quic.ApplicationError{ + Code: uint64(e.code), + Reason: e.message, + }) + } else { + sc.qconn.Abort(err) + } +} + +// responseCanHaveBody reports whether a given response status code permits a +// body. See RFC 7230, section 3.3. +func responseCanHaveBody(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +type responseWriter struct { + st *stream + bw *bodyWriter + mu sync.Mutex + headers http.Header + trailer http.Header + bb bodyBuffer + wroteHeader bool // Non-1xx header has been (logically) written. + statusCode int // Status of the response that will be sent in HEADERS frame. + statusCodeSet bool // Status of the response has been set via a call to WriteHeader. + cannotHaveBody bool // Response should not have a body (e.g. response to a HEAD request). + bodyLenLeft int // How much of the content body is left to be sent, set via "Content-Length" header. -1 if unknown. +} + +func (rw *responseWriter) Header() http.Header { + return rw.headers +} + +// prepareTrailerForWriteLocked populates any pre-declared trailer header with +// its value, and passes it to bodyWriter so it can be written after body EOF. +// Caller must hold rw.mu. +func (rw *responseWriter) prepareTrailerForWriteLocked() { + for name := range rw.trailer { + if val, ok := rw.headers[name]; ok { + rw.trailer[name] = val + } else { + delete(rw.trailer, name) + } + } + if len(rw.trailer) > 0 { + rw.bw.trailer = rw.trailer + } +} + +// writeHeaderLockedOnce writes the final response header. If rw.wroteHeader is +// true, calling this method is a no-op. Sending informational status headers +// should be done using writeInfoHeaderLocked, rather than this method. +// Caller must hold rw.mu. +func (rw *responseWriter) writeHeaderLockedOnce() { + if rw.wroteHeader { + return + } + if !responseCanHaveBody(rw.statusCode) { + rw.cannotHaveBody = true + } + // If there is any Trailer declared in headers, save them so we know which + // trailers have been pre-declared. Also, write back the extracted value, + // which is canonicalized, to rw.Header for consistency. + if _, ok := rw.headers["Trailer"]; ok { + extractTrailerFromHeader(rw.headers, rw.trailer) + rw.headers.Set("Trailer", strings.Join(slices.Sorted(maps.Keys(rw.trailer)), ", ")) + } + + rw.bb.inferHeader(rw.headers, rw.statusCode) + encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) { + f(mayIndex, ":status", strconv.Itoa(rw.statusCode)) + for name, values := range rw.headers { + if !httpguts.ValidHeaderFieldName(name) { + continue + } + for _, val := range values { + if !httpguts.ValidHeaderFieldValue(val) { + continue + } + // Issue #71374: Consider supporting never-indexed fields. + f(mayIndex, name, val) + } + } + }) + + rw.st.writeVarint(int64(frameTypeHeaders)) + rw.st.writeVarint(int64(len(encHeaders))) + rw.st.Write(encHeaders) + rw.wroteHeader = true +} + +// writeHeaderLocked writes informational status headers (i.e. status 1XX). +// If a non-informational status header has been written via +// writeHeaderLockedOnce, this method is a no-op. +// Caller must hold rw.mu. +func (rw *responseWriter) writeHeaderLocked(statusCode int) { + if rw.wroteHeader { + return + } + encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) { + f(mayIndex, ":status", strconv.Itoa(statusCode)) + for name, values := range rw.headers { + if name == "Content-Length" || name == "Transfer-Encoding" { + continue + } + if !httpguts.ValidHeaderFieldName(name) { + continue + } + for _, val := range values { + if !httpguts.ValidHeaderFieldValue(val) { + continue + } + // Issue #71374: Consider supporting never-indexed fields. + f(mayIndex, name, val) + } + } + }) + rw.st.writeVarint(int64(frameTypeHeaders)) + rw.st.writeVarint(int64(len(encHeaders))) + rw.st.Write(encHeaders) +} + +func isInfoStatus(status int) bool { + return status >= 100 && status < 200 +} + +// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. +func checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at http://httpwg.org/specs/rfc7231.html#status.codes). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/3, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +func (rw *responseWriter) WriteHeader(statusCode int) { + // TODO: handle sending informational status headers (e.g. 103). + rw.mu.Lock() + defer rw.mu.Unlock() + if rw.statusCodeSet { + return + } + checkWriteHeaderCode(statusCode) + + // Informational headers can be sent multiple times, and should be flushed + // immediately. + if isInfoStatus(statusCode) { + rw.writeHeaderLocked(statusCode) + rw.st.Flush() + return + } + + // Non-informational headers should only be set once, and should be + // buffered. + rw.statusCodeSet = true + rw.statusCode = statusCode + if n, err := strconv.Atoi(rw.Header().Get("Content-Length")); err == nil { + rw.bodyLenLeft = n + } else { + rw.bodyLenLeft = -1 // Unknown. + } +} + +// trimWriteLocked trims a byte slice, b, such that the length of b will not +// exceed rw.bodyLenLeft. This method will update rw.bodyLenLeft when trimming +// b, and will also return whether b was trimmed or not. +// Caller must hold rw.mu. +func (rw *responseWriter) trimWriteLocked(b []byte) ([]byte, bool) { + if rw.bodyLenLeft < 0 { + return b, false + } + n := min(len(b), rw.bodyLenLeft) + rw.bodyLenLeft -= n + return b[:n], n != len(b) +} + +func (rw *responseWriter) Write(b []byte) (n int, err error) { + // Calling Write implicitly calls WriteHeader(200) if WriteHeader has not + // been called before. + rw.WriteHeader(http.StatusOK) + rw.mu.Lock() + defer rw.mu.Unlock() + + if rw.statusCode == http.StatusNotModified { + return 0, http.ErrBodyNotAllowed + } + + b, trimmed := rw.trimWriteLocked(b) + if trimmed { + defer func() { + err = http.ErrContentLength + }() + } + + // If b fits entirely in our body buffer, save it to the buffer and return + // early so we can coalesce small writes. + // As a special case, we always want to save b to the buffer even when b is + // big if we had yet to write our header, so we can infer headers like + // "Content-Type" with as much information as possible. + initialBLen := len(b) + initialBufLen := len(rw.bb) + if !rw.wroteHeader || len(b) <= cap(rw.bb)-len(rw.bb) { + b = rw.bb.write(b) + if len(b) == 0 { + return initialBLen, nil + } + } + + // Reaching this point means that our buffer has been sufficiently filled. + // Therefore, we now want to: + // 1. Infer and write response headers based on our body buffer, if not + // done yet. + // 2. Write our body buffer and the rest of b (if any). + // 3. Reset the current body buffer so it can be used again. + rw.writeHeaderLockedOnce() + if rw.cannotHaveBody { + return initialBLen, nil + } + if n, err := rw.bw.write(rw.bb, b); err != nil { + return max(0, n-initialBufLen), err + } + rw.bb.discard() + return initialBLen, nil +} + +func (rw *responseWriter) Flush() { + // Calling Flush implicitly calls WriteHeader(200) if WriteHeader has not + // been called before. + rw.WriteHeader(http.StatusOK) + rw.mu.Lock() + defer rw.mu.Unlock() + rw.writeHeaderLockedOnce() + if !rw.cannotHaveBody { + rw.bw.Write(rw.bb) + rw.bb.discard() + } + rw.st.Flush() +} + +func (rw *responseWriter) close() error { + rw.Flush() + rw.mu.Lock() + defer rw.mu.Unlock() + rw.prepareTrailerForWriteLocked() + if err := rw.bw.Close(); err != nil { + return err + } + return rw.st.stream.Close() +} + +// defaultBodyBufferCap is the default number of bytes of body that we are +// willing to save in a buffer for the sake of inferring headers and coalescing +// small writes. 512 was chosen to be consistent with how much +// http.DetectContentType is willing to read. +const defaultBodyBufferCap = 512 + +// bodyBuffer is a buffer used to store body content of a response. +type bodyBuffer []byte + +// write writes b to the buffer. It returns a new slice of b, which contains +// any remaining data that could not be written to the buffer, if any. +func (bb *bodyBuffer) write(b []byte) []byte { + n := min(len(b), cap(*bb)-len(*bb)) + *bb = append(*bb, b[:n]...) + return b[n:] +} + +// discard resets the buffer so it can be used again. +func (bb *bodyBuffer) discard() { + *bb = (*bb)[:0] +} + +// inferHeader populates h with the header values that we can infer from our +// current buffer content, if not already explicitly set. This method should be +// called only once with as much body content as possible in the buffer, before +// a HEADERS frame is sent, and before discard has been called. Doing so +// properly is the responsibility of the caller. +func (bb *bodyBuffer) inferHeader(h http.Header, status int) { + if _, ok := h["Date"]; !ok { + h.Set("Date", time.Now().UTC().Format(http.TimeFormat)) + } + // If the Content-Encoding is non-blank, we shouldn't + // sniff the body. See Issue golang.org/issue/31753. + _, hasCE := h["Content-Encoding"] + _, hasCT := h["Content-Type"] + if !hasCE && !hasCT && responseCanHaveBody(status) && len(*bb) > 0 { + h.Set("Content-Type", http.DetectContentType(*bb)) + } + // We can technically infer Content-Length too here, as long as the entire + // response body fits within hi.buf and does not require flushing. However, + // we have chosen not to do so for now as Content-Length is not very + // important for HTTP/3, and such inconsistent behavior might be confusing. +} diff --git a/src/vendor/golang.org/x/net/internal/http3/settings.go b/src/vendor/golang.org/x/net/internal/http3/settings.go new file mode 100644 index 0000000000..2d2ca0e705 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/settings.go @@ -0,0 +1,66 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +const ( + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1 + settingsMaxFieldSectionSize = 0x06 + + // https://www.rfc-editor.org/rfc/rfc9204.html#section-5 + settingsQPACKMaxTableCapacity = 0x01 + settingsQPACKBlockedStreams = 0x07 +) + +// writeSettings writes a complete SETTINGS frame. +// Its parameter is a list of alternating setting types and values. +func (st *stream) writeSettings(settings ...int64) { + var size int64 + for _, s := range settings { + // Settings values that don't fit in a QUIC varint ([0,2^62)) will panic here. + size += int64(sizeVarint(uint64(s))) + } + st.writeVarint(int64(frameTypeSettings)) + st.writeVarint(size) + for _, s := range settings { + st.writeVarint(s) + } +} + +// readSettings reads a complete SETTINGS frame, including the frame header. +func (st *stream) readSettings(f func(settingType, value int64) error) error { + frameType, err := st.readFrameHeader() + if err != nil || frameType != frameTypeSettings { + return &connectionError{ + code: errH3MissingSettings, + message: "settings not sent on control stream", + } + } + for st.lim > 0 { + settingsType, err := st.readVarint() + if err != nil { + return err + } + settingsValue, err := st.readVarint() + if err != nil { + return err + } + + // Use of HTTP/2 settings where there is no corresponding HTTP/3 setting + // is an error. + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1-5 + switch settingsType { + case 0x02, 0x03, 0x04, 0x05: + return &connectionError{ + code: errH3SettingsError, + message: "use of reserved setting", + } + } + + if err := f(settingsType, settingsValue); err != nil { + return err + } + } + return st.endFrame() +} diff --git a/src/vendor/golang.org/x/net/internal/http3/stream.go b/src/vendor/golang.org/x/net/internal/http3/stream.go new file mode 100644 index 0000000000..93294d43de --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/stream.go @@ -0,0 +1,260 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "context" + "io" + + "golang.org/x/net/quic" +) + +// A stream wraps a QUIC stream, providing methods to read/write various values. +type stream struct { + stream *quic.Stream + + // lim is the current read limit. + // Reading a frame header sets the limit to the end of the frame. + // Reading past the limit or reading less than the limit and ending the frame + // results in an error. + // -1 indicates no limit. + lim int64 +} + +// newConnStream creates a new stream on a connection. +// It writes the stream header for unidirectional streams. +// +// The stream returned by newStream is not flushed, +// and will not be sent to the peer until the caller calls +// Flush or writes enough data to the stream. +func newConnStream(ctx context.Context, qconn *quic.Conn, stype streamType) (*stream, error) { + var qs *quic.Stream + var err error + if stype == streamTypeRequest { + // Request streams are bidirectional. + qs, err = qconn.NewStream(ctx) + } else { + // All other streams are unidirectional. + qs, err = qconn.NewSendOnlyStream(ctx) + } + if err != nil { + return nil, err + } + st := &stream{ + stream: qs, + lim: -1, // no limit + } + if stype != streamTypeRequest { + // Unidirectional stream header. + st.writeVarint(int64(stype)) + } + return st, err +} + +func newStream(qs *quic.Stream) *stream { + return &stream{ + stream: qs, + lim: -1, // no limit + } +} + +// readFrameHeader reads the type and length fields of an HTTP/3 frame. +// It sets the read limit to the end of the frame. +// +// https://www.rfc-editor.org/rfc/rfc9114.html#section-7.1 +func (st *stream) readFrameHeader() (ftype frameType, err error) { + if st.lim >= 0 { + // We shouldn't call readFrameHeader before ending the previous frame. + return 0, errH3FrameError + } + ftype, err = readVarint[frameType](st) + if err != nil { + return 0, err + } + size, err := st.readVarint() + if err != nil { + return 0, err + } + st.lim = size + return ftype, nil +} + +// endFrame is called after reading a frame to reset the read limit. +// It returns an error if the entire contents of a frame have not been read. +func (st *stream) endFrame() error { + if st.lim != 0 { + return &connectionError{ + code: errH3FrameError, + message: "invalid HTTP/3 frame", + } + } + st.lim = -1 + return nil +} + +// readFrameData returns the remaining data in the current frame. +func (st *stream) readFrameData() ([]byte, error) { + if st.lim < 0 { + return nil, errH3FrameError + } + // TODO: Pool buffers to avoid allocation here. + b := make([]byte, st.lim) + _, err := io.ReadFull(st, b) + if err != nil { + return nil, err + } + return b, nil +} + +// ReadByte reads one byte from the stream. +func (st *stream) ReadByte() (b byte, err error) { + if err := st.recordBytesRead(1); err != nil { + return 0, err + } + b, err = st.stream.ReadByte() + if err != nil { + if err == io.EOF && st.lim < 0 { + return 0, io.EOF + } + return 0, errH3FrameError + } + return b, nil +} + +// Read reads from the stream. +func (st *stream) Read(b []byte) (int, error) { + n, err := st.stream.Read(b) + if e2 := st.recordBytesRead(n); e2 != nil { + return 0, e2 + } + if err == io.EOF { + if st.lim == 0 { + // EOF at end of frame, ignore. + return n, nil + } else if st.lim > 0 { + // EOF inside frame, error. + return 0, errH3FrameError + } else { + // EOF outside of frame, surface to caller. + return n, io.EOF + } + } + if err != nil { + return 0, errH3FrameError + } + return n, nil +} + +// discardUnknownFrame discards an unknown frame. +// +// HTTP/3 requires that unknown frames be ignored on all streams. +// However, a known frame appearing in an unexpected place is a fatal error, +// so this returns an error if the frame is one we know. +func (st *stream) discardUnknownFrame(ftype frameType) error { + switch ftype { + case frameTypeData, + frameTypeHeaders, + frameTypeCancelPush, + frameTypeSettings, + frameTypePushPromise, + frameTypeGoaway, + frameTypeMaxPushID: + return &connectionError{ + code: errH3FrameUnexpected, + message: "unexpected " + ftype.String() + " frame", + } + } + return st.discardFrame() +} + +// discardFrame discards any remaining data in the current frame and resets the read limit. +func (st *stream) discardFrame() error { + // TODO: Consider adding a *quic.Stream method to discard some amount of data. + for range st.lim { + _, err := st.stream.ReadByte() + if err != nil { + return &streamError{errH3FrameError, err.Error()} + } + } + st.lim = -1 + return nil +} + +// Write writes to the stream. +func (st *stream) Write(b []byte) (int, error) { return st.stream.Write(b) } + +// Flush commits data written to the stream. +func (st *stream) Flush() error { return st.stream.Flush() } + +// readVarint reads a QUIC variable-length integer from the stream. +func (st *stream) readVarint() (v int64, err error) { + b, err := st.stream.ReadByte() + if err != nil { + return 0, err + } + v = int64(b & 0x3f) + n := 1 << (b >> 6) + for i := 1; i < n; i++ { + b, err := st.stream.ReadByte() + if err != nil { + return 0, errH3FrameError + } + v = (v << 8) | int64(b) + } + if err := st.recordBytesRead(n); err != nil { + return 0, err + } + return v, nil +} + +// readVarint reads a varint of a particular type. +func readVarint[T ~int64 | ~uint64](st *stream) (T, error) { + v, err := st.readVarint() + return T(v), err +} + +// writeVarint writes a QUIC variable-length integer to the stream. +func (st *stream) writeVarint(v int64) { + switch { + case v <= (1<<6)-1: + st.stream.WriteByte(byte(v)) + case v <= (1<<14)-1: + st.stream.WriteByte((1 << 6) | byte(v>>8)) + st.stream.WriteByte(byte(v)) + case v <= (1<<30)-1: + st.stream.WriteByte((2 << 6) | byte(v>>24)) + st.stream.WriteByte(byte(v >> 16)) + st.stream.WriteByte(byte(v >> 8)) + st.stream.WriteByte(byte(v)) + case v <= (1<<62)-1: + st.stream.WriteByte((3 << 6) | byte(v>>56)) + st.stream.WriteByte(byte(v >> 48)) + st.stream.WriteByte(byte(v >> 40)) + st.stream.WriteByte(byte(v >> 32)) + st.stream.WriteByte(byte(v >> 24)) + st.stream.WriteByte(byte(v >> 16)) + st.stream.WriteByte(byte(v >> 8)) + st.stream.WriteByte(byte(v)) + default: + panic("varint too large") + } +} + +// recordBytesRead records that n bytes have been read. +// It returns an error if the read passes the current limit. +func (st *stream) recordBytesRead(n int) error { + if st.lim < 0 { + return nil + } + st.lim -= int64(n) + if st.lim < 0 { + st.stream = nil // panic if we try to read again + return &connectionError{ + code: errH3FrameError, + message: "invalid HTTP/3 frame", + } + } + return nil +} diff --git a/src/vendor/golang.org/x/net/internal/http3/transport.go b/src/vendor/golang.org/x/net/internal/http3/transport.go new file mode 100644 index 0000000000..a99824c9f4 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/transport.go @@ -0,0 +1,284 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "context" + "fmt" + "net/http" + "net/url" + "sync" + + "golang.org/x/net/quic" +) + +// A transport is an HTTP/3 transport. +// +// It does not manage a pool of connections, +// and therefore does not implement net/http.RoundTripper. +// +// TODO: Provide a way to register an HTTP/3 transport with a net/http.transport's +// connection pool. +type transport struct { + // config is the QUIC configuration used for client connections. + config *quic.Config + + mu sync.Mutex // Guards fields below. + // endpoint is the QUIC endpoint used by connections created by the + // transport. If CloseIdleConnections is called when activeConns is empty, + // endpoint will be unset. If unset, endpoint will be initialized by any + // call to dial. + endpoint *quic.Endpoint + activeConns map[*clientConn]struct{} + inFlightDials int +} + +// netHTTPTransport implements the net/http.dialClientConner interface, +// allowing our HTTP/3 transport to integrate with net/http. +type netHTTPTransport struct { + *transport +} + +// RoundTrip is defined since Transport.RegisterProtocol takes in a +// RoundTripper. However, this method will never be used as net/http's +// dialClientConner interface does not have a RoundTrip method and will only +// use DialClientConn to create a new RoundTripper. +func (t netHTTPTransport) RoundTrip(*http.Request) (*http.Response, error) { + panic("netHTTPTransport.RoundTrip should never be called") +} + +func (t netHTTPTransport) DialClientConn(ctx context.Context, addr string, _ *url.URL, _ func()) (http.RoundTripper, error) { + return t.transport.dial(ctx, addr) +} + +// RegisterTransport configures a net/http HTTP/1 Transport to use HTTP/3. +// +// TODO: most likely, add another arg for transport configuration. +func RegisterTransport(tr *http.Transport) { + tr3 := &transport{ + // initConfig will clone the tr.TLSClientConfig. + config: initConfig(&quic.Config{ + TLSConfig: tr.TLSClientConfig, + }), + activeConns: make(map[*clientConn]struct{}), + } + tr.RegisterProtocol("http/3", netHTTPTransport{tr3}) +} + +func (tr *transport) incInFlightDials() { + tr.mu.Lock() + defer tr.mu.Unlock() + tr.inFlightDials++ +} + +func (tr *transport) decInFlightDials() { + tr.mu.Lock() + defer tr.mu.Unlock() + tr.inFlightDials-- +} + +func (tr *transport) initEndpoint() (err error) { + tr.mu.Lock() + defer tr.mu.Unlock() + if tr.endpoint == nil { + tr.endpoint, err = quic.Listen("udp", ":0", nil) + } + return err +} + +// dial creates a new HTTP/3 client connection. +func (tr *transport) dial(ctx context.Context, target string) (*clientConn, error) { + tr.incInFlightDials() + defer tr.decInFlightDials() + + if err := tr.initEndpoint(); err != nil { + return nil, err + } + qconn, err := tr.endpoint.Dial(ctx, "udp", target, tr.config) + if err != nil { + return nil, err + } + return tr.newClientConn(ctx, qconn) +} + +// CloseIdleConnections is called by net/http.Transport.CloseIdleConnections +// after all existing idle connections are closed using http3.clientConn.Close. +// +// When the transport has no active connections anymore, calling this method +// will make the transport clean up any shared resources that are no longer +// required, such as its QUIC endpoint. +func (tr *transport) CloseIdleConnections() { + tr.mu.Lock() + defer tr.mu.Unlock() + if tr.endpoint == nil || len(tr.activeConns) > 0 || tr.inFlightDials > 0 { + return + } + tr.endpoint.Close(canceledCtx) + tr.endpoint = nil +} + +// A clientConn is a client HTTP/3 connection. +// +// Multiple goroutines may invoke methods on a clientConn simultaneously. +type clientConn struct { + qconn *quic.Conn + genericConn + + enc qpackEncoder + dec qpackDecoder +} + +func (tr *transport) registerConn(cc *clientConn) { + tr.mu.Lock() + defer tr.mu.Unlock() + tr.activeConns[cc] = struct{}{} +} + +func (tr *transport) unregisterConn(cc *clientConn) { + tr.mu.Lock() + defer tr.mu.Unlock() + delete(tr.activeConns, cc) +} + +func (tr *transport) newClientConn(ctx context.Context, qconn *quic.Conn) (*clientConn, error) { + cc := &clientConn{ + qconn: qconn, + } + tr.registerConn(cc) + cc.enc.init() + + // Create control stream and send SETTINGS frame. + controlStream, err := newConnStream(ctx, cc.qconn, streamTypeControl) + if err != nil { + tr.unregisterConn(cc) + return nil, fmt.Errorf("http3: cannot create control stream: %v", err) + } + controlStream.writeSettings() + controlStream.Flush() + + go func() { + cc.acceptStreams(qconn, cc) + tr.unregisterConn(cc) + }() + return cc, nil +} + +// TODO: implement the rest of net/http.ClientConn methods beyond Close. +func (cc *clientConn) Close() error { + // We need to use Close rather than Abort on the QUIC connection. + // Otherwise, when a net/http.Transport.CloseIdleConnections is called, it + // might call the http3.transport.CloseIdleConnections prior to all idle + // connections being fully closed; this would make it unable to close its + // QUIC endpoint, making http3.transport.CloseIdleConnections a no-op + // unintentionally. + return cc.qconn.Close() +} + +func (cc *clientConn) Err() error { + return nil +} + +func (cc *clientConn) Reserve() error { + return nil +} + +func (cc *clientConn) Release() { +} + +func (cc *clientConn) Available() int { + return 0 +} + +func (cc *clientConn) InFlight() int { + return 0 +} + +func (cc *clientConn) handleControlStream(st *stream) error { + // "A SETTINGS frame MUST be sent as the first frame of each control stream [...]" + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-2 + if err := st.readSettings(func(settingsType, settingsValue int64) error { + switch settingsType { + case settingsMaxFieldSectionSize: + _ = settingsValue // TODO + case settingsQPACKMaxTableCapacity: + _ = settingsValue // TODO + case settingsQPACKBlockedStreams: + _ = settingsValue // TODO + default: + // Unknown settings types are ignored. + } + return nil + }); err != nil { + return err + } + + for { + ftype, err := st.readFrameHeader() + if err != nil { + return err + } + switch ftype { + case frameTypeCancelPush: + // "If a CANCEL_PUSH frame is received that references a push ID + // greater than currently allowed on the connection, + // this MUST be treated as a connection error of type H3_ID_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.3-7 + return &connectionError{ + code: errH3IDError, + message: "CANCEL_PUSH received when no MAX_PUSH_ID has been sent", + } + case frameTypeGoaway: + // TODO: Wait for requests to complete before closing connection. + return errH3NoError + default: + // Unknown frames are ignored. + if err := st.discardUnknownFrame(ftype); err != nil { + return err + } + } + } +} + +func (cc *clientConn) handleEncoderStream(*stream) error { + // TODO + return nil +} + +func (cc *clientConn) handleDecoderStream(*stream) error { + // TODO + return nil +} + +func (cc *clientConn) handlePushStream(*stream) error { + // "A client MUST treat receipt of a push stream as a connection error + // of type H3_ID_ERROR when no MAX_PUSH_ID frame has been sent [...]" + // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.6-3 + return &connectionError{ + code: errH3IDError, + message: "push stream created when no MAX_PUSH_ID has been sent", + } +} + +func (cc *clientConn) handleRequestStream(st *stream) error { + // "Clients MUST treat receipt of a server-initiated bidirectional + // stream as a connection error of type H3_STREAM_CREATION_ERROR [...]" + // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1-3 + return &connectionError{ + code: errH3StreamCreationError, + message: "server created bidirectional stream", + } +} + +// abort closes the connection with an error. +func (cc *clientConn) abort(err error) { + if e, ok := err.(*connectionError); ok { + cc.qconn.Abort(&quic.ApplicationError{ + Code: uint64(e.code), + Reason: e.message, + }) + } else { + cc.qconn.Abort(err) + } +} diff --git a/src/vendor/golang.org/x/net/internal/http3/varint.go b/src/vendor/golang.org/x/net/internal/http3/varint.go new file mode 100644 index 0000000000..bee3d71fb3 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/varint.go @@ -0,0 +1,23 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +// sizeVarint returns the size of the variable-length integer encoding of f. +// Copied from internal/quic/quicwire to break dependency that makes bundling +// into std more complicated. +func sizeVarint(v uint64) int { + switch { + case v <= 63: + return 1 + case v <= 16383: + return 2 + case v <= 1073741823: + return 4 + case v <= 4611686018427387903: + return 8 + default: + panic("varint too large") + } +} diff --git a/src/vendor/golang.org/x/net/internal/httpcommon/ascii.go b/src/vendor/golang.org/x/net/internal/httpcommon/ascii.go new file mode 100644 index 0000000000..ed14da5afc --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/httpcommon/ascii.go @@ -0,0 +1,53 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httpcommon + +import "strings" + +// The HTTP protocols are defined in terms of ASCII, not Unicode. This file +// contains helper functions which may use Unicode-aware functions which would +// otherwise be unsafe and could introduce vulnerabilities if used improperly. + +// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func asciiEqualFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// isASCIIPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func isASCIIPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +// asciiToLower returns the lowercase version of s if s is ASCII and printable, +// and whether or not it was. +func asciiToLower(s string) (lower string, ok bool) { + if !isASCIIPrint(s) { + return "", false + } + return strings.ToLower(s), true +} diff --git a/src/vendor/golang.org/x/net/internal/httpcommon/headermap.go b/src/vendor/golang.org/x/net/internal/httpcommon/headermap.go new file mode 100644 index 0000000000..92483d8e41 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/httpcommon/headermap.go @@ -0,0 +1,115 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httpcommon + +import ( + "net/textproto" + "sync" +) + +var ( + commonBuildOnce sync.Once + commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case + commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case +) + +func buildCommonHeaderMapsOnce() { + commonBuildOnce.Do(buildCommonHeaderMaps) +} + +func buildCommonHeaderMaps() { + common := []string{ + "accept", + "accept-charset", + "accept-encoding", + "accept-language", + "accept-ranges", + "age", + "access-control-allow-credentials", + "access-control-allow-headers", + "access-control-allow-methods", + "access-control-allow-origin", + "access-control-expose-headers", + "access-control-max-age", + "access-control-request-headers", + "access-control-request-method", + "allow", + "authorization", + "cache-control", + "content-disposition", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-range", + "content-type", + "cookie", + "date", + "etag", + "expect", + "expires", + "from", + "host", + "if-match", + "if-modified-since", + "if-none-match", + "if-unmodified-since", + "last-modified", + "link", + "location", + "max-forwards", + "origin", + "proxy-authenticate", + "proxy-authorization", + "range", + "referer", + "refresh", + "retry-after", + "server", + "set-cookie", + "strict-transport-security", + "trailer", + "transfer-encoding", + "user-agent", + "vary", + "via", + "www-authenticate", + "x-forwarded-for", + "x-forwarded-proto", + } + commonLowerHeader = make(map[string]string, len(common)) + commonCanonHeader = make(map[string]string, len(common)) + for _, v := range common { + chk := textproto.CanonicalMIMEHeaderKey(v) + commonLowerHeader[chk] = v + commonCanonHeader[v] = chk + } +} + +// LowerHeader returns the lowercase form of a header name, +// used on the wire for HTTP/2 and HTTP/3 requests. +func LowerHeader(v string) (lower string, ascii bool) { + buildCommonHeaderMapsOnce() + if s, ok := commonLowerHeader[v]; ok { + return s, true + } + return asciiToLower(v) +} + +// CanonicalHeader canonicalizes a header name. (For example, "host" becomes "Host".) +func CanonicalHeader(v string) string { + buildCommonHeaderMapsOnce() + if s, ok := commonCanonHeader[v]; ok { + return s + } + return textproto.CanonicalMIMEHeaderKey(v) +} + +// CachedCanonicalHeader returns the canonical form of a well-known header name. +func CachedCanonicalHeader(v string) (string, bool) { + buildCommonHeaderMapsOnce() + s, ok := commonCanonHeader[v] + return s, ok +} diff --git a/src/vendor/golang.org/x/net/internal/httpcommon/request.go b/src/vendor/golang.org/x/net/internal/httpcommon/request.go new file mode 100644 index 0000000000..1e10f89ebf --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/httpcommon/request.go @@ -0,0 +1,467 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httpcommon + +import ( + "context" + "errors" + "fmt" + "net/http/httptrace" + "net/textproto" + "net/url" + "sort" + "strconv" + "strings" + + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" +) + +var ( + ErrRequestHeaderListSize = errors.New("request header list larger than peer's advertised limit") +) + +// Request is a subset of http.Request. +// It'd be simpler to pass an *http.Request, of course, but we can't depend on net/http +// without creating a dependency cycle. +type Request struct { + URL *url.URL + Method string + Host string + Header map[string][]string + Trailer map[string][]string + ActualContentLength int64 // 0 means 0, -1 means unknown +} + +// EncodeHeadersParam is parameters to EncodeHeaders. +type EncodeHeadersParam struct { + Request Request + + // AddGzipHeader indicates that an "accept-encoding: gzip" header should be + // added to the request. + AddGzipHeader bool + + // PeerMaxHeaderListSize, when non-zero, is the peer's MAX_HEADER_LIST_SIZE setting. + PeerMaxHeaderListSize uint64 + + // DefaultUserAgent is the User-Agent header to send when the request + // neither contains a User-Agent nor disables it. + DefaultUserAgent string +} + +// EncodeHeadersResult is the result of EncodeHeaders. +type EncodeHeadersResult struct { + HasBody bool + HasTrailers bool +} + +// EncodeHeaders constructs request headers common to HTTP/2 and HTTP/3. +// It validates a request and calls headerf with each pseudo-header and header +// for the request. +// The headerf function is called with the validated, canonicalized header name. +func EncodeHeaders(ctx context.Context, param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) { + req := param.Request + + // Check for invalid connection-level headers. + if err := checkConnHeaders(req.Header); err != nil { + return res, err + } + + if req.URL == nil { + return res, errors.New("Request.URL is nil") + } + + host := req.Host + if host == "" { + host = req.URL.Host + } + host, err := httpguts.PunycodeHostPort(host) + if err != nil { + return res, err + } + if !httpguts.ValidHostHeader(host) { + return res, errors.New("invalid Host header") + } + + // isNormalConnect is true if this is a non-extended CONNECT request. + isNormalConnect := false + var protocol string + if vv := req.Header[":protocol"]; len(vv) > 0 { + protocol = vv[0] + } + if req.Method == "CONNECT" && protocol == "" { + isNormalConnect = true + } else if protocol != "" && req.Method != "CONNECT" { + return res, errors.New("invalid :protocol header in non-CONNECT request") + } + + // Validate the path, except for non-extended CONNECT requests which have no path. + var path string + if !isNormalConnect { + path = req.URL.RequestURI() + if !validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !validPseudoPath(path) { + if req.URL.Opaque != "" { + return res, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return res, fmt.Errorf("invalid request :path %q", orig) + } + } + } + } + + // Check for any invalid headers+trailers and return an error before we + // potentially pollute our hpack state. (We want to be able to + // continue to reuse the hpack encoder for future requests) + if err := validateHeaders(req.Header); err != "" { + return res, fmt.Errorf("invalid HTTP header %s", err) + } + if err := validateHeaders(req.Trailer); err != "" { + return res, fmt.Errorf("invalid HTTP trailer %s", err) + } + + trailers, err := commaSeparatedTrailers(req.Trailer) + if err != nil { + return res, err + } + + enumerateHeaders := func(f func(name, value string)) { + // 8.1.2.3 Request Pseudo-Header Fields + // The :path pseudo-header field includes the path and query parts of the + // target URI (the path-absolute production and optionally a '?' character + // followed by the query production, see Sections 3.3 and 3.4 of + // [RFC3986]). + f(":authority", host) + m := req.Method + if m == "" { + m = "GET" + } + f(":method", m) + if !isNormalConnect { + f(":path", path) + f(":scheme", req.URL.Scheme) + } + if protocol != "" { + f(":protocol", protocol) + } + if trailers != "" { + f("trailer", trailers) + } + + var didUA bool + for k, vv := range req.Header { + if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") { + // Host is :authority, already sent. + // Content-Length is automatic, set below. + continue + } else if asciiEqualFold(k, "connection") || + asciiEqualFold(k, "proxy-connection") || + asciiEqualFold(k, "transfer-encoding") || + asciiEqualFold(k, "upgrade") || + asciiEqualFold(k, "keep-alive") { + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + continue + } else if asciiEqualFold(k, "user-agent") { + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). + didUA = true + if len(vv) < 1 { + continue + } + vv = vv[:1] + if vv[0] == "" { + continue + } + } else if asciiEqualFold(k, "cookie") { + // Per 8.1.2.5 To allow for better compression efficiency, the + // Cookie header field MAY be split into separate header fields, + // each with one or more cookie-pairs. + for _, v := range vv { + for { + p := strings.IndexByte(v, ';') + if p < 0 { + break + } + f("cookie", v[:p]) + p++ + // strip space after semicolon if any. + for p+1 <= len(v) && v[p] == ' ' { + p++ + } + v = v[p:] + } + if len(v) > 0 { + f("cookie", v) + } + } + continue + } else if k == ":protocol" { + // :protocol pseudo-header was already sent above. + continue + } + + for _, v := range vv { + f(k, v) + } + } + if shouldSendReqContentLength(req.Method, req.ActualContentLength) { + f("content-length", strconv.FormatInt(req.ActualContentLength, 10)) + } + if param.AddGzipHeader { + f("accept-encoding", "gzip") + } + if !didUA { + f("user-agent", param.DefaultUserAgent) + } + } + + // Do a first pass over the headers counting bytes to ensure + // we don't exceed cc.peerMaxHeaderListSize. This is done as a + // separate pass before encoding the headers to prevent + // modifying the hpack state. + if param.PeerMaxHeaderListSize > 0 { + hlSize := uint64(0) + enumerateHeaders(func(name, value string) { + hf := hpack.HeaderField{Name: name, Value: value} + hlSize += uint64(hf.Size()) + }) + + if hlSize > param.PeerMaxHeaderListSize { + return res, ErrRequestHeaderListSize + } + } + + trace := httptrace.ContextClientTrace(ctx) + + // Header list size is ok. Write the headers. + enumerateHeaders(func(name, value string) { + name, ascii := LowerHeader(name) + if !ascii { + // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header + // field names have to be ASCII characters (just as in HTTP/1.x). + return + } + + headerf(name, value) + + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField(name, []string{value}) + } + }) + + res.HasBody = req.ActualContentLength != 0 + res.HasTrailers = trailers != "" + return res, nil +} + +// IsRequestGzip reports whether we should add an Accept-Encoding: gzip header +// for a request. +func IsRequestGzip(method string, header map[string][]string, disableCompression bool) bool { + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? + if !disableCompression && + len(header["Accept-Encoding"]) == 0 && + len(header["Range"]) == 0 && + method != "HEAD" { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: https://zlib.net/zlib_faq.html#faq39 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // http://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 + return true + } + return false +} + +// checkConnHeaders checks whether req has any invalid connection-level headers. +// +// https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2-3 +// https://www.rfc-editor.org/rfc/rfc9113.html#section-8.2.2-1 +// +// Certain headers are special-cased as okay but not transmitted later. +// For example, we allow "Transfer-Encoding: chunked", but drop the header when encoding. +func checkConnHeaders(h map[string][]string) error { + if vv := h["Upgrade"]; len(vv) > 0 && (vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("invalid Upgrade request header: %q", vv) + } + if vv := h["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("invalid Transfer-Encoding request header: %q", vv) + } + if vv := h["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) { + return fmt.Errorf("invalid Connection request header: %q", vv) + } + return nil +} + +func commaSeparatedTrailers(trailer map[string][]string) (string, error) { + keys := make([]string, 0, len(trailer)) + for k := range trailer { + k = CanonicalHeader(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return "", fmt.Errorf("invalid Trailer key %q", k) + } + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + return strings.Join(keys, ","), nil + } + return "", nil +} + +// validPseudoPath reports whether v is a valid :path pseudo-header +// value. It must be either: +// +// - a non-empty string starting with '/' +// - the string '*', for OPTIONS requests. +// +// For now this is only used a quick check for deciding when to clean +// up Opaque URLs before sending requests from the Transport. +// See golang.org/issue/16847 +// +// We used to enforce that the path also didn't start with "//", but +// Google's GFE accepts such paths and Chrome sends them, so ignore +// that part of the spec. See golang.org/issue/19103. +func validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/') || v == "*" +} + +func validateHeaders(hdrs map[string][]string) string { + for k, vv := range hdrs { + if !httpguts.ValidHeaderFieldName(k) && k != ":protocol" { + return fmt.Sprintf("name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // Don't include the value in the error, + // because it may be sensitive. + return fmt.Sprintf("value for header %q", k) + } + } + } + return "" +} + +// shouldSendReqContentLength reports whether we should send +// a "content-length" request header. This logic is basically a copy of the net/http +// transferWriter.shouldSendContentLength. +// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). +// -1 means unknown. +func shouldSendReqContentLength(method string, contentLength int64) bool { + if contentLength > 0 { + return true + } + if contentLength < 0 { + return false + } + // For zero bodies, whether we send a content-length depends on the method. + // It also kinda doesn't matter for http2 either way, with END_STREAM. + switch method { + case "POST", "PUT", "PATCH": + return true + default: + return false + } +} + +// ServerRequestParam is parameters to NewServerRequest. +type ServerRequestParam struct { + Method string + Scheme, Authority, Path string + Protocol string + Header map[string][]string +} + +// ServerRequestResult is the result of NewServerRequest. +type ServerRequestResult struct { + // Various http.Request fields. + URL *url.URL + RequestURI string + Trailer map[string][]string + + NeedsContinue bool // client provided an "Expect: 100-continue" header + + // If the request should be rejected, this is a short string suitable for passing + // to the http2 package's CountError function. + // It might be a bit odd to return errors this way rather than returning an error, + // but this ensures we don't forget to include a CountError reason. + InvalidReason string +} + +func NewServerRequest(rp ServerRequestParam) ServerRequestResult { + needsContinue := httpguts.HeaderValuesContainsToken(rp.Header["Expect"], "100-continue") + if needsContinue { + delete(rp.Header, "Expect") + } + // Merge Cookie headers into one "; "-delimited value. + if cookies := rp.Header["Cookie"]; len(cookies) > 1 { + rp.Header["Cookie"] = []string{strings.Join(cookies, "; ")} + } + + // Setup Trailers + var trailer map[string][]string + for _, v := range rp.Header["Trailer"] { + for _, key := range strings.Split(v, ",") { + key = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + // Bogus. (copy of http1 rules) + // Ignore. + default: + if trailer == nil { + trailer = make(map[string][]string) + } + trailer[key] = nil + } + } + } + delete(rp.Header, "Trailer") + + // "':authority' MUST NOT include the deprecated userinfo subcomponent + // for "http" or "https" schemed URIs." + // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8 + if strings.IndexByte(rp.Authority, '@') != -1 && (rp.Scheme == "http" || rp.Scheme == "https") { + return ServerRequestResult{ + InvalidReason: "userinfo_in_authority", + } + } + + var url_ *url.URL + var requestURI string + if rp.Method == "CONNECT" && rp.Protocol == "" { + url_ = &url.URL{Host: rp.Authority} + requestURI = rp.Authority // mimic HTTP/1 server behavior + } else { + var err error + url_, err = url.ParseRequestURI(rp.Path) + if err != nil { + return ServerRequestResult{ + InvalidReason: "bad_path", + } + } + requestURI = rp.Path + } + + return ServerRequestResult{ + URL: url_, + NeedsContinue: needsContinue, + RequestURI: requestURI, + Trailer: trailer, + } +} diff --git a/src/vendor/golang.org/x/net/internal/quic/quicwire/wire.go b/src/vendor/golang.org/x/net/internal/quic/quicwire/wire.go new file mode 100644 index 0000000000..1a06a22519 --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/quic/quicwire/wire.go @@ -0,0 +1,150 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package quicwire encodes and decode QUIC/HTTP3 wire encoding types, +// particularly variable-length integers. +package quicwire + +import "encoding/binary" + +const ( + MaxVarintSize = 8 // encoded size in bytes + MaxVarint = (1 << 62) - 1 +) + +// ConsumeVarint parses a variable-length integer, reporting its length. +// It returns a negative length upon an error. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-16 +func ConsumeVarint(b []byte) (v uint64, n int) { + if len(b) < 1 { + return 0, -1 + } + b0 := b[0] & 0x3f + switch b[0] >> 6 { + case 0: + return uint64(b0), 1 + case 1: + if len(b) < 2 { + return 0, -1 + } + return uint64(b0)<<8 | uint64(b[1]), 2 + case 2: + if len(b) < 4 { + return 0, -1 + } + return uint64(b0)<<24 | uint64(b[1])<<16 | uint64(b[2])<<8 | uint64(b[3]), 4 + case 3: + if len(b) < 8 { + return 0, -1 + } + return uint64(b0)<<56 | uint64(b[1])<<48 | uint64(b[2])<<40 | uint64(b[3])<<32 | uint64(b[4])<<24 | uint64(b[5])<<16 | uint64(b[6])<<8 | uint64(b[7]), 8 + } + return 0, -1 +} + +// ConsumeVarintInt64 parses a variable-length integer as an int64. +func ConsumeVarintInt64(b []byte) (v int64, n int) { + u, n := ConsumeVarint(b) + // QUIC varints are 62-bits large, so this conversion can never overflow. + return int64(u), n +} + +// AppendVarint appends a variable-length integer to b. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-16 +func AppendVarint(b []byte, v uint64) []byte { + switch { + case v <= 63: + return append(b, byte(v)) + case v <= 16383: + return append(b, (1<<6)|byte(v>>8), byte(v)) + case v <= 1073741823: + return append(b, (2<<6)|byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + case v <= 4611686018427387903: + return append(b, (3<<6)|byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + default: + panic("varint too large") + } +} + +// SizeVarint returns the size of the variable-length integer encoding of f. +func SizeVarint(v uint64) int { + switch { + case v <= 63: + return 1 + case v <= 16383: + return 2 + case v <= 1073741823: + return 4 + case v <= 4611686018427387903: + return 8 + default: + panic("varint too large") + } +} + +// ConsumeUint32 parses a 32-bit fixed-length, big-endian integer, reporting its length. +// It returns a negative length upon an error. +func ConsumeUint32(b []byte) (uint32, int) { + if len(b) < 4 { + return 0, -1 + } + return binary.BigEndian.Uint32(b), 4 +} + +// ConsumeUint64 parses a 64-bit fixed-length, big-endian integer, reporting its length. +// It returns a negative length upon an error. +func ConsumeUint64(b []byte) (uint64, int) { + if len(b) < 8 { + return 0, -1 + } + return binary.BigEndian.Uint64(b), 8 +} + +// ConsumeUint8Bytes parses a sequence of bytes prefixed with an 8-bit length, +// reporting the total number of bytes consumed. +// It returns a negative length upon an error. +func ConsumeUint8Bytes(b []byte) ([]byte, int) { + if len(b) < 1 { + return nil, -1 + } + size := int(b[0]) + const n = 1 + if size > len(b[n:]) { + return nil, -1 + } + return b[n:][:size], size + n +} + +// AppendUint8Bytes appends a sequence of bytes prefixed by an 8-bit length. +func AppendUint8Bytes(b, v []byte) []byte { + if len(v) > 0xff { + panic("uint8-prefixed bytes too large") + } + b = append(b, uint8(len(v))) + b = append(b, v...) + return b +} + +// ConsumeVarintBytes parses a sequence of bytes preceded by a variable-length integer length, +// reporting the total number of bytes consumed. +// It returns a negative length upon an error. +func ConsumeVarintBytes(b []byte) ([]byte, int) { + size, n := ConsumeVarint(b) + if n < 0 { + return nil, -1 + } + if size > uint64(len(b[n:])) { + return nil, -1 + } + return b[n:][:size], int(size) + n +} + +// AppendVarintBytes appends a sequence of bytes prefixed by a variable-length integer length. +func AppendVarintBytes(b, v []byte) []byte { + b = AppendVarint(b, uint64(len(v))) + b = append(b, v...) + return b +} diff --git a/src/vendor/golang.org/x/net/quic/ack_delay.go b/src/vendor/golang.org/x/net/quic/ack_delay.go new file mode 100644 index 0000000000..029ce6faec --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/ack_delay.go @@ -0,0 +1,26 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "math" + "time" +) + +// An unscaledAckDelay is an ACK Delay field value from an ACK packet, +// without the ack_delay_exponent scaling applied. +type unscaledAckDelay int64 + +func unscaledAckDelayFromDuration(d time.Duration, ackDelayExponent uint8) unscaledAckDelay { + return unscaledAckDelay(d.Microseconds() >> ackDelayExponent) +} + +func (d unscaledAckDelay) Duration(ackDelayExponent uint8) time.Duration { + if int64(d) > (math.MaxInt64>>ackDelayExponent)/int64(time.Microsecond) { + // If scaling the delay would overflow, ignore the delay. + return 0 + } + return time.Duration(d<<ackDelayExponent) * time.Microsecond +} diff --git a/src/vendor/golang.org/x/net/quic/acks.go b/src/vendor/golang.org/x/net/quic/acks.go new file mode 100644 index 0000000000..90f82bed03 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/acks.go @@ -0,0 +1,213 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// ackState tracks packets received from a peer within a number space. +// It handles packet deduplication (don't process the same packet twice) and +// determines the timing and content of ACK frames. +type ackState struct { + seen rangeset[packetNumber] + + // The time at which we must send an ACK frame, even if we have no other data to send. + nextAck time.Time + + // The time we received the largest-numbered packet in seen. + maxRecvTime time.Time + + // The largest-numbered ack-eliciting packet in seen. + maxAckEliciting packetNumber + + // The number of ack-eliciting packets in seen that we have not yet acknowledged. + unackedAckEliciting int + + // Total ECN counters for this packet number space. + ecn ecnCounts +} + +type ecnCounts struct { + t0 int + t1 int + ce int +} + +// shouldProcess reports whether a packet should be handled or discarded. +func (acks *ackState) shouldProcess(num packetNumber) bool { + if packetNumber(acks.seen.min()) > num { + // We've discarded the state for this range of packet numbers. + // Discard the packet rather than potentially processing a duplicate. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.3-5 + return false + } + if acks.seen.contains(num) { + // Discard duplicate packets. + return false + } + return true +} + +// receive records receipt of a packet. +func (acks *ackState) receive(now time.Time, space numberSpace, num packetNumber, ackEliciting bool, ecn ecnBits) { + if ackEliciting { + acks.unackedAckEliciting++ + if acks.mustAckImmediately(space, num, ecn) { + acks.nextAck = now + } else if acks.nextAck.IsZero() { + // This packet does not need to be acknowledged immediately, + // but the ack must not be intentionally delayed by more than + // the max_ack_delay transport parameter we sent to the peer. + // + // We always delay acks by the maximum allowed, less the timer + // granularity. ("[max_ack_delay] SHOULD include the receiver's + // expected delays in alarms firing.") + // + // https://www.rfc-editor.org/rfc/rfc9000#section-18.2-4.28.1 + acks.nextAck = now.Add(maxAckDelay - timerGranularity) + } + if num > acks.maxAckEliciting { + acks.maxAckEliciting = num + } + } + + acks.seen.add(num, num+1) + if num == acks.seen.max() { + acks.maxRecvTime = now + } + + switch ecn { + case ecnECT0: + acks.ecn.t0++ + case ecnECT1: + acks.ecn.t1++ + case ecnCE: + acks.ecn.ce++ + } + + // Limit the total number of ACK ranges by dropping older ranges. + // + // Remembering more ranges results in larger ACK frames. + // + // Remembering a large number of ranges could result in ACK frames becoming + // too large to fit in a packet, in which case we will silently drop older + // ranges during packet construction. + // + // Remembering fewer ranges can result in unnecessary retransmissions, + // since we cannot accept packets older than the oldest remembered range. + // + // The limit here is completely arbitrary. If it seems wrong, it probably is. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.3 + const maxAckRanges = 8 + if overflow := acks.seen.numRanges() - maxAckRanges; overflow > 0 { + acks.seen.removeranges(0, overflow) + } +} + +// mustAckImmediately reports whether an ack-eliciting packet must be acknowledged immediately, +// or whether the ack may be deferred. +func (acks *ackState) mustAckImmediately(space numberSpace, num packetNumber, ecn ecnBits) bool { + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1 + if space != appDataSpace { + // "[...] all ack-eliciting Initial and Handshake packets [...]" + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1-2 + return true + } + if num < acks.maxAckEliciting { + // "[...] when the received packet has a packet number less than another + // ack-eliciting packet that has been received [...]" + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1-8.1 + return true + } + if acks.seen.rangeContaining(acks.maxAckEliciting).end != num { + // "[...] when the packet has a packet number larger than the highest-numbered + // ack-eliciting packet that has been received and there are missing packets + // between that packet and this packet." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1-8.2 + // + // This case is a bit tricky. Let's say we've received: + // 0, ack-eliciting + // 1, ack-eliciting + // 3, NOT ack eliciting + // + // We have sent ACKs for 0 and 1. If we receive ack-eliciting packet 2, + // we do not need to send an immediate ACK, because there are no missing + // packets between it and the highest-numbered ack-eliciting packet (1). + // If we receive ack-eliciting packet 4, we do need to send an immediate ACK, + // because there's a gap (the missing packet 2). + // + // We check for this by looking up the ACK range which contains the + // highest-numbered ack-eliciting packet: [0, 1) in the above example. + // If the range ends just before the packet we are now processing, + // there are no gaps. If it does not, there must be a gap. + return true + } + // "[...] packets marked with the ECN Congestion Experienced (CE) codepoint + // in the IP header SHOULD be acknowledged immediately [...]" + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1-9 + if ecn == ecnCE { + return true + } + // "[...] SHOULD send an ACK frame after receiving at least two ack-eliciting packets." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2 + // + // This ack frequency takes a substantial toll on performance, however. + // Follow the behavior of Google QUICHE: + // Ack every other packet for the first 100 packets, and then ack every 10th packet. + // This keeps ack frequency high during the beginning of slow start when CWND is + // increasing rapidly. + packetsBeforeAck := 2 + if acks.seen.max() > 100 { + packetsBeforeAck = 10 + } + return acks.unackedAckEliciting >= packetsBeforeAck +} + +// shouldSendAck reports whether the connection should send an ACK frame at this time, +// in an ACK-only packet if necessary. +func (acks *ackState) shouldSendAck(now time.Time) bool { + return !acks.nextAck.IsZero() && !acks.nextAck.After(now) +} + +// acksToSend returns the set of packet numbers to ACK at this time, and the current ack delay. +// It may return acks even if shouldSendAck returns false, when there are unacked +// ack-eliciting packets whose ack is being delayed. +func (acks *ackState) acksToSend(now time.Time) (nums rangeset[packetNumber], ackDelay time.Duration) { + if acks.nextAck.IsZero() && acks.unackedAckEliciting == 0 { + return nil, 0 + } + // "[...] the delays intentionally introduced between the time the packet with the + // largest packet number is received and the time an acknowledgement is sent." + // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.5-1 + delay := now.Sub(acks.maxRecvTime) + if delay < 0 { + delay = 0 + } + return acks.seen, delay +} + +// sentAck records that an ACK frame has been sent. +func (acks *ackState) sentAck() { + acks.nextAck = time.Time{} + acks.unackedAckEliciting = 0 +} + +// handleAck records that an ack has been received for a ACK frame we sent +// containing the given Largest Acknowledged field. +func (acks *ackState) handleAck(largestAcked packetNumber) { + // We can stop acking packets less or equal to largestAcked. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.4-1 + // + // We rely on acks.seen containing the largest packet number that has been successfully + // processed, so we retain the range containing largestAcked and discard previous ones. + acks.seen.sub(0, acks.seen.rangeContaining(largestAcked).start) +} + +// largestSeen reports the largest seen packet. +func (acks *ackState) largestSeen() packetNumber { + return acks.seen.max() +} diff --git a/src/vendor/golang.org/x/net/quic/atomic_bits.go b/src/vendor/golang.org/x/net/quic/atomic_bits.go new file mode 100644 index 0000000000..5244e04201 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/atomic_bits.go @@ -0,0 +1,31 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "sync/atomic" + +// atomicBits is an atomic uint32 that supports setting individual bits. +type atomicBits[T ~uint32] struct { + bits atomic.Uint32 +} + +// set sets the bits in mask to the corresponding bits in v. +// It returns the new value. +func (a *atomicBits[T]) set(v, mask T) T { + if v&^mask != 0 { + panic("BUG: bits in v are not in mask") + } + for { + o := a.bits.Load() + n := (o &^ uint32(mask)) | uint32(v) + if a.bits.CompareAndSwap(o, n) { + return T(n) + } + } +} + +func (a *atomicBits[T]) load() T { + return T(a.bits.Load()) +} diff --git a/src/vendor/golang.org/x/net/quic/config.go b/src/vendor/golang.org/x/net/quic/config.go new file mode 100644 index 0000000000..a9ec4bc437 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/config.go @@ -0,0 +1,158 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "crypto/tls" + "log/slog" + "math" + "time" + + "golang.org/x/net/internal/quic/quicwire" +) + +// A Config structure configures a QUIC endpoint. +// A Config must not be modified after it has been passed to a QUIC function. +// A Config may be reused; the quic package will also not modify it. +type Config struct { + // TLSConfig is the endpoint's TLS configuration. + // It must be non-nil and include at least one certificate or else set GetCertificate. + TLSConfig *tls.Config + + // MaxBidiRemoteStreams limits the number of simultaneous bidirectional streams + // a peer may open. + // If zero, the default value of 100 is used. + // If negative, the limit is zero. + MaxBidiRemoteStreams int64 + + // MaxUniRemoteStreams limits the number of simultaneous unidirectional streams + // a peer may open. + // If zero, the default value of 100 is used. + // If negative, the limit is zero. + MaxUniRemoteStreams int64 + + // MaxStreamReadBufferSize is the maximum amount of data sent by the peer that a + // stream will buffer for reading. + // If zero, the default value of 1MiB is used. + // If negative, the limit is zero. + MaxStreamReadBufferSize int64 + + // MaxStreamWriteBufferSize is the maximum amount of data a stream will buffer for + // sending to the peer. + // If zero, the default value of 1MiB is used. + // If negative, the limit is zero. + MaxStreamWriteBufferSize int64 + + // MaxConnReadBufferSize is the maximum amount of data sent by the peer that a + // connection will buffer for reading, across all streams. + // If zero, the default value of 1MiB is used. + // If negative, the limit is zero. + MaxConnReadBufferSize int64 + + // RequireAddressValidation may be set to true to enable address validation + // of client connections prior to starting the handshake. + // + // Enabling this setting reduces the amount of work packets with spoofed + // source address information can cause a server to perform, + // at the cost of increased handshake latency. + RequireAddressValidation bool + + // StatelessResetKey is used to provide stateless reset of connections. + // A restart may leave an endpoint without access to the state of + // existing connections. Stateless reset permits an endpoint to respond + // to a packet for a connection it does not recognize. + // + // This field should be filled with random bytes. + // The contents should remain stable across restarts, + // to permit an endpoint to send a reset for + // connections created before a restart. + // + // The contents of the StatelessResetKey should not be exposed. + // An attacker can use knowledge of this field's value to + // reset existing connections. + // + // If this field is left as zero, stateless reset is disabled. + StatelessResetKey [32]byte + + // HandshakeTimeout is the maximum time in which a connection handshake must complete. + // If zero, the default of 10 seconds is used. + // If negative, there is no handshake timeout. + HandshakeTimeout time.Duration + + // MaxIdleTimeout is the maximum time after which an idle connection will be closed. + // If zero, the default of 30 seconds is used. + // If negative, idle connections are never closed. + // + // The idle timeout for a connection is the minimum of the maximum idle timeouts + // of the endpoints. + MaxIdleTimeout time.Duration + + // KeepAlivePeriod is the time after which a packet will be sent to keep + // an idle connection alive. + // If zero, keep alive packets are not sent. + // If greater than zero, the keep alive period is the smaller of KeepAlivePeriod and + // half the connection idle timeout. + KeepAlivePeriod time.Duration + + // QLogLogger receives qlog events. + // + // Events currently correspond to the definitions in draft-ietf-qlog-quic-events-03. + // This is not the latest version of the draft, but is the latest version supported + // by common event log viewers as of the time this paragraph was written. + // + // The qlog package contains a slog.Handler which serializes qlog events + // to a standard JSON representation. + QLogLogger *slog.Logger +} + +// Clone returns a shallow clone of c, or nil if c is nil. +// It is safe to clone a [Config] that is being used concurrently by a QUIC endpoint. +func (c *Config) Clone() *Config { + n := *c + return &n +} + +func configDefault[T ~int64](v, def, limit T) T { + switch { + case v == 0: + return def + case v < 0: + return 0 + default: + return min(v, limit) + } +} + +func (c *Config) maxBidiRemoteStreams() int64 { + return configDefault(c.MaxBidiRemoteStreams, 100, maxStreamsLimit) +} + +func (c *Config) maxUniRemoteStreams() int64 { + return configDefault(c.MaxUniRemoteStreams, 100, maxStreamsLimit) +} + +func (c *Config) maxStreamReadBufferSize() int64 { + return configDefault(c.MaxStreamReadBufferSize, 1<<20, quicwire.MaxVarint) +} + +func (c *Config) maxStreamWriteBufferSize() int64 { + return configDefault(c.MaxStreamWriteBufferSize, 1<<20, quicwire.MaxVarint) +} + +func (c *Config) maxConnReadBufferSize() int64 { + return configDefault(c.MaxConnReadBufferSize, 1<<20, quicwire.MaxVarint) +} + +func (c *Config) handshakeTimeout() time.Duration { + return configDefault(c.HandshakeTimeout, defaultHandshakeTimeout, math.MaxInt64) +} + +func (c *Config) maxIdleTimeout() time.Duration { + return configDefault(c.MaxIdleTimeout, defaultMaxIdleTimeout, math.MaxInt64) +} + +func (c *Config) keepAlivePeriod() time.Duration { + return configDefault(c.KeepAlivePeriod, defaultKeepAlivePeriod, math.MaxInt64) +} diff --git a/src/vendor/golang.org/x/net/quic/congestion_reno.go b/src/vendor/golang.org/x/net/quic/congestion_reno.go new file mode 100644 index 0000000000..028a2ed6c3 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/congestion_reno.go @@ -0,0 +1,306 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "log/slog" + "math" + "time" +) + +// ccReno is the NewReno-based congestion controller defined in RFC 9002. +// https://www.rfc-editor.org/rfc/rfc9002.html#section-7 +type ccReno struct { + maxDatagramSize int + + // Maximum number of bytes allowed to be in flight. + congestionWindow int + + // Sum of size of all packets that contain at least one ack-eliciting + // or PADDING frame (i.e., any non-ACK frame), and have neither been + // acknowledged nor declared lost. + bytesInFlight int + + // When the congestion window is below the slow start threshold, + // the controller is in slow start. + slowStartThreshold int + + // The time the current recovery period started, or zero when not + // in a recovery period. + recoveryStartTime time.Time + + // Accumulated count of bytes acknowledged in congestion avoidance. + congestionPendingAcks int + + // When entering a recovery period, we are allowed to send one packet + // before reducing the congestion window. sendOnePacketInRecovery is + // true if we haven't sent that packet yet. + sendOnePacketInRecovery bool + + // inRecovery is set when we are in the recovery state. + inRecovery bool + + // underutilized is set if the congestion window is underutilized + // due to insufficient application data, flow control limits, or + // anti-amplification limits. + underutilized bool + + // ackLastLoss is the sent time of the newest lost packet processed + // in the current batch. + ackLastLoss time.Time + + // Data tracking the duration of the most recently handled sequence of + // contiguous lost packets. If this exceeds the persistent congestion duration, + // persistent congestion is declared. + // + // https://www.rfc-editor.org/rfc/rfc9002#section-7.6 + persistentCongestion [numberSpaceCount]struct { + start time.Time // send time of first lost packet + end time.Time // send time of last lost packet + next packetNumber // one plus the number of the last lost packet + } +} + +func newReno(maxDatagramSize int) *ccReno { + c := &ccReno{ + maxDatagramSize: maxDatagramSize, + } + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.2-1 + c.congestionWindow = min(10*maxDatagramSize, max(14720, c.minimumCongestionWindow())) + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.1-1 + c.slowStartThreshold = math.MaxInt + + for space := range c.persistentCongestion { + c.persistentCongestion[space].next = -1 + } + return c +} + +// canSend reports whether the congestion controller permits sending +// a maximum-size datagram at this time. +// +// "An endpoint MUST NOT send a packet if it would cause bytes_in_flight [...] +// to be larger than the congestion window [...]" +// https://www.rfc-editor.org/rfc/rfc9002#section-7-7 +// +// For simplicity and efficiency, we don't permit sending undersized datagrams. +func (c *ccReno) canSend() bool { + if c.sendOnePacketInRecovery { + return true + } + return c.bytesInFlight+c.maxDatagramSize <= c.congestionWindow +} + +// setUnderutilized indicates that the congestion window is underutilized. +// +// The congestion window is underutilized if bytes in flight is smaller than +// the congestion window and sending is not pacing limited; that is, the +// congestion controller permits sending data, but no data is sent. +// +// https://www.rfc-editor.org/rfc/rfc9002#section-7.8 +func (c *ccReno) setUnderutilized(log *slog.Logger, v bool) { + if c.underutilized == v { + return + } + oldState := c.state() + c.underutilized = v + if logEnabled(log, QLogLevelPacket) { + logCongestionStateUpdated(log, oldState, c.state()) + } +} + +// packetSent indicates that a packet has been sent. +func (c *ccReno) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) { + if !sent.inFlight { + return + } + c.bytesInFlight += sent.size + if c.sendOnePacketInRecovery { + c.sendOnePacketInRecovery = false + } +} + +// Acked and lost packets are processed in batches +// resulting from either a received ACK frame or +// the loss detection timer expiring. +// +// A batch consists of zero or more calls to packetAcked and packetLost, +// followed by a single call to packetBatchEnd. +// +// Acks may be reported in any order, but lost packets must +// be reported in strictly increasing order. + +// packetAcked indicates that a packet has been newly acknowledged. +func (c *ccReno) packetAcked(now time.Time, sent *sentPacket) { + if !sent.inFlight { + return + } + c.bytesInFlight -= sent.size + + if c.underutilized { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.8 + return + } + if sent.time.Before(c.recoveryStartTime) { + // In recovery, and this packet was sent before we entered recovery. + // (If this packet was sent after we entered recovery, receiving an ack + // for it moves us out of recovery into congestion avoidance.) + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.2 + return + } + c.congestionPendingAcks += sent.size +} + +// packetLost indicates that a packet has been newly marked as lost. +// Lost packets must be reported in increasing order. +func (c *ccReno) packetLost(now time.Time, space numberSpace, sent *sentPacket, rtt *rttState) { + // Record state to check for persistent congestion. + // https://www.rfc-editor.org/rfc/rfc9002#section-7.6 + // + // Note that this relies on always receiving loss events in increasing order: + // All packets prior to the one we're examining now have either been + // acknowledged or declared lost. + isValidPersistentCongestionSample := (sent.ackEliciting && + !rtt.firstSampleTime.IsZero() && + !sent.time.Before(rtt.firstSampleTime)) + if isValidPersistentCongestionSample { + // This packet either extends an existing range of lost packets, + // or starts a new one. + if sent.num != c.persistentCongestion[space].next { + c.persistentCongestion[space].start = sent.time + } + c.persistentCongestion[space].end = sent.time + c.persistentCongestion[space].next = sent.num + 1 + } else { + // This packet cannot establish persistent congestion on its own. + // However, if we have an existing range of lost packets, + // this does not break it. + if sent.num == c.persistentCongestion[space].next { + c.persistentCongestion[space].next = sent.num + 1 + } + } + + if !sent.inFlight { + return + } + c.bytesInFlight -= sent.size + if sent.time.After(c.ackLastLoss) { + c.ackLastLoss = sent.time + } +} + +// packetBatchEnd is called at the end of processing a batch of acked or lost packets. +func (c *ccReno) packetBatchEnd(now time.Time, log *slog.Logger, space numberSpace, rtt *rttState, maxAckDelay time.Duration) { + if logEnabled(log, QLogLevelPacket) { + oldState := c.state() + defer func() { logCongestionStateUpdated(log, oldState, c.state()) }() + } + if !c.ackLastLoss.IsZero() && !c.ackLastLoss.Before(c.recoveryStartTime) { + // Enter the recovery state. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.2 + c.recoveryStartTime = now + c.slowStartThreshold = c.congestionWindow / 2 + c.congestionWindow = max(c.slowStartThreshold, c.minimumCongestionWindow()) + c.sendOnePacketInRecovery = true + // Clear congestionPendingAcks to avoid increasing the congestion + // window based on acks in a frame that sends us into recovery. + c.congestionPendingAcks = 0 + c.inRecovery = true + } else if c.congestionPendingAcks > 0 { + // We are in slow start or congestion avoidance. + c.inRecovery = false + if c.congestionWindow < c.slowStartThreshold { + // When the congestion window is less than the slow start threshold, + // we are in slow start and increase the window by the number of + // bytes acknowledged. + d := min(c.slowStartThreshold-c.congestionWindow, c.congestionPendingAcks) + c.congestionWindow += d + c.congestionPendingAcks -= d + } + // When the congestion window is at or above the slow start threshold, + // we are in congestion avoidance. + // + // RFC 9002 does not specify an algorithm here. The following is + // the recommended algorithm from RFC 5681, in which we increment + // the window by the maximum datagram size every time the number + // of bytes acknowledged reaches cwnd. + for c.congestionPendingAcks > c.congestionWindow { + c.congestionPendingAcks -= c.congestionWindow + c.congestionWindow += c.maxDatagramSize + } + } + if !c.ackLastLoss.IsZero() { + // Check for persistent congestion. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.6 + // + // "A sender [...] MAY use state for just the packet number space that + // was acknowledged." + // https://www.rfc-editor.org/rfc/rfc9002#section-7.6.2-5 + // + // For simplicity, we consider each number space independently. + const persistentCongestionThreshold = 3 + d := (rtt.smoothedRTT + max(4*rtt.rttvar, timerGranularity) + maxAckDelay) * + persistentCongestionThreshold + start := c.persistentCongestion[space].start + end := c.persistentCongestion[space].end + if end.Sub(start) >= d { + c.congestionWindow = c.minimumCongestionWindow() + c.recoveryStartTime = time.Time{} + rtt.establishPersistentCongestion() + } + } + c.ackLastLoss = time.Time{} +} + +// packetDiscarded indicates that the keys for a packet's space have been discarded. +func (c *ccReno) packetDiscarded(sent *sentPacket) { + // https://www.rfc-editor.org/rfc/rfc9002#section-6.2.2-3 + if sent.inFlight { + c.bytesInFlight -= sent.size + } +} + +func (c *ccReno) minimumCongestionWindow() int { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.2-4 + return 2 * c.maxDatagramSize +} + +func logCongestionStateUpdated(log *slog.Logger, oldState, newState congestionState) { + if oldState == newState { + return + } + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:congestion_state_updated", + slog.String("old", oldState.String()), + slog.String("new", newState.String()), + ) +} + +type congestionState string + +func (s congestionState) String() string { return string(s) } + +const ( + congestionSlowStart = congestionState("slow_start") + congestionCongestionAvoidance = congestionState("congestion_avoidance") + congestionApplicationLimited = congestionState("application_limited") + congestionRecovery = congestionState("recovery") +) + +func (c *ccReno) state() congestionState { + switch { + case c.inRecovery: + return congestionRecovery + case c.underutilized: + return congestionApplicationLimited + case c.congestionWindow < c.slowStartThreshold: + return congestionSlowStart + default: + return congestionCongestionAvoidance + } +} diff --git a/src/vendor/golang.org/x/net/quic/conn.go b/src/vendor/golang.org/x/net/quic/conn.go new file mode 100644 index 0000000000..fd812b8a28 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn.go @@ -0,0 +1,434 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + cryptorand "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "log/slog" + "math/rand/v2" + "net/netip" + "time" +) + +// A Conn is a QUIC connection. +// +// Multiple goroutines may invoke methods on a Conn simultaneously. +type Conn struct { + side connSide + endpoint *Endpoint + config *Config + testHooks connTestHooks + peerAddr netip.AddrPort + localAddr netip.AddrPort + prng *rand.Rand + + msgc chan any + donec chan struct{} // closed when conn loop exits + + w packetWriter + acks [numberSpaceCount]ackState // indexed by number space + lifetime lifetimeState + idle idleState + connIDState connIDState + loss lossState + streams streamsState + path pathState + skip skipState + + // Packet protection keys, CRYPTO streams, and TLS state. + keysInitial fixedKeyPair + keysHandshake fixedKeyPair + keysAppData updatingKeyPair + crypto [numberSpaceCount]cryptoStream + tls *tls.QUICConn + + // retryToken is the token provided by the peer in a Retry packet. + retryToken []byte + + // handshakeConfirmed is set when the handshake is confirmed. + // For server connections, it tracks sending HANDSHAKE_DONE. + handshakeConfirmed sentVal + + peerAckDelayExponent int8 // -1 when unknown + + // Tests only: Send a PING in a specific number space. + testSendPingSpace numberSpace + testSendPing sentVal + + log *slog.Logger +} + +// connTestHooks override conn behavior in tests. +type connTestHooks interface { + // init is called after a conn is created. + init(first bool) + + // handleTLSEvent is called with each TLS event. + handleTLSEvent(tls.QUICEvent) + + // newConnID is called to generate a new connection ID. + // Permits tests to generate consistent connection IDs rather than random ones. + newConnID(seq int64) ([]byte, error) +} + +// newServerConnIDs is connection IDs associated with a new server connection. +type newServerConnIDs struct { + srcConnID []byte // source from client's current Initial + dstConnID []byte // destination from client's current Initial + originalDstConnID []byte // destination from client's first Initial + retrySrcConnID []byte // source from server's Retry +} + +func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) { + c := &Conn{ + side: side, + endpoint: e, + config: config, + peerAddr: unmapAddrPort(peerAddr), + donec: make(chan struct{}), + peerAckDelayExponent: -1, + } + defer func() { + // If we hit an error in newConn, close donec so tests don't get stuck waiting for it. + // This is only relevant if we've got a bug, but it makes tracking that bug down + // much easier. + if conn == nil { + close(c.donec) + } + }() + + // A one-element buffer allows us to wake a Conn's event loop as a + // non-blocking operation. + c.msgc = make(chan any, 1) + + if e.testHooks != nil { + e.testHooks.newConn(c) + } + + // initialConnID is the connection ID used to generate Initial packet protection keys. + var initialConnID []byte + if c.side == clientSide { + if err := c.connIDState.initClient(c); err != nil { + return nil, err + } + initialConnID, _ = c.connIDState.dstConnID() + } else { + initialConnID = cids.originalDstConnID + if cids.retrySrcConnID != nil { + initialConnID = cids.retrySrcConnID + } + if err := c.connIDState.initServer(c, cids); err != nil { + return nil, err + } + } + + // A per-conn ChaCha8 PRNG is probably more than we need, + // but at least it's fairly small. + var seed [32]byte + if _, err := cryptorand.Read(seed[:]); err != nil { + panic(err) + } + c.prng = rand.New(rand.NewChaCha8(seed)) + + // TODO: PMTU discovery. + c.logConnectionStarted(cids.originalDstConnID, peerAddr) + c.keysAppData.init() + c.loss.init(c.side, smallestMaxDatagramSize, now) + c.streamsInit() + c.lifetimeInit() + c.restartIdleTimer(now) + c.skip.init(c) + + if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{ + initialSrcConnID: c.connIDState.srcConnID(), + originalDstConnID: cids.originalDstConnID, + retrySrcConnID: cids.retrySrcConnID, + ackDelayExponent: ackDelayExponent, + maxUDPPayloadSize: maxUDPPayloadSize, + maxAckDelay: maxAckDelay, + disableActiveMigration: true, + initialMaxData: config.maxConnReadBufferSize(), + initialMaxStreamDataBidiLocal: config.maxStreamReadBufferSize(), + initialMaxStreamDataBidiRemote: config.maxStreamReadBufferSize(), + initialMaxStreamDataUni: config.maxStreamReadBufferSize(), + initialMaxStreamsBidi: c.streams.remoteLimit[bidiStream].max, + initialMaxStreamsUni: c.streams.remoteLimit[uniStream].max, + activeConnIDLimit: activeConnIDLimit, + }); err != nil { + return nil, err + } + + if c.testHooks != nil { + c.testHooks.init(true) + } + go c.loop(now) + return c, nil +} + +func (c *Conn) String() string { + return fmt.Sprintf("quic.Conn(%v,->%v)", c.side, c.peerAddr) +} + +// LocalAddr returns the local network address, if known. +func (c *Conn) LocalAddr() netip.AddrPort { + return c.localAddr +} + +// RemoteAddr returns the remote network address, if known. +func (c *Conn) RemoteAddr() netip.AddrPort { + return c.peerAddr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() tls.ConnectionState { + return c.tls.ConnectionState() +} + +// confirmHandshake is called when the handshake is confirmed. +// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2 +func (c *Conn) confirmHandshake(now time.Time) { + // If handshakeConfirmed is unset, the handshake is not confirmed. + // If it is unsent, the handshake is confirmed and we need to send a HANDSHAKE_DONE. + // If it is sent, we have sent a HANDSHAKE_DONE. + // If it is received, the handshake is confirmed and we do not need to send anything. + if c.handshakeConfirmed.isSet() { + return // already confirmed + } + if c.side == serverSide { + // When the server confirms the handshake, it sends a HANDSHAKE_DONE. + c.handshakeConfirmed.setUnsent() + c.endpoint.serverConnEstablished(c) + } else { + // The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed + // to the received state, indicating that the handshake is confirmed and we + // don't need to send anything. + c.handshakeConfirmed.setReceived() + } + c.restartIdleTimer(now) + c.loss.confirmHandshake() + // "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed" + // https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1 + c.discardKeys(now, handshakeSpace) +} + +// discardKeys discards unused packet protection keys. +// https://www.rfc-editor.org/rfc/rfc9001#section-4.9 +func (c *Conn) discardKeys(now time.Time, space numberSpace) { + if err := c.crypto[space].discardKeys(); err != nil { + c.abort(now, err) + } + switch space { + case initialSpace: + c.keysInitial.discard() + case handshakeSpace: + c.keysHandshake.discard() + } + c.loss.discardKeys(now, c.log, space) +} + +// receiveTransportParameters applies transport parameters sent by the peer. +func (c *Conn) receiveTransportParameters(p transportParameters) error { + isRetry := c.retryToken != nil + if err := c.connIDState.validateTransportParameters(c, isRetry, p); err != nil { + return err + } + c.streams.outflow.setMaxData(p.initialMaxData) + c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi) + c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni) + c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal + c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote + c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni + c.receivePeerMaxIdleTimeout(p.maxIdleTimeout) + c.peerAckDelayExponent = p.ackDelayExponent + c.loss.setMaxAckDelay(p.maxAckDelay) + if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil { + return err + } + if p.preferredAddrConnID != nil { + var ( + seq int64 = 1 // sequence number of this conn id is 1 + retirePriorTo int64 = 0 // retire nothing + resetToken [16]byte + ) + copy(resetToken[:], p.preferredAddrResetToken) + if err := c.connIDState.handleNewConnID(c, seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil { + return err + } + } + // TODO: stateless_reset_token + // TODO: max_udp_payload_size + // TODO: disable_active_migration + // TODO: preferred_address + return nil +} + +type ( + timerEvent struct{} + wakeEvent struct{} +) + +var errIdleTimeout = errors.New("idle timeout") + +// loop is the connection main loop. +// +// Except where otherwise noted, all connection state is owned by the loop goroutine. +// +// The loop processes messages from c.msgc and timer events. +// Other goroutines may examine or modify conn state by sending the loop funcs to execute. +func (c *Conn) loop(now time.Time) { + defer c.cleanup() + + // The connection timer sends a message to the connection loop on expiry. + // We need to give it an expiry when creating it, so set the initial timeout to + // an arbitrary large value. The timer will be reset before this expires (and it + // isn't a problem if it does anyway). + var lastTimeout time.Time + timer := time.AfterFunc(1*time.Hour, func() { + c.sendMsg(timerEvent{}) + }) + defer timer.Stop() + + for c.lifetime.state != connStateDone { + sendTimeout := c.maybeSend(now) // try sending + + // Note that we only need to consider the ack timer for the App Data space, + // since the Initial and Handshake spaces always ack immediately. + nextTimeout := sendTimeout + nextTimeout = firstTime(nextTimeout, c.idle.nextTimeout) + if c.isAlive() { + nextTimeout = firstTime(nextTimeout, c.loss.timer) + nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck) + } else { + nextTimeout = firstTime(nextTimeout, c.lifetime.drainEndTime) + } + + var m any + if !nextTimeout.IsZero() && nextTimeout.Before(now) { + // A connection timer has expired. + now = time.Now() + m = timerEvent{} + } else { + // Reschedule the connection timer if necessary + // and wait for the next event. + if !nextTimeout.Equal(lastTimeout) && !nextTimeout.IsZero() { + // Resetting a timer created with time.AfterFunc guarantees + // that the timer will run again. We might generate a spurious + // timer event under some circumstances, but that's okay. + timer.Reset(nextTimeout.Sub(now)) + lastTimeout = nextTimeout + } + m = <-c.msgc + now = time.Now() + } + switch m := m.(type) { + case *datagram: + if !c.handleDatagram(now, m) { + if c.logEnabled(QLogLevelPacket) { + c.logPacketDropped(m) + } + } + m.recycle() + case timerEvent: + // A connection timer has expired. + if c.idleAdvance(now) { + // The connection idle timer has expired. + c.abortImmediately(now, errIdleTimeout) + return + } + c.loss.advance(now, c.handleAckOrLoss) + if c.lifetimeAdvance(now) { + // The connection has completed the draining period, + // and may be shut down. + return + } + case wakeEvent: + // We're being woken up to try sending some frames. + case func(time.Time, *Conn): + // Send a func to msgc to run it on the main Conn goroutine + m(now, c) + case func(now, next time.Time, _ *Conn): + // Send a func to msgc to run it on the main Conn goroutine + m(now, nextTimeout, c) + default: + panic(fmt.Sprintf("quic: unrecognized conn message %T", m)) + } + } +} + +func (c *Conn) cleanup() { + c.logConnectionClosed() + c.endpoint.connDrained(c) + c.tls.Close() + close(c.donec) +} + +// sendMsg sends a message to the conn's loop. +// It does not wait for the message to be processed. +// The conn may close before processing the message, in which case it is lost. +func (c *Conn) sendMsg(m any) { + select { + case c.msgc <- m: + case <-c.donec: + } +} + +// wake wakes up the conn's loop. +func (c *Conn) wake() { + select { + case c.msgc <- wakeEvent{}: + default: + } +} + +// runOnLoop executes a function within the conn's loop goroutine. +func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) error { + donec := make(chan struct{}) + msg := func(now time.Time, c *Conn) { + defer close(donec) + f(now, c) + } + c.sendMsg(msg) + select { + case <-donec: + case <-c.donec: + return errors.New("quic: connection closed") + } + return nil +} + +func (c *Conn) waitOnDone(ctx context.Context, ch <-chan struct{}) error { + // Check the channel before the context. + // We always prefer to return results when available, + // even when provided with an already-canceled context. + select { + case <-ch: + return nil + default: + } + select { + case <-ch: + case <-ctx.Done(): + return ctx.Err() + } + return nil +} + +// firstTime returns the earliest non-zero time, or zero if both times are zero. +func firstTime(a, b time.Time) time.Time { + switch { + case a.IsZero(): + return b + case b.IsZero(): + return a + case a.Before(b): + return a + default: + return b + } +} diff --git a/src/vendor/golang.org/x/net/quic/conn_close.go b/src/vendor/golang.org/x/net/quic/conn_close.go new file mode 100644 index 0000000000..d22f3df5c8 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_close.go @@ -0,0 +1,340 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "errors" + "time" +) + +// connState is the state of a connection. +type connState int + +const ( + // A connection is alive when it is first created. + connStateAlive = connState(iota) + + // The connection has received a CONNECTION_CLOSE frame from the peer, + // and has not yet sent a CONNECTION_CLOSE in response. + // + // We will send a CONNECTION_CLOSE, and then enter the draining state. + connStatePeerClosed + + // The connection is in the closing state. + // + // We will send CONNECTION_CLOSE frames to the peer + // (once upon entering the closing state, and possibly again in response to peer packets). + // + // If we receive a CONNECTION_CLOSE from the peer, we will enter the draining state. + // Otherwise, we will eventually time out and move to the done state. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.1 + connStateClosing + + // The connection is in the draining state. + // + // We will neither send packets nor process received packets. + // When the drain timer expires, we move to the done state. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.2 + connStateDraining + + // The connection is done, and the conn loop will exit. + connStateDone +) + +// lifetimeState tracks the state of a connection. +// +// This is fairly coupled to the rest of a Conn, but putting it in a struct of its own helps +// reason about operations that cause state transitions. +type lifetimeState struct { + state connState + + readyc chan struct{} // closed when TLS handshake completes + donec chan struct{} // closed when finalErr is set + + localErr error // error sent to the peer + finalErr error // error sent by the peer, or transport error; set before closing donec + + connCloseSentTime time.Time // send time of last CONNECTION_CLOSE frame + connCloseDelay time.Duration // delay until next CONNECTION_CLOSE frame sent + drainEndTime time.Time // time the connection exits the draining state +} + +func (c *Conn) lifetimeInit() { + c.lifetime.readyc = make(chan struct{}) + c.lifetime.donec = make(chan struct{}) +} + +var ( + errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE") + errConnClosed = errors.New("connection closed") +) + +// advance is called when time passes. +func (c *Conn) lifetimeAdvance(now time.Time) (done bool) { + if c.lifetime.drainEndTime.IsZero() || c.lifetime.drainEndTime.After(now) { + return false + } + // The connection drain period has ended, and we can shut down. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-7 + c.lifetime.drainEndTime = time.Time{} + if c.lifetime.state != connStateDraining { + // We were in the closing state, waiting for a CONNECTION_CLOSE from the peer. + c.setFinalError(errNoPeerResponse) + } + c.setState(now, connStateDone) + return true +} + +// setState sets the conn state. +func (c *Conn) setState(now time.Time, state connState) { + if c.lifetime.state == state { + return + } + c.lifetime.state = state + switch state { + case connStateClosing, connStateDraining: + if c.lifetime.drainEndTime.IsZero() { + c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) + } + case connStateDone: + c.setFinalError(nil) + } + if state != connStateAlive { + c.streamsCleanup() + } +} + +// handshakeDone is called when the TLS handshake completes. +func (c *Conn) handshakeDone() { + close(c.lifetime.readyc) +} + +// isDraining reports whether the conn is in the draining state. +// +// The draining state is entered once an endpoint receives a CONNECTION_CLOSE frame. +// The endpoint will no longer send any packets, but we retain knowledge of the connection +// until the end of the drain period to ensure we discard packets for the connection +// rather than treating them as starting a new connection. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 +func (c *Conn) isDraining() bool { + switch c.lifetime.state { + case connStateDraining, connStateDone: + return true + } + return false +} + +// isAlive reports whether the conn is handling packets. +func (c *Conn) isAlive() bool { + return c.lifetime.state == connStateAlive +} + +// sendOK reports whether the conn can send frames at this time. +func (c *Conn) sendOK(now time.Time) bool { + switch c.lifetime.state { + case connStateAlive: + return true + case connStatePeerClosed: + if c.lifetime.localErr == nil { + // We're waiting for the user to close the connection, providing us with + // a final status to send to the peer. + return false + } + // We should send a CONNECTION_CLOSE. + return true + case connStateClosing: + if c.lifetime.connCloseSentTime.IsZero() { + return true + } + maxRecvTime := c.acks[initialSpace].maxRecvTime + if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) { + // After sending CONNECTION_CLOSE, ignore packets from the peer for + // a delay. On the next packet received after the delay, send another + // CONNECTION_CLOSE. + return false + } + return true + case connStateDraining: + // We are in the draining state, and will send no more packets. + return false + case connStateDone: + return false + default: + panic("BUG: unhandled connection state") + } +} + +// sentConnectionClose reports that the conn has sent a CONNECTION_CLOSE to the peer. +func (c *Conn) sentConnectionClose(now time.Time) { + switch c.lifetime.state { + case connStatePeerClosed: + c.enterDraining(now) + } + if c.lifetime.connCloseSentTime.IsZero() { + // Set the initial delay before we will send another CONNECTION_CLOSE. + // + // RFC 9000 states that we should rate limit CONNECTION_CLOSE frames, + // but leaves the implementation of the limit up to us. Here, we start + // with the same delay as the PTO timer (RFC 9002, Section 6.2.1), + // not including max_ack_delay, and double it on every CONNECTION_CLOSE sent. + c.lifetime.connCloseDelay = c.loss.rtt.smoothedRTT + max(4*c.loss.rtt.rttvar, timerGranularity) + } else if !c.lifetime.connCloseSentTime.Equal(now) { + // If connCloseSentTime == now, we're sending two CONNECTION_CLOSE frames + // coalesced into the same datagram. We only want to increase the delay once. + c.lifetime.connCloseDelay *= 2 + } + c.lifetime.connCloseSentTime = now +} + +// handlePeerConnectionClose handles a CONNECTION_CLOSE from the peer. +func (c *Conn) handlePeerConnectionClose(now time.Time, err error) { + c.setFinalError(err) + switch c.lifetime.state { + case connStateAlive: + c.setState(now, connStatePeerClosed) + case connStatePeerClosed: + // Duplicate CONNECTION_CLOSE, ignore. + case connStateClosing: + if c.lifetime.connCloseSentTime.IsZero() { + c.setState(now, connStatePeerClosed) + } else { + c.setState(now, connStateDraining) + } + case connStateDraining: + case connStateDone: + } +} + +// setFinalError records the final connection status we report to the user. +func (c *Conn) setFinalError(err error) { + select { + case <-c.lifetime.donec: + return // already set + default: + } + c.lifetime.finalErr = err + close(c.lifetime.donec) +} + +// finalError returns the final connection status reported to the user, +// or nil if a final status has not yet been set. +func (c *Conn) finalError() error { + select { + case <-c.lifetime.donec: + return c.lifetime.finalErr + default: + } + return nil +} + +func (c *Conn) waitReady(ctx context.Context) error { + select { + case <-c.lifetime.readyc: + return nil + case <-c.lifetime.donec: + return c.lifetime.finalErr + default: + } + select { + case <-c.lifetime.readyc: + return nil + case <-c.lifetime.donec: + return c.lifetime.finalErr + case <-ctx.Done(): + return ctx.Err() + } +} + +// Close closes the connection. +// +// Close is equivalent to: +// +// conn.Abort(nil) +// err := conn.Wait(context.Background()) +func (c *Conn) Close() error { + c.Abort(nil) + <-c.lifetime.donec + return c.lifetime.finalErr +} + +// Wait waits for the peer to close the connection. +// +// If the connection is closed locally and the peer does not close its end of the connection, +// Wait will return with a non-nil error after the drain period expires. +// +// If the peer closes the connection with a NO_ERROR transport error, Wait returns nil. +// If the peer closes the connection with an application error, Wait returns an ApplicationError +// containing the peer's error code and reason. +// If the peer closes the connection with any other status, Wait returns a non-nil error. +func (c *Conn) Wait(ctx context.Context) error { + if err := c.waitOnDone(ctx, c.lifetime.donec); err != nil { + return err + } + return c.lifetime.finalErr +} + +// Abort closes the connection and returns immediately. +// +// If err is nil, Abort sends a transport error of NO_ERROR to the peer. +// If err is an ApplicationError, Abort sends its error code and text. +// Otherwise, Abort sends a transport error of APPLICATION_ERROR with the error's text. +func (c *Conn) Abort(err error) { + if err == nil { + err = localTransportError{code: errNo} + } + c.sendMsg(func(now time.Time, c *Conn) { + c.enterClosing(now, err) + }) +} + +// abort terminates a connection with an error. +func (c *Conn) abort(now time.Time, err error) { + c.setFinalError(err) // this error takes precedence over the peer's CONNECTION_CLOSE + c.enterClosing(now, err) +} + +// abortImmediately terminates a connection. +// The connection does not send a CONNECTION_CLOSE, and skips the draining period. +func (c *Conn) abortImmediately(now time.Time, err error) { + c.setFinalError(err) + c.setState(now, connStateDone) +} + +// enterClosing starts an immediate close. +// We will send a CONNECTION_CLOSE to the peer and wait for their response. +func (c *Conn) enterClosing(now time.Time, err error) { + switch c.lifetime.state { + case connStateAlive: + c.lifetime.localErr = err + c.setState(now, connStateClosing) + case connStatePeerClosed: + c.lifetime.localErr = err + } +} + +// enterDraining moves directly to the draining state, without sending a CONNECTION_CLOSE. +func (c *Conn) enterDraining(now time.Time) { + switch c.lifetime.state { + case connStateAlive, connStatePeerClosed, connStateClosing: + c.setState(now, connStateDraining) + } +} + +// exit fully terminates a connection immediately. +func (c *Conn) exit() { + c.sendMsg(func(now time.Time, c *Conn) { + c.abortImmediately(now, errors.New("connection closed")) + }) +} diff --git a/src/vendor/golang.org/x/net/quic/conn_flow.go b/src/vendor/golang.org/x/net/quic/conn_flow.go new file mode 100644 index 0000000000..1d04f45545 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_flow.go @@ -0,0 +1,142 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "sync/atomic" + "time" +) + +// connInflow tracks connection-level flow control for data sent by the peer to us. +// +// There are four byte offsets of significance in the stream of data received from the peer, +// each >= to the previous: +// +// - bytes read by the user +// - bytes received from the peer +// - limit sent to the peer in a MAX_DATA frame +// - potential new limit to sent to the peer +// +// We maintain a flow control window, so as bytes are read by the user +// the potential limit is extended correspondingly. +// +// We keep an atomic counter of bytes read by the user and not yet applied to the +// potential limit (credit). When this count grows large enough, we update the +// new limit to send and mark that we need to send a new MAX_DATA frame. +type connInflow struct { + sent sentVal // set when we need to send a MAX_DATA update to the peer + usedLimit int64 // total bytes sent by the peer, must be less than sentLimit + sentLimit int64 // last MAX_DATA sent to the peer + newLimit int64 // new MAX_DATA to send + + credit atomic.Int64 // bytes read but not yet applied to extending the flow-control window +} + +func (c *Conn) inflowInit() { + // The initial MAX_DATA limit is sent as a transport parameter. + c.streams.inflow.sentLimit = c.config.maxConnReadBufferSize() + c.streams.inflow.newLimit = c.streams.inflow.sentLimit +} + +// handleStreamBytesReadOffLoop records that the user has consumed bytes from a stream. +// We may extend the peer's flow control window. +// +// This is called indirectly by the user, via Read or CloseRead. +func (c *Conn) handleStreamBytesReadOffLoop(n int64) { + if n == 0 { + return + } + if c.shouldUpdateFlowControl(c.streams.inflow.credit.Add(n)) { + // We should send a MAX_DATA update to the peer. + // Record this on the Conn's main loop. + c.sendMsg(func(now time.Time, c *Conn) { + // A MAX_DATA update may have already happened, so check again. + if c.shouldUpdateFlowControl(c.streams.inflow.credit.Load()) { + c.sendMaxDataUpdate() + } + }) + } +} + +// handleStreamBytesReadOnLoop extends the peer's flow control window after +// data has been discarded due to a RESET_STREAM frame. +// +// This is called on the conn's loop. +func (c *Conn) handleStreamBytesReadOnLoop(n int64) { + if c.shouldUpdateFlowControl(c.streams.inflow.credit.Add(n)) { + c.sendMaxDataUpdate() + } +} + +func (c *Conn) sendMaxDataUpdate() { + c.streams.inflow.sent.setUnsent() + // Apply current credit to the limit. + // We don't strictly need to do this here + // since appendMaxDataFrame will do so as well, + // but this avoids redundant trips down this path + // if the MAX_DATA frame doesn't go out right away. + c.streams.inflow.newLimit += c.streams.inflow.credit.Swap(0) +} + +func (c *Conn) shouldUpdateFlowControl(credit int64) bool { + return shouldUpdateFlowControl(c.config.maxConnReadBufferSize(), credit) +} + +// handleStreamBytesReceived records that the peer has sent us stream data. +func (c *Conn) handleStreamBytesReceived(n int64) error { + c.streams.inflow.usedLimit += n + if c.streams.inflow.usedLimit > c.streams.inflow.sentLimit { + return localTransportError{ + code: errFlowControl, + reason: "stream exceeded flow control limit", + } + } + return nil +} + +// appendMaxDataFrame appends a MAX_DATA frame to the current packet. +// +// It returns true if no more frames need appending, +// false if it could not fit a frame in the current packet. +func (c *Conn) appendMaxDataFrame(w *packetWriter, pnum packetNumber, pto bool) bool { + if c.streams.inflow.sent.shouldSendPTO(pto) { + // Add any unapplied credit to the new limit now. + c.streams.inflow.newLimit += c.streams.inflow.credit.Swap(0) + if !w.appendMaxDataFrame(c.streams.inflow.newLimit) { + return false + } + c.streams.inflow.sentLimit += c.streams.inflow.newLimit + c.streams.inflow.sent.setSent(pnum) + } + return true +} + +// ackOrLossMaxData records the fate of a MAX_DATA frame. +func (c *Conn) ackOrLossMaxData(pnum packetNumber, fate packetFate) { + c.streams.inflow.sent.ackLatestOrLoss(pnum, fate) +} + +// connOutflow tracks connection-level flow control for data sent by us to the peer. +type connOutflow struct { + max int64 // largest MAX_DATA received from peer + used int64 // total bytes of STREAM data sent to peer +} + +// setMaxData updates the connection-level flow control limit +// with the initial limit conveyed in transport parameters +// or an update from a MAX_DATA frame. +func (f *connOutflow) setMaxData(maxData int64) { + f.max = max(f.max, maxData) +} + +// avail returns the number of connection-level flow control bytes available. +func (f *connOutflow) avail() int64 { + return f.max - f.used +} + +// consume records consumption of n bytes of flow. +func (f *connOutflow) consume(n int64) { + f.used += n +} diff --git a/src/vendor/golang.org/x/net/quic/conn_id.go b/src/vendor/golang.org/x/net/quic/conn_id.go new file mode 100644 index 0000000000..8749e52b79 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_id.go @@ -0,0 +1,502 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "crypto/rand" + "slices" +) + +// connIDState is a conn's connection IDs. +type connIDState struct { + // The destination connection IDs of packets we receive are local. + // The destination connection IDs of packets we send are remote. + // + // Local IDs are usually issued by us, and remote IDs by the peer. + // The exception is the transient destination connection ID sent in + // a client's Initial packets, which is chosen by the client. + // + // These are []connID rather than []*connID to minimize allocations. + local []connID + remote []remoteConnID + + nextLocalSeq int64 + peerActiveConnIDLimit int64 // peer's active_connection_id_limit + + // Handling of retirement of remote connection IDs. + // The rangesets track ID sequence numbers. + // IDs in need of retirement are added to remoteRetiring, + // moved to remoteRetiringSent once we send a RETIRE_CONECTION_ID frame, + // and removed from the set once retirement completes. + retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer + remoteRetiring rangeset[int64] // remote IDs in need of retirement + remoteRetiringSent rangeset[int64] // remote IDs waiting for ack of retirement + + originalDstConnID []byte // expected original_destination_connection_id param + retrySrcConnID []byte // expected retry_source_connection_id param + + needSend bool +} + +// A connID is a connection ID and associated metadata. +type connID struct { + // cid is the connection ID itself. + cid []byte + + // seq is the connection ID's sequence number: + // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-1 + // + // For the transient destination ID in a client's Initial packet, this is -1. + seq int64 + + // send is set when the connection ID's state needs to be sent to the peer. + // + // For local IDs, this indicates a new ID that should be sent + // in a NEW_CONNECTION_ID frame. + // + // For remote IDs, this indicates a retired ID that should be sent + // in a RETIRE_CONNECTION_ID frame. + send sentVal +} + +// A remoteConnID is a connection ID and stateless reset token. +type remoteConnID struct { + connID + resetToken statelessResetToken +} + +func (s *connIDState) initClient(c *Conn) error { + // Client chooses its initial connection ID, and sends it + // in the Source Connection ID field of the first Initial packet. + locid, err := c.newConnID(0) + if err != nil { + return err + } + s.local = append(s.local, connID{ + seq: 0, + cid: locid, + }) + s.nextLocalSeq = 1 + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.addConnID(c, locid) + }) + + // Client chooses an initial, transient connection ID for the server, + // and sends it in the Destination Connection ID field of the first Initial packet. + remid, err := c.newConnID(-1) + if err != nil { + return err + } + s.remote = append(s.remote, remoteConnID{ + connID: connID{ + seq: -1, + cid: remid, + }, + }) + s.originalDstConnID = remid + return nil +} + +func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error { + dstConnID := cloneBytes(cids.dstConnID) + // Client-chosen, transient connection ID received in the first Initial packet. + // The server will not use this as the Source Connection ID of packets it sends, + // but remembers it because it may receive packets sent to this destination. + s.local = append(s.local, connID{ + seq: -1, + cid: dstConnID, + }) + + // Server chooses a connection ID, and sends it in the Source Connection ID of + // the response to the clent. + locid, err := c.newConnID(0) + if err != nil { + return err + } + s.local = append(s.local, connID{ + seq: 0, + cid: locid, + }) + s.nextLocalSeq = 1 + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.addConnID(c, dstConnID) + conns.addConnID(c, locid) + }) + + // Client chose its own connection ID. + s.remote = append(s.remote, remoteConnID{ + connID: connID{ + seq: 0, + cid: cloneBytes(cids.srcConnID), + }, + }) + return nil +} + +// srcConnID is the Source Connection ID to use in a sent packet. +func (s *connIDState) srcConnID() []byte { + if s.local[0].seq == -1 && len(s.local) > 1 { + // Don't use the transient connection ID if another is available. + return s.local[1].cid + } + return s.local[0].cid +} + +// dstConnID is the Destination Connection ID to use in a sent packet. +func (s *connIDState) dstConnID() (cid []byte, ok bool) { + for i := range s.remote { + return s.remote[i].cid, true + } + return nil, false +} + +// isValidStatelessResetToken reports whether the given reset token is +// associated with a non-retired connection ID which we have used. +func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool { + if len(s.remote) == 0 { + return false + } + // We currently only use the first available remote connection ID, + // so any other reset token is not valid. + return s.remote[0].resetToken == resetToken +} + +// setPeerActiveConnIDLimit sets the active_connection_id_limit +// transport parameter received from the peer. +func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error { + s.peerActiveConnIDLimit = lim + return s.issueLocalIDs(c) +} + +func (s *connIDState) issueLocalIDs(c *Conn) error { + toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit) + for i := range s.local { + if s.local[i].seq != -1 { + toIssue-- + } + } + var newIDs [][]byte + for toIssue > 0 { + cid, err := c.newConnID(s.nextLocalSeq) + if err != nil { + return err + } + newIDs = append(newIDs, cid) + s.local = append(s.local, connID{ + seq: s.nextLocalSeq, + cid: cid, + }) + s.local[len(s.local)-1].send.setUnsent() + s.nextLocalSeq++ + s.needSend = true + toIssue-- + } + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + for _, cid := range newIDs { + conns.addConnID(c, cid) + } + }) + return nil +} + +// validateTransportParameters verifies the original_destination_connection_id and +// initial_source_connection_id transport parameters match the expected values. +func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p transportParameters) error { + // TODO: Consider returning more detailed errors, for debugging. + // Verify original_destination_connection_id matches + // the transient remote connection ID we chose (client) + // or is empty (server). + if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) { + return localTransportError{ + code: errTransportParameter, + reason: "original_destination_connection_id mismatch", + } + } + s.originalDstConnID = nil // we have no further need for this + // Verify retry_source_connection_id matches the value from + // the server's Retry packet (when one was sent), or is empty. + if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) { + return localTransportError{ + code: errTransportParameter, + reason: "retry_source_connection_id mismatch", + } + } + s.retrySrcConnID = nil // we have no further need for this + // Verify initial_source_connection_id matches the first remote connection ID. + if len(s.remote) == 0 || s.remote[0].seq != 0 { + return localTransportError{ + code: errInternal, + reason: "remote connection id missing", + } + } + if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) { + return localTransportError{ + code: errTransportParameter, + reason: "initial_source_connection_id mismatch", + } + } + if len(p.statelessResetToken) > 0 { + if c.side == serverSide { + return localTransportError{ + code: errTransportParameter, + reason: "client sent stateless_reset_token", + } + } + token := statelessResetToken(p.statelessResetToken) + s.remote[0].resetToken = token + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.addResetToken(c, token) + }) + } + return nil +} + +// handlePacket updates the connection ID state during the handshake +// (Initial and Handshake packets). +func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) { + switch { + case ptype == packetTypeInitial && c.side == clientSide: + if len(s.remote) == 1 && s.remote[0].seq == -1 { + // We're a client connection processing the first Initial packet + // from the server. Replace the transient remote connection ID + // with the Source Connection ID from the packet. + s.remote[0] = remoteConnID{ + connID: connID{ + seq: 0, + cid: cloneBytes(srcConnID), + }, + } + } + case ptype == packetTypeHandshake && c.side == serverSide: + if len(s.local) > 0 && s.local[0].seq == -1 { + // We're a server connection processing the first Handshake packet from + // the client. Discard the transient, client-chosen connection ID used + // for Initial packets; the client will never send it again. + cid := s.local[0].cid + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.retireConnID(c, cid) + }) + s.local = append(s.local[:0], s.local[1:]...) + } + } +} + +func (s *connIDState) handleRetryPacket(srcConnID []byte) { + if len(s.remote) != 1 || s.remote[0].seq != -1 { + panic("BUG: handling retry with non-transient remote conn id") + } + s.retrySrcConnID = cloneBytes(srcConnID) + s.remote[0].cid = s.retrySrcConnID +} + +func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, resetToken statelessResetToken) error { + if len(s.remote[0].cid) == 0 { + // "An endpoint that is sending packets with a zero-length + // Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID + // frame as a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6 + return localTransportError{ + code: errProtocolViolation, + reason: "NEW_CONNECTION_ID from peer with zero-length DCID", + } + } + + if seq < s.retireRemotePriorTo { + // This ID was already retired by a previous NEW_CONNECTION_ID frame. + // Nothing to do. + return nil + } + + if retire > s.retireRemotePriorTo { + // Add newly-retired connection IDs to the set we need to send + // RETIRE_CONNECTION_ID frames for, and remove them from s.remote. + // + // (This might cause us to send a RETIRE_CONNECTION_ID for an ID we've + // never seen. That's fine.) + s.remoteRetiring.add(s.retireRemotePriorTo, retire) + s.retireRemotePriorTo = retire + s.needSend = true + s.remote = slices.DeleteFunc(s.remote, func(rcid remoteConnID) bool { + return rcid.seq < s.retireRemotePriorTo + }) + } + + have := false // do we already have this connection ID? + for i := range s.remote { + rcid := &s.remote[i] + if rcid.seq == seq { + if !bytes.Equal(rcid.cid, cid) { + return localTransportError{ + code: errProtocolViolation, + reason: "NEW_CONNECTION_ID does not match prior id", + } + } + have = true // yes, we've seen this sequence number + break + } + } + + if !have { + // This is a new connection ID that we have not seen before. + // + // We could take steps to keep the list of remote connection IDs + // sorted by sequence number, but there's no particular need + // so we don't bother. + s.remote = append(s.remote, remoteConnID{ + connID: connID{ + seq: seq, + cid: cloneBytes(cid), + }, + resetToken: resetToken, + }) + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.addResetToken(c, resetToken) + }) + } + + if len(s.remote) > activeConnIDLimit { + // Retired connection IDs (including newly-retired ones) do not count + // against the limit. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5 + return localTransportError{ + code: errConnectionIDLimit, + reason: "active_connection_id_limit exceeded", + } + } + + // "An endpoint SHOULD limit the number of connection IDs it has retired locally + // for which RETIRE_CONNECTION_ID frames have not yet been acknowledged." + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6 + // + // Set a limit of three times the active_connection_id_limit for + // the total number of remote connection IDs we keep retirement state for. + if s.remoteRetiring.size()+s.remoteRetiringSent.size() > 3*activeConnIDLimit { + return localTransportError{ + code: errConnectionIDLimit, + reason: "too many unacknowledged retired connection ids", + } + } + + return nil +} + +func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error { + if seq >= s.nextLocalSeq { + return localTransportError{ + code: errProtocolViolation, + reason: "RETIRE_CONNECTION_ID for unissued sequence number", + } + } + for i := range s.local { + if s.local[i].seq == seq { + cid := s.local[i].cid + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.retireConnID(c, cid) + }) + s.local = append(s.local[:i], s.local[i+1:]...) + break + } + } + s.issueLocalIDs(c) + return nil +} + +func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) { + for i := range s.local { + if s.local[i].seq != seq { + continue + } + s.local[i].send.ackOrLoss(pnum, fate) + if fate != packetAcked { + s.needSend = true + } + return + } +} + +func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) { + s.remoteRetiringSent.sub(seq, seq+1) + if fate == packetLost { + // RETIRE_CONNECTION_ID frame was lost, mark for retransmission. + s.remoteRetiring.add(seq, seq+1) + s.needSend = true + } +} + +// appendFrames appends NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames +// to the current packet. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool { + if !s.needSend && !pto { + // Fast path: We don't need to send anything. + return true + } + retireBefore := int64(0) + if s.local[0].seq != -1 { + retireBefore = s.local[0].seq + } + for i := range s.local { + if !s.local[i].send.shouldSendPTO(pto) { + continue + } + if !c.w.appendNewConnectionIDFrame( + s.local[i].seq, + retireBefore, + s.local[i].cid, + c.endpoint.resetGen.tokenForConnID(s.local[i].cid), + ) { + return false + } + s.local[i].send.setSent(pnum) + } + if pto { + for _, r := range s.remoteRetiringSent { + for cid := r.start; cid < r.end; cid++ { + if !c.w.appendRetireConnectionIDFrame(cid) { + return false + } + } + } + } + for s.remoteRetiring.numRanges() > 0 { + cid := s.remoteRetiring.min() + if !c.w.appendRetireConnectionIDFrame(cid) { + return false + } + s.remoteRetiring.sub(cid, cid+1) + s.remoteRetiringSent.add(cid, cid+1) + } + s.needSend = false + return true +} + +func cloneBytes(b []byte) []byte { + n := make([]byte, len(b)) + copy(n, b) + return n +} + +func (c *Conn) newConnID(seq int64) ([]byte, error) { + if c.testHooks != nil { + return c.testHooks.newConnID(seq) + } + return newRandomConnID(seq) +} + +func newRandomConnID(_ int64) ([]byte, error) { + // It is not necessary for connection IDs to be cryptographically secure, + // but it doesn't hurt. + id := make([]byte, connIDLen) + if _, err := rand.Read(id); err != nil { + // TODO: Surface this error as a metric or log event or something. + // rand.Read really shouldn't ever fail, but if it does, we should + // have a way to inform the user. + return nil, err + } + return id, nil +} diff --git a/src/vendor/golang.org/x/net/quic/conn_loss.go b/src/vendor/golang.org/x/net/quic/conn_loss.go new file mode 100644 index 0000000000..bc6d106601 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_loss.go @@ -0,0 +1,85 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "fmt" + +// handleAckOrLoss deals with the final fate of a packet we sent: +// Either the peer acknowledges it, or we declare it lost. +// +// In order to handle packet loss, we must retain any information sent to the peer +// until the peer has acknowledged it. +// +// When information is acknowledged, we can discard it. +// +// When information is lost, we mark it for retransmission. +// See RFC 9000, Section 13.3 for a complete list of information which is retransmitted on loss. +// https://www.rfc-editor.org/rfc/rfc9000#section-13.3 +func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) { + if fate == packetLost && c.logEnabled(QLogLevelPacket) { + c.logPacketLost(space, sent) + } + + // The list of frames in a sent packet is marshaled into a buffer in the sentPacket + // by the packetWriter. Unmarshal that buffer here. This code must be kept in sync with + // packetWriter.append*. + // + // A sent packet meets its fate (acked or lost) only once, so it's okay to consume + // the sentPacket's buffer here. + for !sent.done() { + switch f := sent.next(); f { + default: + panic(fmt.Sprintf("BUG: unhandled acked/lost frame type %x", f)) + case frameTypeAck, frameTypeAckECN: + // Unlike most information, loss of an ACK frame does not trigger + // retransmission. ACKs are sent in response to ack-eliciting packets, + // and always contain the latest information available. + // + // Acknowledgement of an ACK frame may allow us to discard information + // about older packets. + largest := packetNumber(sent.nextInt()) + if fate == packetAcked { + c.acks[space].handleAck(largest) + } + case frameTypeCrypto: + start, end := sent.nextRange() + c.crypto[space].ackOrLoss(start, end, fate) + case frameTypeMaxData: + c.ackOrLossMaxData(sent.num, fate) + case frameTypeResetStream, + frameTypeStopSending, + frameTypeMaxStreamData, + frameTypeStreamDataBlocked: + id := streamID(sent.nextInt()) + s := c.streamForID(id) + if s == nil { + continue + } + s.ackOrLoss(sent.num, f, fate) + case frameTypeStreamBase, + frameTypeStreamBase | streamFinBit: + id := streamID(sent.nextInt()) + start, end := sent.nextRange() + s := c.streamForID(id) + if s == nil { + continue + } + fin := f&streamFinBit != 0 + s.ackOrLossData(sent.num, start, end, fin, fate) + case frameTypeMaxStreamsBidi: + c.streams.remoteLimit[bidiStream].sendMax.ackLatestOrLoss(sent.num, fate) + case frameTypeMaxStreamsUni: + c.streams.remoteLimit[uniStream].sendMax.ackLatestOrLoss(sent.num, fate) + case frameTypeNewConnectionID: + seq := int64(sent.nextInt()) + c.connIDState.ackOrLossNewConnectionID(sent.num, seq, fate) + case frameTypeRetireConnectionID: + seq := int64(sent.nextInt()) + c.connIDState.ackOrLossRetireConnectionID(sent.num, seq, fate) + case frameTypeHandshakeDone: + c.handshakeConfirmed.ackOrLoss(sent.num, fate) + } + } +} diff --git a/src/vendor/golang.org/x/net/quic/conn_recv.go b/src/vendor/golang.org/x/net/quic/conn_recv.go new file mode 100644 index 0000000000..2bf127a479 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_recv.go @@ -0,0 +1,631 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "encoding/binary" + "errors" + "time" +) + +func (c *Conn) handleDatagram(now time.Time, dgram *datagram) (handled bool) { + if !c.localAddr.IsValid() { + // We don't have any way to tell in the general case what address we're + // sending packets from. Set our address from the destination address of + // the first packet received from the peer. + c.localAddr = dgram.localAddr + } + if dgram.peerAddr.IsValid() && dgram.peerAddr != c.peerAddr { + if c.side == clientSide { + // "If a client receives packets from an unknown server address, + // the client MUST discard these packets." + // https://www.rfc-editor.org/rfc/rfc9000#section-9-6 + return false + } + // We currently don't support connection migration, + // so for now the server also drops packets from an unknown address. + return false + } + buf := dgram.b + c.loss.datagramReceived(now, len(buf)) + if c.isDraining() { + return false + } + for len(buf) > 0 { + var n int + ptype := getPacketType(buf) + switch ptype { + case packetTypeInitial: + if c.side == serverSide && len(dgram.b) < paddedInitialDatagramSize { + // Discard client-sent Initial packets in too-short datagrams. + // https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4 + return false + } + n = c.handleLongHeader(now, dgram, ptype, initialSpace, c.keysInitial.r, buf) + case packetTypeHandshake: + n = c.handleLongHeader(now, dgram, ptype, handshakeSpace, c.keysHandshake.r, buf) + case packetType1RTT: + n = c.handle1RTT(now, dgram, buf) + case packetTypeRetry: + c.handleRetry(now, buf) + return true + case packetTypeVersionNegotiation: + c.handleVersionNegotiation(now, buf) + return true + default: + n = -1 + } + if n <= 0 { + // We don't expect to get a stateless reset with a valid + // destination connection ID, since the sender of a stateless + // reset doesn't know what the connection ID is. + // + // We're required to perform this check anyway. + // + // "[...] the comparison MUST be performed when the first packet + // in an incoming datagram [...] cannot be decrypted." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-2 + if len(buf) == len(dgram.b) && len(buf) > statelessResetTokenLen { + var token statelessResetToken + copy(token[:], buf[len(buf)-len(token):]) + if c.handleStatelessReset(now, token) { + return true + } + } + // Invalid data at the end of a datagram is ignored. + return false + } + c.idleHandlePacketReceived(now) + buf = buf[n:] + } + return true +} + +func (c *Conn) handleLongHeader(now time.Time, dgram *datagram, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int { + if !k.isSet() { + return skipLongHeaderPacket(buf) + } + + pnumMax := c.acks[space].largestSeen() + p, n := parseLongHeaderPacket(buf, k, pnumMax) + if n < 0 { + return -1 + } + if buf[0]&reservedLongBits != 0 { + // Reserved header bits must be 0. + // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "reserved header bits are not zero", + }) + return -1 + } + if p.version != quicVersion1 { + // The peer has changed versions on us mid-handshake? + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "protocol version changed during handshake", + }) + return -1 + } + + if !c.acks[space].shouldProcess(p.num) { + return n + } + + if logPackets { + logInboundLongPacket(c, p) + } + if c.logEnabled(QLogLevelPacket) { + c.logLongPacketReceived(p, buf[:n]) + } + c.connIDState.handlePacket(c, p.ptype, p.srcConnID) + ackEliciting := c.handleFrames(now, dgram, ptype, space, p.payload) + c.acks[space].receive(now, space, p.num, ackEliciting, dgram.ecn) + if p.ptype == packetTypeHandshake && c.side == serverSide { + c.loss.validateClientAddress() + + // "[...] a server MUST discard Initial keys when it first successfully + // processes a Handshake packet [...]" + // https://www.rfc-editor.org/rfc/rfc9001#section-4.9.1-2 + c.discardKeys(now, initialSpace) + } + return n +} + +func (c *Conn) handle1RTT(now time.Time, dgram *datagram, buf []byte) int { + if !c.keysAppData.canRead() { + // 1-RTT packets extend to the end of the datagram, + // so skip the remainder of the datagram if we can't parse this. + return len(buf) + } + + pnumMax := c.acks[appDataSpace].largestSeen() + p, err := parse1RTTPacket(buf, &c.keysAppData, connIDLen, pnumMax) + if err != nil { + // A localTransportError terminates the connection. + // Other errors indicate an unparsable packet, but otherwise may be ignored. + if _, ok := err.(localTransportError); ok { + c.abort(now, err) + } + return -1 + } + if buf[0]&reserved1RTTBits != 0 { + // Reserved header bits must be 0. + // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "reserved header bits are not zero", + }) + return -1 + } + + if !c.acks[appDataSpace].shouldProcess(p.num) { + return len(buf) + } + + if logPackets { + logInboundShortPacket(c, p) + } + if c.logEnabled(QLogLevelPacket) { + c.log1RTTPacketReceived(p, buf) + } + ackEliciting := c.handleFrames(now, dgram, packetType1RTT, appDataSpace, p.payload) + c.acks[appDataSpace].receive(now, appDataSpace, p.num, ackEliciting, dgram.ecn) + return len(buf) +} + +func (c *Conn) handleRetry(now time.Time, pkt []byte) { + if c.side != clientSide { + return // clients don't send Retry packets + } + // "After the client has received and processed an Initial or Retry packet + // from the server, it MUST discard any subsequent Retry packets that it receives." + // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-1 + if !c.keysInitial.canRead() { + return // discarded Initial keys, connection is already established + } + if c.acks[initialSpace].seen.numRanges() != 0 { + return // processed at least one packet + } + if c.retryToken != nil { + return // received a Retry already + } + // "Clients MUST discard Retry packets that have a Retry Integrity Tag + // that cannot be validated." + // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-2 + p, ok := parseRetryPacket(pkt, c.connIDState.originalDstConnID) + if !ok { + return + } + // "A client MUST discard a Retry packet with a zero-length Retry Token field." + // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-2 + if len(p.token) == 0 { + return + } + c.retryToken = cloneBytes(p.token) + c.connIDState.handleRetryPacket(p.srcConnID) + c.keysInitial = initialKeys(p.srcConnID, c.side) + // We need to resend any data we've already sent in Initial packets. + // We must not reuse already sent packet numbers. + c.loss.discardPackets(initialSpace, c.log, c.handleAckOrLoss) + // TODO: Discard 0-RTT packets as well, once we support 0-RTT. + if c.testHooks != nil { + c.testHooks.init(false) + } +} + +var errVersionNegotiation = errors.New("server does not support QUIC version 1") + +func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) { + if c.side != clientSide { + return // servers don't handle Version Negotiation packets + } + // "A client MUST discard any Version Negotiation packet if it has + // received and successfully processed any other packet [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + if !c.keysInitial.canRead() { + return // discarded Initial keys, connection is already established + } + if c.acks[initialSpace].seen.numRanges() != 0 { + return // processed at least one packet + } + _, srcConnID, versions := parseVersionNegotiation(pkt) + if len(c.connIDState.remote) < 1 || !bytes.Equal(c.connIDState.remote[0].cid, srcConnID) { + return // Source Connection ID doesn't match what we sent + } + for len(versions) >= 4 { + ver := binary.BigEndian.Uint32(versions) + if ver == 1 { + // "A client MUST discard a Version Negotiation packet that lists + // the QUIC version selected by the client." + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + return + } + versions = versions[4:] + } + // "A client that supports only this version of QUIC MUST + // abandon the current connection attempt if it receives + // a Version Negotiation packet, [with the two exceptions handled above]." + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + c.abortImmediately(now, errVersionNegotiation) +} + +func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) { + if len(payload) == 0 { + // "An endpoint MUST treat receipt of a packet containing no frames + // as a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "packet contains no frames", + }) + return false + } + // frameOK verifies that ptype is one of the packets in mask. + frameOK := func(c *Conn, ptype, mask packetType) (ok bool) { + if ptype&mask == 0 { + // "An endpoint MUST treat receipt of a frame in a packet type + // that is not permitted as a connection error of type + // PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "frame not allowed in packet", + }) + return false + } + return true + } + // Packet masks from RFC 9000 Table 3. + // https://www.rfc-editor.org/rfc/rfc9000#table-3 + const ( + IH_1 = packetTypeInitial | packetTypeHandshake | packetType1RTT + __01 = packetType0RTT | packetType1RTT + ___1 = packetType1RTT + ) + hasCrypto := false + for len(payload) > 0 { + switch payload[0] { + case frameTypePadding, frameTypeAck, frameTypeAckECN, + frameTypeConnectionCloseTransport, frameTypeConnectionCloseApplication: + default: + ackEliciting = true + } + n := -1 + switch payload[0] { + case frameTypePadding: + // PADDING is OK in all spaces. + n = 1 + case frameTypePing: + // PING is OK in all spaces. + // + // A PING frame causes us to respond with an ACK by virtue of being + // an ack-eliciting frame, but requires no other action. + n = 1 + case frameTypeAck, frameTypeAckECN: + if !frameOK(c, ptype, IH_1) { + return + } + n = c.handleAckFrame(now, space, payload) + case frameTypeResetStream: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleResetStreamFrame(now, space, payload) + case frameTypeStopSending: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleStopSendingFrame(now, space, payload) + case frameTypeCrypto: + if !frameOK(c, ptype, IH_1) { + return + } + hasCrypto = true + n = c.handleCryptoFrame(now, space, payload) + case frameTypeNewToken: + if !frameOK(c, ptype, ___1) { + return + } + _, n = consumeNewTokenFrame(payload) + case 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f: // STREAM + if !frameOK(c, ptype, __01) { + return + } + n = c.handleStreamFrame(now, space, payload) + case frameTypeMaxData: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleMaxDataFrame(now, payload) + case frameTypeMaxStreamData: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleMaxStreamDataFrame(now, payload) + case frameTypeMaxStreamsBidi, frameTypeMaxStreamsUni: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleMaxStreamsFrame(now, payload) + case frameTypeDataBlocked: + if !frameOK(c, ptype, __01) { + return + } + _, n = consumeDataBlockedFrame(payload) + case frameTypeStreamsBlockedBidi, frameTypeStreamsBlockedUni: + if !frameOK(c, ptype, __01) { + return + } + _, _, n = consumeStreamsBlockedFrame(payload) + case frameTypeStreamDataBlocked: + if !frameOK(c, ptype, __01) { + return + } + _, _, n = consumeStreamDataBlockedFrame(payload) + case frameTypeNewConnectionID: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleNewConnectionIDFrame(now, space, payload) + case frameTypeRetireConnectionID: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleRetireConnectionIDFrame(now, space, payload) + case frameTypePathChallenge: + if !frameOK(c, ptype, __01) { + return + } + n = c.handlePathChallengeFrame(now, dgram, space, payload) + case frameTypePathResponse: + if !frameOK(c, ptype, ___1) { + return + } + n = c.handlePathResponseFrame(now, space, payload) + case frameTypeConnectionCloseTransport: + // Transport CONNECTION_CLOSE is OK in all spaces. + n = c.handleConnectionCloseTransportFrame(now, payload) + case frameTypeConnectionCloseApplication: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleConnectionCloseApplicationFrame(now, payload) + case frameTypeHandshakeDone: + if !frameOK(c, ptype, ___1) { + return + } + n = c.handleHandshakeDoneFrame(now, space, payload) + } + if n < 0 { + c.abort(now, localTransportError{ + code: errFrameEncoding, + reason: "frame encoding error", + }) + return false + } + payload = payload[n:] + } + if hasCrypto { + // Process TLS events after handling all frames in a packet. + // TLS events can cause us to drop state for a number space, + // so do that last, to avoid handling frames differently + // depending on whether they come before or after a CRYPTO frame. + if err := c.handleTLSEvents(now); err != nil { + c.abort(now, err) + } + } + return ackEliciting +} + +func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) int { + c.loss.receiveAckStart() + largest, ackDelay, ecn, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) { + if err := c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss); err != nil { + c.abort(now, err) + return + } + }) + // TODO: Make use of ECN feedback. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.3.2 + _ = ecn + // Prior to receiving the peer's transport parameters, we cannot + // interpret the ACK Delay field because we don't know the ack_delay_exponent + // to apply. + // + // For servers, we should always know the ack_delay_exponent because the + // client's transport parameters are carried in its Initial packets and we + // won't send an ack-eliciting Initial packet until after receiving the last + // client Initial packet. + // + // For clients, we won't receive the server's transport parameters until handling + // its Handshake flight, which will probably happen after reading its ACK for our + // Initial packet(s). However, the peer's acknowledgement delay cannot reduce our + // adjusted RTT sample below min_rtt, and min_rtt is generally going to be set + // by the packet containing the ACK for our Initial flight. Therefore, the + // ACK Delay for an ACK in the Initial space is likely to be ignored anyway. + // + // Long story short, setting the delay to 0 prior to reading transport parameters + // is usually going to have no effect, will have only a minor effect in the rare + // cases when it happens, and there aren't any good alternatives anyway since we + // can't interpret the ACK Delay field without knowing the exponent. + var delay time.Duration + if c.peerAckDelayExponent >= 0 { + delay = ackDelay.Duration(uint8(c.peerAckDelayExponent)) + } + c.loss.receiveAckEnd(now, c.log, space, delay, c.handleAckOrLoss) + if space == appDataSpace { + c.keysAppData.handleAckFor(largest) + } + return n +} + +func (c *Conn) handleMaxDataFrame(now time.Time, payload []byte) int { + maxData, n := consumeMaxDataFrame(payload) + if n < 0 { + return -1 + } + c.streams.outflow.setMaxData(maxData) + return n +} + +func (c *Conn) handleMaxStreamDataFrame(now time.Time, payload []byte) int { + id, maxStreamData, n := consumeMaxStreamDataFrame(payload) + if n < 0 { + return -1 + } + if s := c.streamForFrame(now, id, sendStream); s != nil { + if err := s.handleMaxStreamData(maxStreamData); err != nil { + c.abort(now, err) + return -1 + } + } + return n +} + +func (c *Conn) handleMaxStreamsFrame(now time.Time, payload []byte) int { + styp, max, n := consumeMaxStreamsFrame(payload) + if n < 0 { + return -1 + } + c.streams.localLimit[styp].setMax(max) + return n +} + +func (c *Conn) handleResetStreamFrame(now time.Time, space numberSpace, payload []byte) int { + id, code, finalSize, n := consumeResetStreamFrame(payload) + if n < 0 { + return -1 + } + if s := c.streamForFrame(now, id, recvStream); s != nil { + if err := s.handleReset(code, finalSize); err != nil { + c.abort(now, err) + } + } + return n +} + +func (c *Conn) handleStopSendingFrame(now time.Time, space numberSpace, payload []byte) int { + id, code, n := consumeStopSendingFrame(payload) + if n < 0 { + return -1 + } + if s := c.streamForFrame(now, id, sendStream); s != nil { + if err := s.handleStopSending(code); err != nil { + c.abort(now, err) + } + } + return n +} + +func (c *Conn) handleCryptoFrame(now time.Time, space numberSpace, payload []byte) int { + off, data, n := consumeCryptoFrame(payload) + err := c.handleCrypto(now, space, off, data) + if err != nil { + c.abort(now, err) + return -1 + } + return n +} + +func (c *Conn) handleStreamFrame(now time.Time, space numberSpace, payload []byte) int { + id, off, fin, b, n := consumeStreamFrame(payload) + if n < 0 { + return -1 + } + if s := c.streamForFrame(now, id, recvStream); s != nil { + if err := s.handleData(off, b, fin); err != nil { + c.abort(now, err) + } + } + return n +} + +func (c *Conn) handleNewConnectionIDFrame(now time.Time, space numberSpace, payload []byte) int { + seq, retire, connID, resetToken, n := consumeNewConnectionIDFrame(payload) + if n < 0 { + return -1 + } + if err := c.connIDState.handleNewConnID(c, seq, retire, connID, resetToken); err != nil { + c.abort(now, err) + } + return n +} + +func (c *Conn) handleRetireConnectionIDFrame(now time.Time, space numberSpace, payload []byte) int { + seq, n := consumeRetireConnectionIDFrame(payload) + if n < 0 { + return -1 + } + if err := c.connIDState.handleRetireConnID(c, seq); err != nil { + c.abort(now, err) + } + return n +} + +func (c *Conn) handlePathChallengeFrame(now time.Time, dgram *datagram, space numberSpace, payload []byte) int { + data, n := consumePathChallengeFrame(payload) + if n < 0 { + return -1 + } + c.handlePathChallenge(now, dgram, data) + return n +} + +func (c *Conn) handlePathResponseFrame(now time.Time, space numberSpace, payload []byte) int { + data, n := consumePathResponseFrame(payload) + if n < 0 { + return -1 + } + c.handlePathResponse(now, data) + return n +} + +func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte) int { + code, _, reason, n := consumeConnectionCloseTransportFrame(payload) + if n < 0 { + return -1 + } + c.handlePeerConnectionClose(now, peerTransportError{code: code, reason: reason}) + return n +} + +func (c *Conn) handleConnectionCloseApplicationFrame(now time.Time, payload []byte) int { + code, reason, n := consumeConnectionCloseApplicationFrame(payload) + if n < 0 { + return -1 + } + c.handlePeerConnectionClose(now, &ApplicationError{Code: code, Reason: reason}) + return n +} + +func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payload []byte) int { + if c.side == serverSide { + // Clients should never send HANDSHAKE_DONE. + // https://www.rfc-editor.org/rfc/rfc9000#section-19.20-4 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "client sent HANDSHAKE_DONE", + }) + return -1 + } + if c.isAlive() { + c.confirmHandshake(now) + } + return 1 +} + +var errStatelessReset = errors.New("received stateless reset") + +func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) (valid bool) { + if !c.connIDState.isValidStatelessResetToken(resetToken) { + return false + } + c.setFinalError(errStatelessReset) + c.enterDraining(now) + return true +} diff --git a/src/vendor/golang.org/x/net/quic/conn_send.go b/src/vendor/golang.org/x/net/quic/conn_send.go new file mode 100644 index 0000000000..3e8cf526b5 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_send.go @@ -0,0 +1,407 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "crypto/tls" + "errors" + "time" +) + +// maybeSend sends datagrams, if possible. +// +// If sending is blocked by pacing, it returns the next time +// a datagram may be sent. +// +// If sending is blocked indefinitely, it returns the zero Time. +func (c *Conn) maybeSend(now time.Time) (next time.Time) { + // Assumption: The congestion window is not underutilized. + // If congestion control, pacing, and anti-amplification all permit sending, + // but we have no packet to send, then we will declare the window underutilized. + underutilized := false + defer func() { + c.loss.cc.setUnderutilized(c.log, underutilized) + }() + + // Send one datagram on each iteration of this loop, + // until we hit a limit or run out of data to send. + // + // For each number space where we have write keys, + // attempt to construct a packet in that space. + // If the packet contains no frames (we have no data in need of sending), + // abandon the packet. + // + // Speculatively constructing packets means we don't need + // separate code paths for "do we have data to send?" and + // "send the data" that need to be kept in sync. + for { + limit, next := c.loss.sendLimit(now) + if limit == ccBlocked { + // If anti-amplification blocks sending, then no packet can be sent. + return next + } + if !c.sendOK(now) { + return time.Time{} + } + // We may still send ACKs, even if congestion control or pacing limit sending. + + // Prepare to write a datagram of at most maxSendSize bytes. + c.w.reset(c.loss.maxSendSize()) + + dstConnID, ok := c.connIDState.dstConnID() + if !ok { + // It is currently not possible for us to end up without a connection ID, + // but handle the case anyway. + return time.Time{} + } + + // Initial packet. + pad := false + var sentInitial *sentPacket + if c.keysInitial.canWrite() { + pnumMaxAcked := c.loss.spaces[initialSpace].maxAcked + pnum := c.loss.nextNumber(initialSpace) + p := longPacket{ + ptype: packetTypeInitial, + version: quicVersion1, + num: pnum, + dstConnID: dstConnID, + srcConnID: c.connIDState.srcConnID(), + extra: c.retryToken, + } + c.w.startProtectedLongHeaderPacket(pnumMaxAcked, p) + c.appendFrames(now, initialSpace, pnum, limit) + if logPackets { + logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload()) + } + if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { + c.logPacketSent(packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload()) + } + sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p) + if sentInitial != nil { + // Client initial packets and ack-eliciting server initial packaets + // need to be sent in a datagram padded to at least 1200 bytes. + // We can't add the padding yet, however, since we may want to + // coalesce additional packets with this one. + if c.side == clientSide || sentInitial.ackEliciting { + pad = true + } + } + } + + // Handshake packet. + if c.keysHandshake.canWrite() { + pnumMaxAcked := c.loss.spaces[handshakeSpace].maxAcked + pnum := c.loss.nextNumber(handshakeSpace) + p := longPacket{ + ptype: packetTypeHandshake, + version: quicVersion1, + num: pnum, + dstConnID: dstConnID, + srcConnID: c.connIDState.srcConnID(), + } + c.w.startProtectedLongHeaderPacket(pnumMaxAcked, p) + c.appendFrames(now, handshakeSpace, pnum, limit) + if logPackets { + logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload()) + } + if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { + c.logPacketSent(packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload()) + } + if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil { + c.packetSent(now, handshakeSpace, sent) + if c.side == clientSide { + // "[...] a client MUST discard Initial keys when it first + // sends a Handshake packet [...]" + // https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9.1-2 + c.discardKeys(now, initialSpace) + } + } + } + + // 1-RTT packet. + if c.keysAppData.canWrite() { + pnumMaxAcked := c.loss.spaces[appDataSpace].maxAcked + pnum := c.loss.nextNumber(appDataSpace) + c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID) + c.appendFrames(now, appDataSpace, pnum, limit) + if pad && len(c.w.payload()) > 0 { + // 1-RTT packets have no length field and extend to the end + // of the datagram, so if we're sending a datagram that needs + // padding we need to add it inside the 1-RTT packet. + c.w.appendPaddingTo(paddedInitialDatagramSize) + pad = false + } + if logPackets { + logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload()) + } + if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { + c.logPacketSent(packetType1RTT, pnum, nil, dstConnID, c.w.packetLen(), c.w.payload()) + } + if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil { + c.packetSent(now, appDataSpace, sent) + if c.skip.shouldSkip(pnum + 1) { + c.loss.skipNumber(now, appDataSpace) + c.skip.updateNumberSkip(c) + } + } + } + + buf := c.w.datagram() + if len(buf) == 0 { + if limit == ccOK { + // We have nothing to send, and congestion control does not + // block sending. The congestion window is underutilized. + underutilized = true + } + return next + } + + if sentInitial != nil { + if pad { + // Pad out the datagram with zeros, coalescing the Initial + // packet with invalid packets that will be ignored by the peer. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-14.1-1 + for len(buf) < paddedInitialDatagramSize { + buf = append(buf, 0) + // Technically this padding isn't in any packet, but + // account it to the Initial packet in this datagram + // for purposes of flow control and loss recovery. + sentInitial.size++ + sentInitial.inFlight = true + } + } + // If we're a client and this Initial packet is coalesced + // with a Handshake packet, then we've discarded Initial keys + // since constructing the packet and shouldn't record it as in-flight. + if c.keysInitial.canWrite() { + c.packetSent(now, initialSpace, sentInitial) + } + } + + c.endpoint.sendDatagram(datagram{ + b: buf, + peerAddr: c.peerAddr, + }) + } +} + +func (c *Conn) packetSent(now time.Time, space numberSpace, sent *sentPacket) { + c.idleHandlePacketSent(now, sent) + c.loss.packetSent(now, c.log, space, sent) +} + +func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, limit ccLimit) { + if c.lifetime.localErr != nil { + c.appendConnectionCloseFrame(now, space, c.lifetime.localErr) + return + } + + shouldSendAck := c.acks[space].shouldSendAck(now) + if limit != ccOK { + // ACKs are not limited by congestion control. + if shouldSendAck && c.appendAckFrame(now, space) { + c.acks[space].sentAck() + } + return + } + // We want to send an ACK frame if the ack controller wants to send a frame now, + // OR if we are sending a packet anyway and have ack-eliciting packets which we + // have not yet acked. + // + // We speculatively add ACK frames here, to put them at the front of the packet + // to avoid truncation. + // + // After adding all frames, if we don't need to send an ACK frame and have not + // added any other frames, we abandon the packet. + if c.appendAckFrame(now, space) { + defer func() { + // All frames other than ACK and PADDING are ack-eliciting, + // so if the packet is ack-eliciting we've added additional + // frames to it. + if !shouldSendAck && !c.w.sent.ackEliciting { + // There's nothing in this packet but ACK frames, and + // we don't want to send an ACK-only packet at this time. + // Abandoning the packet means we wrote an ACK frame for + // nothing, but constructing the frame is cheap. + c.w.abandonPacket() + return + } + // Either we are willing to send an ACK-only packet, + // or we've added additional frames. + c.acks[space].sentAck() + if !c.w.sent.ackEliciting && c.shouldMakePacketAckEliciting() { + c.w.appendPingFrame() + } + }() + } + if limit != ccOK { + return + } + pto := c.loss.ptoExpired + + // TODO: Add all the other frames we can send. + + // CRYPTO + c.crypto[space].dataToSend(pto, func(off, size int64) int64 { + b, _ := c.w.appendCryptoFrame(off, int(size)) + c.crypto[space].sendData(off, b) + return int64(len(b)) + }) + + // Test-only PING frames. + if space == c.testSendPingSpace && c.testSendPing.shouldSendPTO(pto) { + if !c.w.appendPingFrame() { + return + } + c.testSendPing.setSent(pnum) + } + + if space == appDataSpace { + // HANDSHAKE_DONE + if c.handshakeConfirmed.shouldSendPTO(pto) { + if !c.w.appendHandshakeDoneFrame() { + return + } + c.handshakeConfirmed.setSent(pnum) + } + + // NEW_CONNECTION_ID, RETIRE_CONNECTION_ID + if !c.connIDState.appendFrames(c, pnum, pto) { + return + } + + // PATH_RESPONSE + if pad, ok := c.appendPathFrames(); !ok { + return + } else if pad { + defer c.w.appendPaddingTo(smallestMaxDatagramSize) + } + + // All stream-related frames. This should come last in the packet, + // so large amounts of STREAM data don't crowd out other frames + // we may need to send. + if !c.appendStreamFrames(&c.w, pnum, pto) { + return + } + + if !c.appendKeepAlive(now) { + return + } + } + + // If this is a PTO probe and we haven't added an ack-eliciting frame yet, + // add a PING to make this an ack-eliciting probe. + // + // Technically, there are separate PTO timers for each number space. + // When a PTO timer expires, we MUST send an ack-eliciting packet in the + // timer's space. We SHOULD send ack-eliciting packets in every other space + // with in-flight data. (RFC 9002, section 6.2.4) + // + // What we actually do is send a single datagram containing an ack-eliciting packet + // for every space for which we have keys. + // + // We fill the PTO probe packets with new or unacknowledged data. For example, + // a PTO probe sent for the Initial space will generally retransmit previously + // sent but unacknowledged CRYPTO data. + // + // When sending a PTO probe datagram containing multiple packets, it is + // possible that an earlier packet will fill up the datagram, leaving no + // space for the remaining probe packet(s). This is not a problem in practice. + // + // A client discards Initial keys when it first sends a Handshake packet + // (RFC 9001 Section 4.9.1). Handshake keys are discarded when the handshake + // is confirmed (RFC 9001 Section 4.9.2). The PTO timer is not set for the + // Application Data packet number space until the handshake is confirmed + // (RFC 9002 Section 6.2.1). Therefore, the only times a PTO probe can fire + // while data for multiple spaces is in flight are: + // + // - a server's Initial or Handshake timers can fire while Initial and Handshake + // data is in flight; and + // + // - a client's Handshake timer can fire while Handshake and Application Data + // data is in flight. + // + // It is theoretically possible for a server's Initial CRYPTO data to overflow + // the maximum datagram size, but unlikely in practice; this space contains + // only the ServerHello TLS message, which is small. It's also unlikely that + // the Handshake PTO probe will fire while Initial data is in flight (this + // requires not just that the Initial CRYPTO data completely fill a datagram, + // but a quite specific arrangement of lost and retransmitted packets.) + // We don't bother worrying about this case here, since the worst case is + // that we send a PTO probe for the in-flight Initial data and drop the + // Handshake probe. + // + // If a client's Handshake PTO timer fires while Application Data data is in + // flight, it is possible that the resent Handshake CRYPTO data will crowd + // out the probe for the Application Data space. However, since this probe is + // optional (recall that the Application Data PTO timer is never set until + // after Handshake keys have been discarded), dropping it is acceptable. + if pto && !c.w.sent.ackEliciting { + c.w.appendPingFrame() + } +} + +// shouldMakePacketAckEliciting is called when sending a packet containing nothing but an ACK frame. +// It reports whether we should add a PING frame to the packet to make it ack-eliciting. +func (c *Conn) shouldMakePacketAckEliciting() bool { + if c.keysAppData.needAckEliciting() { + // The peer has initiated a key update. + // We haven't sent them any packets yet in the new phase. + // Make this an ack-eliciting packet. + // Their ack of this packet will complete the key update. + return true + } + if c.loss.consecutiveNonAckElicitingPackets >= 19 { + // We've sent a run of non-ack-eliciting packets. + // Add in an ack-eliciting one every once in a while so the peer + // lets us know which ones have arrived. + // + // Google QUICHE injects a PING after sending 19 packets. We do the same. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.4-2 + return true + } + // TODO: Consider making every packet sent when in PTO ack-eliciting to speed up recovery. + return false +} + +func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool { + seen, delay := c.acks[space].acksToSend(now) + if len(seen) == 0 { + return false + } + d := unscaledAckDelayFromDuration(delay, ackDelayExponent) + return c.w.appendAckFrame(seen, d, c.acks[space].ecn) +} + +func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err error) { + c.sentConnectionClose(now) + switch e := err.(type) { + case localTransportError: + c.w.appendConnectionCloseTransportFrame(e.code, 0, e.reason) + case *ApplicationError: + if space != appDataSpace { + // "CONNECTION_CLOSE frames signaling application errors (type 0x1d) + // MUST only appear in the application data packet number space." + // https://www.rfc-editor.org/rfc/rfc9000#section-12.5-2.2 + c.w.appendConnectionCloseTransportFrame(errApplicationError, 0, "") + } else { + c.w.appendConnectionCloseApplicationFrame(e.Code, e.Reason) + } + default: + // TLS alerts are sent using error codes [0x0100,0x01ff). + // https://www.rfc-editor.org/rfc/rfc9000#section-20.1-2.36.1 + var alert tls.AlertError + switch { + case errors.As(err, &alert): + // tls.AlertError is a uint8, so this can't exceed 0x01ff. + code := errTLSBase + transportError(alert) + c.w.appendConnectionCloseTransportFrame(code, 0, "") + default: + c.w.appendConnectionCloseTransportFrame(errInternal, 0, "") + } + } +} diff --git a/src/vendor/golang.org/x/net/quic/conn_streams.go b/src/vendor/golang.org/x/net/quic/conn_streams.go new file mode 100644 index 0000000000..0e4bf50094 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_streams.go @@ -0,0 +1,483 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +type streamsState struct { + queue queue[*Stream] // new, peer-created streams + + // All peer-created streams. + // + // Implicitly created streams are included as an empty entry in the map. + // (For example, if we receive a frame for stream 4, we implicitly create stream 0 and + // insert an empty entry for it to the map.) + // + // The map value is maybeStream rather than *Stream as a reminder that values can be nil. + streams map[streamID]maybeStream + + // Limits on the number of streams, indexed by streamType. + localLimit [streamTypeCount]localStreamLimits + remoteLimit [streamTypeCount]remoteStreamLimits + + // Peer configuration provided in transport parameters. + peerInitialMaxStreamDataRemote [streamTypeCount]int64 // streams opened by us + peerInitialMaxStreamDataBidiLocal int64 // streams opened by them + + // Connection-level flow control. + inflow connInflow + outflow connOutflow + + // Streams with frames to send are stored in one of two circular linked lists, + // depending on whether they require connection-level flow control. + needSend atomic.Bool + sendMu sync.Mutex + queueMeta streamRing // streams with any non-flow-controlled frames + queueData streamRing // streams with only flow-controlled frames +} + +// maybeStream is a possibly nil *Stream. See streamsState.streams. +type maybeStream struct { + s *Stream +} + +func (c *Conn) streamsInit() { + c.streams.streams = make(map[streamID]maybeStream) + c.streams.queue = newQueue[*Stream]() + c.streams.localLimit[bidiStream].init() + c.streams.localLimit[uniStream].init() + c.streams.remoteLimit[bidiStream].init(c.config.maxBidiRemoteStreams()) + c.streams.remoteLimit[uniStream].init(c.config.maxUniRemoteStreams()) + c.inflowInit() +} + +func (c *Conn) streamsCleanup() { + c.streams.queue.close(errConnClosed) + c.streams.localLimit[bidiStream].connHasClosed() + c.streams.localLimit[uniStream].connHasClosed() + for _, s := range c.streams.streams { + if s.s != nil { + s.s.connHasClosed() + } + } +} + +// AcceptStream waits for and returns the next stream created by the peer. +func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) { + return c.streams.queue.get(ctx) +} + +// NewStream creates a stream. +// +// If the peer's maximum stream limit for the connection has been reached, +// NewStream blocks until the limit is increased or the context expires. +func (c *Conn) NewStream(ctx context.Context) (*Stream, error) { + return c.newLocalStream(ctx, bidiStream) +} + +// NewSendOnlyStream creates a unidirectional, send-only stream. +// +// If the peer's maximum stream limit for the connection has been reached, +// NewSendOnlyStream blocks until the limit is increased or the context expires. +func (c *Conn) NewSendOnlyStream(ctx context.Context) (*Stream, error) { + return c.newLocalStream(ctx, uniStream) +} + +func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, error) { + num, err := c.streams.localLimit[styp].open(ctx, c) + if err != nil { + return nil, err + } + + s := newStream(c, newStreamID(c.side, styp, num)) + s.outmaxbuf = c.config.maxStreamWriteBufferSize() + s.outwin = c.streams.peerInitialMaxStreamDataRemote[styp] + if styp == bidiStream { + s.inmaxbuf = c.config.maxStreamReadBufferSize() + s.inwin = c.config.maxStreamReadBufferSize() + } + s.inUnlock() + s.outUnlock() + + // Modify c.streams on the conn's loop. + if err := c.runOnLoop(ctx, func(now time.Time, c *Conn) { + c.streams.streams[s.id] = maybeStream{s} + }); err != nil { + return nil, err + } + return s, nil +} + +// streamFrameType identifies which direction of a stream, +// from the local perspective, a frame is associated with. +// +// For example, STREAM is a recvStream frame, +// because it carries data from the peer to us. +type streamFrameType uint8 + +const ( + sendStream = streamFrameType(iota) // for example, MAX_DATA + recvStream // for example, STREAM_DATA_BLOCKED +) + +// streamForID returns the stream with the given id. +// If the stream does not exist, it returns nil. +func (c *Conn) streamForID(id streamID) *Stream { + return c.streams.streams[id].s +} + +// streamForFrame returns the stream with the given id. +// If the stream does not exist, it may be created. +// +// streamForFrame aborts the connection if the stream id, state, and frame type don't align. +// For example, it aborts the connection with a STREAM_STATE error if a MAX_DATA frame +// is received for a receive-only stream, or if the peer attempts to create a stream that +// should be originated locally. +// +// streamForFrame returns nil if the stream no longer exists or if an error occurred. +func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) *Stream { + if id.streamType() == uniStream { + if (id.initiator() == c.side) != (ftype == sendStream) { + // Received an invalid frame for unidirectional stream. + // For example, a RESET_STREAM frame for a send-only stream. + c.abort(now, localTransportError{ + code: errStreamState, + reason: "invalid frame for unidirectional stream", + }) + return nil + } + } + + ms, isOpen := c.streams.streams[id] + if ms.s != nil { + return ms.s + } + + num := id.num() + styp := id.streamType() + if id.initiator() == c.side { + if num < c.streams.localLimit[styp].opened { + // This stream was created by us, and has been closed. + return nil + } + // Received a frame for a stream that should be originated by us, + // but which we never created. + c.abort(now, localTransportError{ + code: errStreamState, + reason: "received frame for unknown stream", + }) + return nil + } else { + // if isOpen, this is a stream that was implicitly opened by a + // previous frame for a larger-numbered stream, but we haven't + // actually created it yet. + if !isOpen && num < c.streams.remoteLimit[styp].opened { + // This stream was created by the peer, and has been closed. + return nil + } + } + + prevOpened := c.streams.remoteLimit[styp].opened + if err := c.streams.remoteLimit[styp].open(id); err != nil { + c.abort(now, err) + return nil + } + + // Receiving a frame for a stream implicitly creates all streams + // with the same initiator and type and a lower number. + // Add a nil entry to the streams map for each implicitly created stream. + for n := newStreamID(id.initiator(), id.streamType(), prevOpened); n < id; n += 4 { + c.streams.streams[n] = maybeStream{} + } + + s := newStream(c, id) + s.inmaxbuf = c.config.maxStreamReadBufferSize() + s.inwin = c.config.maxStreamReadBufferSize() + if id.streamType() == bidiStream { + s.outmaxbuf = c.config.maxStreamWriteBufferSize() + s.outwin = c.streams.peerInitialMaxStreamDataBidiLocal + } + s.inUnlock() + s.outUnlock() + + c.streams.streams[id] = maybeStream{s} + c.streams.queue.put(s) + return s +} + +// maybeQueueStreamForSend marks a stream as containing frames that need sending. +func (c *Conn) maybeQueueStreamForSend(s *Stream, state streamState) { + if state.wantQueue() == state.inQueue() { + return // already on the right queue + } + c.streams.sendMu.Lock() + defer c.streams.sendMu.Unlock() + state = s.state.load() // may have changed while waiting + c.queueStreamForSendLocked(s, state) + + c.streams.needSend.Store(true) + c.wake() +} + +// queueStreamForSendLocked moves a stream to the correct send queue, +// or removes it from all queues. +// +// state is the last known stream state. +func (c *Conn) queueStreamForSendLocked(s *Stream, state streamState) { + for { + wantQueue := state.wantQueue() + inQueue := state.inQueue() + if inQueue == wantQueue { + return // already on the right queue + } + + switch inQueue { + case metaQueue: + c.streams.queueMeta.remove(s) + case dataQueue: + c.streams.queueData.remove(s) + } + + switch wantQueue { + case metaQueue: + c.streams.queueMeta.append(s) + state = s.state.set(streamQueueMeta, streamQueueMeta|streamQueueData) + case dataQueue: + c.streams.queueData.append(s) + state = s.state.set(streamQueueData, streamQueueMeta|streamQueueData) + case noQueue: + state = s.state.set(0, streamQueueMeta|streamQueueData) + } + + // If the stream state changed while we were moving the stream, + // we might now be on the wrong queue. + // + // For example: + // - stream has data to send: streamOutSendData|streamQueueData + // - appendStreamFrames sends all the data: streamQueueData + // - concurrently, more data is written: streamOutSendData|streamQueueData + // - appendStreamFrames calls us with the last state it observed + // (streamQueueData). + // - We remove the stream from the queue and observe the updated state: + // streamOutSendData + // - We realize that the stream needs to go back on the data queue. + // + // Go back around the loop to confirm we're on the correct queue. + } +} + +// appendStreamFrames writes stream-related frames to the current packet. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) bool { + // MAX_DATA + if !c.appendMaxDataFrame(w, pnum, pto) { + return false + } + + if pto { + return c.appendStreamFramesPTO(w, pnum) + } + if !c.streams.needSend.Load() { + // If queueMeta includes newly-finished streams, we may extend the peer's + // stream limits. When there are no streams to process, add MAX_STREAMS + // frames here. Otherwise, wait until after we've processed queueMeta. + return c.appendMaxStreams(w, pnum, pto) + } + c.streams.sendMu.Lock() + defer c.streams.sendMu.Unlock() + // queueMeta contains streams with non-flow-controlled frames to send. + for c.streams.queueMeta.head != nil { + s := c.streams.queueMeta.head + state := s.state.load() + if state&(streamQueueMeta|streamConnRemoved) != streamQueueMeta { + panic("BUG: queueMeta stream is not streamQueueMeta") + } + if state&streamInSendMeta != 0 { + s.ingate.lock() + ok := s.appendInFramesLocked(w, pnum, pto) + state = s.inUnlockNoQueue() + if !ok { + return false + } + if state&streamInSendMeta != 0 { + panic("BUG: streamInSendMeta set after successfully appending frames") + } + } + if state&streamOutSendMeta != 0 { + s.outgate.lock() + // This might also append flow-controlled frames if we have any + // and available conn-level quota. That's fine. + ok := s.appendOutFramesLocked(w, pnum, pto) + state = s.outUnlockNoQueue() + // We're checking both ok and state, because appendOutFramesLocked + // might have filled up the packet with flow-controlled data. + // If so, we want to move the stream to queueData for any remaining frames. + if !ok && state&streamOutSendMeta != 0 { + return false + } + if state&streamOutSendMeta != 0 { + panic("BUG: streamOutSendMeta set after successfully appending frames") + } + } + // We've sent all frames for this stream, so remove it from the send queue. + c.streams.queueMeta.remove(s) + if state&(streamInDone|streamOutDone) == streamInDone|streamOutDone { + // Stream is finished, remove it from the conn. + state = s.state.set(streamConnRemoved, streamQueueMeta|streamConnRemoved) + delete(c.streams.streams, s.id) + + // Record finalization of remote streams, to know when + // to extend the peer's stream limit. + if s.id.initiator() != c.side { + c.streams.remoteLimit[s.id.streamType()].close() + } + } else { + state = s.state.set(0, streamQueueMeta|streamConnRemoved) + } + // The stream may have flow-controlled data to send, + // or something might have added non-flow-controlled frames after we + // unlocked the stream. + // If so, put the stream back on a queue. + c.queueStreamForSendLocked(s, state) + } + + // MAX_STREAMS (possibly triggered by finalization of remote streams above). + if !c.appendMaxStreams(w, pnum, pto) { + return false + } + + // queueData contains streams with flow-controlled frames. + for c.streams.queueData.head != nil { + avail := c.streams.outflow.avail() + if avail == 0 { + break // no flow control quota available + } + s := c.streams.queueData.head + s.outgate.lock() + ok := s.appendOutFramesLocked(w, pnum, pto) + state := s.outUnlockNoQueue() + if !ok { + // We've sent some data for this stream, but it still has more to send. + // If the stream got a reasonable chance to put data in a packet, + // advance sendHead to the next stream in line, to avoid starvation. + // We'll come back to this stream after going through the others. + // + // If the packet was already mostly out of space, leave sendHead alone + // and come back to this stream again on the next packet. + if avail > 512 { + c.streams.queueData.head = s.next + } + return false + } + if state&streamQueueData == 0 { + panic("BUG: queueData stream is not streamQueueData") + } + if state&streamOutSendData != 0 { + // We must have run out of connection-level flow control: + // appendOutFramesLocked says it wrote all it can, but there's + // still data to send. + // + // Advance sendHead to the next stream in line to avoid starvation. + if c.streams.outflow.avail() != 0 { + panic("BUG: streamOutSendData set and flow control available after send") + } + c.streams.queueData.head = s.next + return true + } + c.streams.queueData.remove(s) + state = s.state.set(0, streamQueueData) + c.queueStreamForSendLocked(s, state) + } + if c.streams.queueMeta.head == nil && c.streams.queueData.head == nil { + c.streams.needSend.Store(false) + } + return true +} + +// appendStreamFramesPTO writes stream-related frames to the current packet +// for a PTO probe. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { + const pto = true + if !c.appendMaxStreams(w, pnum, pto) { + return false + } + c.streams.sendMu.Lock() + defer c.streams.sendMu.Unlock() + for _, ms := range c.streams.streams { + s := ms.s + if s == nil { + continue + } + const pto = true + s.ingate.lock() + inOK := s.appendInFramesLocked(w, pnum, pto) + s.inUnlockNoQueue() + if !inOK { + return false + } + + s.outgate.lock() + outOK := s.appendOutFramesLocked(w, pnum, pto) + s.outUnlockNoQueue() + if !outOK { + return false + } + } + return true +} + +func (c *Conn) appendMaxStreams(w *packetWriter, pnum packetNumber, pto bool) bool { + if !c.streams.remoteLimit[uniStream].appendFrame(w, uniStream, pnum, pto) { + return false + } + if !c.streams.remoteLimit[bidiStream].appendFrame(w, bidiStream, pnum, pto) { + return false + } + return true +} + +// A streamRing is a circular linked list of streams. +type streamRing struct { + head *Stream +} + +// remove removes s from the ring. +// s must be on the ring. +func (r *streamRing) remove(s *Stream) { + if s.next == s { + r.head = nil // s was the last stream in the ring + } else { + s.prev.next = s.next + s.next.prev = s.prev + if r.head == s { + r.head = s.next + } + } +} + +// append places s at the last position in the ring. +// s must not be attached to any ring. +func (r *streamRing) append(s *Stream) { + if r.head == nil { + r.head = s + s.next = s + s.prev = s + } else { + s.prev = r.head.prev + s.next = r.head + s.prev.next = s + s.next.prev = s + } +} diff --git a/src/vendor/golang.org/x/net/quic/crypto_stream.go b/src/vendor/golang.org/x/net/quic/crypto_stream.go new file mode 100644 index 0000000000..a5b9818296 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/crypto_stream.go @@ -0,0 +1,157 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// "Implementations MUST support buffering at least 4096 bytes of data +// received in out-of-order CRYPTO frames." +// https://www.rfc-editor.org/rfc/rfc9000.html#section-7.5-2 +// +// 4096 is too small for real-world cases, however, so we allow more. +const cryptoBufferSize = 1 << 20 + +// A cryptoStream is the stream of data passed in CRYPTO frames. +// There is one cryptoStream per packet number space. +type cryptoStream struct { + // CRYPTO data received from the peer. + in pipe + inset rangeset[int64] // bytes received + + // CRYPTO data queued for transmission to the peer. + out pipe + outunsent rangeset[int64] // bytes in need of sending + outacked rangeset[int64] // bytes acked by peer +} + +// handleCrypto processes data received in a CRYPTO frame. +func (s *cryptoStream) handleCrypto(off int64, b []byte, f func([]byte) error) error { + end := off + int64(len(b)) + if end-s.inset.min() > cryptoBufferSize { + return localTransportError{ + code: errCryptoBufferExceeded, + reason: "crypto buffer exceeded", + } + } + s.inset.add(off, end) + if off == s.in.start { + // Fast path: This is the next chunk of data in the stream, + // so just handle it immediately. + if err := f(b); err != nil { + return err + } + s.in.discardBefore(end) + } else { + // This is either data we've already processed, + // data we can't process yet, or a mix of both. + s.in.writeAt(b, off) + } + // s.in.start is the next byte in sequence. + // If it's in s.inset, we have bytes to provide. + // If it isn't, we don't--we're either out of data, + // or only have data that comes after the next byte. + if !s.inset.contains(s.in.start) { + return nil + } + // size is the size of the first contiguous chunk of bytes + // that have not been processed yet. + size := int(s.inset[0].end - s.in.start) + if size <= 0 { + return nil + } + err := s.in.read(s.in.start, size, f) + s.in.discardBefore(s.inset[0].end) + return err +} + +// write queues data for sending to the peer. +// It does not block or limit the amount of buffered data. +// QUIC connections don't communicate the amount of CRYPTO data they are willing to buffer, +// so we send what we have and the peer can close the connection if it is too much. +func (s *cryptoStream) write(b []byte) { + start := s.out.end + s.out.writeAt(b, start) + s.outunsent.add(start, s.out.end) +} + +// ackOrLoss reports that an CRYPTO frame sent by us has been acknowledged by the peer, or lost. +func (s *cryptoStream) ackOrLoss(start, end int64, fate packetFate) { + switch fate { + case packetAcked: + s.outacked.add(start, end) + s.outunsent.sub(start, end) + // If this ack is for data at the start of the send buffer, we can now discard it. + if s.outacked.contains(s.out.start) { + s.out.discardBefore(s.outacked[0].end) + } + case packetLost: + // Mark everything lost, but not previously acked, as needing retransmission. + // We do this by adding all the lost bytes to outunsent, and then + // removing everything already acked. + s.outunsent.add(start, end) + for _, a := range s.outacked { + s.outunsent.sub(a.start, a.end) + } + } +} + +// dataToSend reports what data should be sent in CRYPTO frames to the peer. +// It calls f with each range of data to send. +// f uses sendData to get the bytes to send, and returns the number of bytes sent. +// dataToSend calls f until no data is left, or f returns 0. +// +// This function is unusually indirect (why not just return a []byte, +// or implement io.Reader?). +// +// Returning a []byte to the caller either requires that we store the +// data to send contiguously (which we don't), allocate a temporary buffer +// and copy into it (inefficient), or return less data than we have available +// (requires complexity to avoid unnecessarily breaking data across frames). +// +// Accepting a []byte from the caller (io.Reader) makes packet construction +// difficult. Since CRYPTO data is encoded with a varint length prefix, the +// location of the data depends on the length of the data. (We could hardcode +// a 2-byte length, of course.) +// +// Instead, we tell the caller how much data is, the caller figures out where +// to put it (and possibly decides that it doesn't have space for this data +// in the packet after all), and the caller then makes a separate call to +// copy the data it wants into position. +func (s *cryptoStream) dataToSend(pto bool, f func(off, size int64) (sent int64)) { + for { + off, size := dataToSend(s.out.start, s.out.end, s.outunsent, s.outacked, pto) + if size == 0 { + return + } + n := f(off, size) + if n == 0 || pto { + return + } + } +} + +// sendData fills b with data to send to the peer, starting at off, +// and marks the data as sent. The caller must have already ascertained +// that there is data to send in this region using dataToSend. +func (s *cryptoStream) sendData(off int64, b []byte) { + s.out.copy(off, b) + s.outunsent.sub(off, off+int64(len(b))) +} + +// discardKeys is called when the packet protection keys for the stream are dropped. +func (s *cryptoStream) discardKeys() error { + if s.in.end-s.in.start != 0 { + // The peer sent some unprocessed CRYPTO data that we're about to discard. + // Close the connection with a TLS unexpected_message alert. + // https://www.rfc-editor.org/rfc/rfc5246#section-7.2.2 + const unexpectedMessage = 10 + return localTransportError{ + code: errTLSBase + unexpectedMessage, + reason: "excess crypto data", + } + } + // Discard any unacked (but presumably received) data in our output buffer. + s.out.discardBefore(s.out.end) + *s = cryptoStream{} + return nil +} diff --git a/src/vendor/golang.org/x/net/quic/dgram.go b/src/vendor/golang.org/x/net/quic/dgram.go new file mode 100644 index 0000000000..cea03694ee --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/dgram.go @@ -0,0 +1,53 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "net/netip" + "sync" +) + +type datagram struct { + b []byte + localAddr netip.AddrPort + peerAddr netip.AddrPort + ecn ecnBits +} + +// Explicit Congestion Notification bits. +// +// https://www.rfc-editor.org/rfc/rfc3168.html#section-5 +type ecnBits byte + +const ( + ecnMask = 0b000000_11 + ecnNotECT = 0b000000_00 + ecnECT1 = 0b000000_01 + ecnECT0 = 0b000000_10 + ecnCE = 0b000000_11 +) + +var datagramPool = sync.Pool{ + New: func() any { + return &datagram{ + b: make([]byte, maxUDPPayloadSize), + } + }, +} + +func newDatagram() *datagram { + m := datagramPool.Get().(*datagram) + *m = datagram{ + b: m.b[:cap(m.b)], + } + return m +} + +func (m *datagram) recycle() { + if cap(m.b) != maxUDPPayloadSize { + return + } + datagramPool.Put(m) +} diff --git a/src/vendor/golang.org/x/net/quic/doc.go b/src/vendor/golang.org/x/net/quic/doc.go new file mode 100644 index 0000000000..8d5a78f8a8 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/doc.go @@ -0,0 +1,50 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package quic implements the QUIC protocol. +// +// This package is a work in progress. +// It is not ready for production usage. +// Its API is subject to change without notice. +// +// This package is low-level. +// Most users will use it indirectly through an HTTP/3 implementation. +// +// # Usage +// +// An [Endpoint] sends and receives traffic on a network address. +// Create an Endpoint to either accept inbound QUIC connections +// or create outbound ones. +// +// A [Conn] is a QUIC connection. +// +// A [Stream] is a QUIC stream, an ordered, reliable byte stream. +// +// # Cancellation +// +// All blocking operations may be canceled using a context.Context. +// When performing an operation with a canceled context, the operation +// will succeed if doing so does not require blocking. For example, +// reading from a stream will return data when buffered data is available, +// even if the stream context is canceled. +// +// # Limitations +// +// This package is a work in progress. +// Known limitations include: +// +// - Performance is untuned. +// - 0-RTT is not supported. +// - Address migration is not supported. +// - Server preferred addresses are not supported. +// - The latency spin bit is not supported. +// - Stream send/receive windows are configurable, +// but are fixed and do not adapt to available throughput. +// - Path MTU discovery is not implemented. +// +// # Security Policy +// +// This package is a work in progress, +// and not yet covered by the Go security policy. +package quic diff --git a/src/vendor/golang.org/x/net/quic/endpoint.go b/src/vendor/golang.org/x/net/quic/endpoint.go new file mode 100644 index 0000000000..3d68073cd6 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/endpoint.go @@ -0,0 +1,472 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "crypto/rand" + "errors" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" +) + +// An Endpoint handles QUIC traffic on a network address. +// It can accept inbound connections or create outbound ones. +// +// Multiple goroutines may invoke methods on an Endpoint simultaneously. +type Endpoint struct { + listenConfig *Config + packetConn packetConn + testHooks endpointTestHooks + resetGen statelessResetTokenGenerator + retry retryState + + acceptQueue queue[*Conn] // new inbound connections + connsMap connsMap // only accessed by the listen loop + + connsMu sync.Mutex + conns map[*Conn]struct{} + closing bool // set when Close is called + closec chan struct{} // closed when the listen loop exits +} + +type endpointTestHooks interface { + newConn(c *Conn) +} + +// A packetConn is the interface to sending and receiving UDP packets. +type packetConn interface { + Close() error + LocalAddr() netip.AddrPort + Read(f func(*datagram)) + Write(datagram) error +} + +// Listen listens on a local network address. +// +// The config is used to for connections accepted by the endpoint. +// If the config is nil, the endpoint will not accept connections. +func Listen(network, address string, listenConfig *Config) (*Endpoint, error) { + if listenConfig != nil && listenConfig.TLSConfig == nil { + return nil, errors.New("TLSConfig is not set") + } + a, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + udpConn, err := net.ListenUDP(network, a) + if err != nil { + return nil, err + } + pc, err := newNetUDPConn(udpConn) + if err != nil { + return nil, err + } + return newEndpoint(pc, listenConfig, nil) +} + +// NewEndpoint creates an endpoint using a net.PacketConn as the underlying transport. +// +// If the PacketConn is not a *net.UDPConn, the endpoint may be slower and lack +// access to some features of the network. +func NewEndpoint(conn net.PacketConn, config *Config) (*Endpoint, error) { + var pc packetConn + var err error + switch conn := conn.(type) { + case *net.UDPConn: + pc, err = newNetUDPConn(conn) + default: + pc, err = newNetPacketConn(conn) + } + if err != nil { + return nil, err + } + return newEndpoint(pc, config, nil) +} + +func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { + e := &Endpoint{ + listenConfig: config, + packetConn: pc, + testHooks: hooks, + conns: make(map[*Conn]struct{}), + acceptQueue: newQueue[*Conn](), + closec: make(chan struct{}), + } + var statelessResetKey [32]byte + if config != nil { + statelessResetKey = config.StatelessResetKey + } + e.resetGen.init(statelessResetKey) + e.connsMap.init() + if config != nil && config.RequireAddressValidation { + if err := e.retry.init(); err != nil { + return nil, err + } + } + go e.listen() + return e, nil +} + +// LocalAddr returns the local network address. +func (e *Endpoint) LocalAddr() netip.AddrPort { + return e.packetConn.LocalAddr() +} + +// Close closes the Endpoint. +// Any blocked operations on the Endpoint or associated Conns and Stream will be unblocked +// and return errors. +// +// Close aborts every open connection. +// Data in stream read and write buffers is discarded. +// It waits for the peers of any open connection to acknowledge the connection has been closed. +func (e *Endpoint) Close(ctx context.Context) error { + e.acceptQueue.close(errors.New("endpoint closed")) + + // It isn't safe to call Conn.Abort or conn.exit with connsMu held, + // so copy the list of conns. + var conns []*Conn + e.connsMu.Lock() + if !e.closing { + e.closing = true // setting e.closing prevents new conns from being created + for c := range e.conns { + conns = append(conns, c) + } + if len(e.conns) == 0 { + e.packetConn.Close() + } + } + e.connsMu.Unlock() + + for _, c := range conns { + c.Abort(localTransportError{code: errNo}) + } + select { + case <-e.closec: + case <-ctx.Done(): + for _, c := range conns { + c.exit() + } + return ctx.Err() + } + return nil +} + +// Accept waits for and returns the next connection. +func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) { + return e.acceptQueue.get(ctx) +} + +// Dial creates and returns a connection to a network address. +// The config cannot be nil. +func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) { + u, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + addr := u.AddrPort() + addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) + c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr) + if err != nil { + return nil, err + } + if err := c.waitReady(ctx); err != nil { + c.Abort(nil) + return nil, err + } + return c, nil +} + +func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) { + e.connsMu.Lock() + defer e.connsMu.Unlock() + if e.closing { + return nil, errors.New("endpoint closed") + } + c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e) + if err != nil { + return nil, err + } + e.conns[c] = struct{}{} + return c, nil +} + +// serverConnEstablished is called by a conn when the handshake completes +// for an inbound (serverSide) connection. +func (e *Endpoint) serverConnEstablished(c *Conn) { + e.acceptQueue.put(c) +} + +// connDrained is called by a conn when it leaves the draining state, +// either when the peer acknowledges connection closure or the drain timeout expires. +func (e *Endpoint) connDrained(c *Conn) { + var cids [][]byte + for i := range c.connIDState.local { + cids = append(cids, c.connIDState.local[i].cid) + } + var tokens []statelessResetToken + for i := range c.connIDState.remote { + tokens = append(tokens, c.connIDState.remote[i].resetToken) + } + e.connsMap.updateConnIDs(func(conns *connsMap) { + for _, cid := range cids { + conns.retireConnID(c, cid) + } + for _, token := range tokens { + conns.retireResetToken(c, token) + } + }) + e.connsMu.Lock() + defer e.connsMu.Unlock() + delete(e.conns, c) + if e.closing && len(e.conns) == 0 { + e.packetConn.Close() + } +} + +func (e *Endpoint) listen() { + defer close(e.closec) + e.packetConn.Read(func(m *datagram) { + if e.connsMap.updateNeeded.Load() { + e.connsMap.applyUpdates() + } + e.handleDatagram(m) + }) +} + +func (e *Endpoint) handleDatagram(m *datagram) { + dstConnID, ok := dstConnIDForDatagram(m.b) + if !ok { + m.recycle() + return + } + c := e.connsMap.byConnID[string(dstConnID)] + if c == nil { + // TODO: Move this branch into a separate goroutine to avoid blocking + // the endpoint while processing packets. + e.handleUnknownDestinationDatagram(m) + return + } + + // TODO: This can block the endpoint while waiting for the conn to accept the dgram. + // Think about buffering between the receive loop and the conn. + c.sendMsg(m) +} + +func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { + defer func() { + if m != nil { + m.recycle() + } + }() + const minimumValidPacketSize = 21 + if len(m.b) < minimumValidPacketSize { + return + } + now := time.Now() + // Check to see if this is a stateless reset. + var token statelessResetToken + copy(token[:], m.b[len(m.b)-len(token):]) + if c := e.connsMap.byResetToken[token]; c != nil { + c.sendMsg(func(now time.Time, c *Conn) { + c.handleStatelessReset(now, token) + }) + return + } + // If this is a 1-RTT packet, there's nothing productive we can do with it. + // Send a stateless reset if possible. + if !isLongHeader(m.b[0]) { + e.maybeSendStatelessReset(m.b, m.peerAddr) + return + } + p, ok := parseGenericLongHeaderPacket(m.b) + if !ok || len(m.b) < paddedInitialDatagramSize { + return + } + switch p.version { + case quicVersion1: + case 0: + // Version Negotiation for an unknown connection. + return + default: + // Unknown version. + e.sendVersionNegotiation(p, m.peerAddr) + return + } + if getPacketType(m.b) != packetTypeInitial { + // This packet isn't trying to create a new connection. + // It might be associated with some connection we've lost state for. + // We are technically permitted to send a stateless reset for + // a long-header packet, but this isn't generally useful. See: + // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16 + return + } + if e.listenConfig == nil { + // We are not configured to accept connections. + return + } + cids := newServerConnIDs{ + srcConnID: p.srcConnID, + dstConnID: p.dstConnID, + } + if e.listenConfig.RequireAddressValidation { + var ok bool + cids.retrySrcConnID = p.dstConnID + cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr) + if !ok { + return + } + } else { + cids.originalDstConnID = p.dstConnID + } + var err error + c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr) + if err != nil { + // The accept queue is probably full. + // We could send a CONNECTION_CLOSE to the peer to reject the connection. + // Currently, we just drop the datagram. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5 + return + } + c.sendMsg(m) + m = nil // don't recycle, sendMsg takes ownership +} + +func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) { + if !e.resetGen.canReset { + // Config.StatelessResetKey isn't set, so we don't send stateless resets. + return + } + // The smallest possible valid packet a peer can send us is: + // 1 byte of header + // connIDLen bytes of destination connection ID + // 1 byte of packet number + // 1 byte of payload + // 16 bytes AEAD expansion + if len(b) < 1+connIDLen+1+1+16 { + return + } + // TODO: Rate limit stateless resets. + cid := b[1:][:connIDLen] + token := e.resetGen.tokenForConnID(cid) + // We want to generate a stateless reset that is as short as possible, + // but long enough to be difficult to distinguish from a 1-RTT packet. + // + // The minimal 1-RTT packet is: + // 1 byte of header + // 0-20 bytes of destination connection ID + // 1-4 bytes of packet number + // 1 byte of payload + // 16 bytes AEAD expansion + // + // Assuming the maximum possible connection ID and packet number size, + // this gives 1 + 20 + 4 + 1 + 16 = 42 bytes. + // + // We also must generate a stateless reset that is shorter than the datagram + // we are responding to, in order to ensure that reset loops terminate. + // + // See: https://www.rfc-editor.org/rfc/rfc9000#section-10.3 + size := min(len(b)-1, 42) + // Reuse the input buffer for generating the stateless reset. + b = b[:size] + rand.Read(b[:len(b)-statelessResetTokenLen]) + b[0] &^= headerFormLong // clear long header bit + b[0] |= fixedBit // set fixed bit + copy(b[len(b)-statelessResetTokenLen:], token[:]) + e.sendDatagram(datagram{ + b: b, + peerAddr: peerAddr, + }) +} + +func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) { + m := newDatagram() + m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1) + m.peerAddr = peerAddr + e.sendDatagram(*m) + m.recycle() +} + +func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) { + keys := initialKeys(in.dstConnID, serverSide) + var w packetWriter + p := longPacket{ + ptype: packetTypeInitial, + version: quicVersion1, + num: 0, + dstConnID: in.srcConnID, + srcConnID: in.dstConnID, + } + const pnumMaxAcked = 0 + w.reset(paddedInitialDatagramSize) + w.startProtectedLongHeaderPacket(pnumMaxAcked, p) + w.appendConnectionCloseTransportFrame(code, 0, "") + w.finishProtectedLongHeaderPacket(pnumMaxAcked, keys.w, p) + buf := w.datagram() + if len(buf) == 0 { + return + } + e.sendDatagram(datagram{ + b: buf, + peerAddr: peerAddr, + }) +} + +func (e *Endpoint) sendDatagram(dgram datagram) error { + return e.packetConn.Write(dgram) +} + +// A connsMap is an endpoint's mapping of conn ids and reset tokens to conns. +type connsMap struct { + byConnID map[string]*Conn + byResetToken map[statelessResetToken]*Conn + + updateMu sync.Mutex + updateNeeded atomic.Bool + updates []func(*connsMap) +} + +func (m *connsMap) init() { + m.byConnID = map[string]*Conn{} + m.byResetToken = map[statelessResetToken]*Conn{} +} + +func (m *connsMap) addConnID(c *Conn, cid []byte) { + m.byConnID[string(cid)] = c +} + +func (m *connsMap) retireConnID(c *Conn, cid []byte) { + delete(m.byConnID, string(cid)) +} + +func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) { + m.byResetToken[token] = c +} + +func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) { + delete(m.byResetToken, token) +} + +func (m *connsMap) updateConnIDs(f func(*connsMap)) { + m.updateMu.Lock() + defer m.updateMu.Unlock() + m.updates = append(m.updates, f) + m.updateNeeded.Store(true) +} + +// applyUpdates is called by the datagram receive loop to update its connection ID map. +func (m *connsMap) applyUpdates() { + m.updateMu.Lock() + defer m.updateMu.Unlock() + for _, f := range m.updates { + f(m) + } + clear(m.updates) + m.updates = m.updates[:0] + m.updateNeeded.Store(false) +} diff --git a/src/vendor/golang.org/x/net/quic/errors.go b/src/vendor/golang.org/x/net/quic/errors.go new file mode 100644 index 0000000000..1226370d26 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/errors.go @@ -0,0 +1,129 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "fmt" +) + +// A transportError is a transport error code from RFC 9000 Section 20.1. +// +// The transportError type doesn't implement the error interface to ensure we always +// distinguish between errors sent to and received from the peer. +// See the localTransportError and peerTransportError types below. +type transportError uint64 + +// https://www.rfc-editor.org/rfc/rfc9000.html#section-20.1 +const ( + errNo = transportError(0x00) + errInternal = transportError(0x01) + errConnectionRefused = transportError(0x02) + errFlowControl = transportError(0x03) + errStreamLimit = transportError(0x04) + errStreamState = transportError(0x05) + errFinalSize = transportError(0x06) + errFrameEncoding = transportError(0x07) + errTransportParameter = transportError(0x08) + errConnectionIDLimit = transportError(0x09) + errProtocolViolation = transportError(0x0a) + errInvalidToken = transportError(0x0b) + errApplicationError = transportError(0x0c) + errCryptoBufferExceeded = transportError(0x0d) + errKeyUpdateError = transportError(0x0e) + errAEADLimitReached = transportError(0x0f) + errNoViablePath = transportError(0x10) + errTLSBase = transportError(0x0100) // 0x0100-0x01ff; base + TLS code +) + +func (e transportError) String() string { + switch e { + case errNo: + return "NO_ERROR" + case errInternal: + return "INTERNAL_ERROR" + case errConnectionRefused: + return "CONNECTION_REFUSED" + case errFlowControl: + return "FLOW_CONTROL_ERROR" + case errStreamLimit: + return "STREAM_LIMIT_ERROR" + case errStreamState: + return "STREAM_STATE_ERROR" + case errFinalSize: + return "FINAL_SIZE_ERROR" + case errFrameEncoding: + return "FRAME_ENCODING_ERROR" + case errTransportParameter: + return "TRANSPORT_PARAMETER_ERROR" + case errConnectionIDLimit: + return "CONNECTION_ID_LIMIT_ERROR" + case errProtocolViolation: + return "PROTOCOL_VIOLATION" + case errInvalidToken: + return "INVALID_TOKEN" + case errApplicationError: + return "APPLICATION_ERROR" + case errCryptoBufferExceeded: + return "CRYPTO_BUFFER_EXCEEDED" + case errKeyUpdateError: + return "KEY_UPDATE_ERROR" + case errAEADLimitReached: + return "AEAD_LIMIT_REACHED" + case errNoViablePath: + return "NO_VIABLE_PATH" + } + if e >= 0x0100 && e <= 0x01ff { + return fmt.Sprintf("CRYPTO_ERROR(%v)", uint64(e)&0xff) + } + return fmt.Sprintf("ERROR %d", uint64(e)) +} + +// A localTransportError is an error sent to the peer. +type localTransportError struct { + code transportError + reason string +} + +func (e localTransportError) Error() string { + if e.reason == "" { + return fmt.Sprintf("closed connection: %v", e.code) + } + return fmt.Sprintf("closed connection: %v: %q", e.code, e.reason) +} + +// A peerTransportError is an error received from the peer. +type peerTransportError struct { + code transportError + reason string +} + +func (e peerTransportError) Error() string { + return fmt.Sprintf("peer closed connection: %v: %q", e.code, e.reason) +} + +// A StreamErrorCode is an application protocol error code (RFC 9000, Section 20.2) +// indicating why a stream is being closed. +type StreamErrorCode uint64 + +func (e StreamErrorCode) Error() string { + return fmt.Sprintf("stream error code %v", uint64(e)) +} + +// An ApplicationError is an application protocol error code (RFC 9000, Section 20.2). +// Application protocol errors may be sent when terminating a stream or connection. +type ApplicationError struct { + Code uint64 + Reason string +} + +func (e *ApplicationError) Error() string { + return fmt.Sprintf("peer closed connection: %v: %q", e.Code, e.Reason) +} + +// Is reports a match if err is an *ApplicationError with a matching Code. +func (e *ApplicationError) Is(err error) bool { + e2, ok := err.(*ApplicationError) + return ok && e2.Code == e.Code +} diff --git a/src/vendor/golang.org/x/net/quic/frame_debug.go b/src/vendor/golang.org/x/net/quic/frame_debug.go new file mode 100644 index 0000000000..8d8fd54517 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/frame_debug.go @@ -0,0 +1,730 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "fmt" + "log/slog" + "strconv" + "time" +) + +// A debugFrame is a representation of the contents of a QUIC frame, +// used for debug logs and testing but not the primary serving path. +type debugFrame interface { + String() string + write(w *packetWriter) bool + LogValue() slog.Value +} + +func parseDebugFrame(b []byte) (f debugFrame, n int) { + if len(b) == 0 { + return nil, -1 + } + switch b[0] { + case frameTypePadding: + f, n = parseDebugFramePadding(b) + case frameTypePing: + f, n = parseDebugFramePing(b) + case frameTypeAck, frameTypeAckECN: + f, n = parseDebugFrameAck(b) + case frameTypeResetStream: + f, n = parseDebugFrameResetStream(b) + case frameTypeStopSending: + f, n = parseDebugFrameStopSending(b) + case frameTypeCrypto: + f, n = parseDebugFrameCrypto(b) + case frameTypeNewToken: + f, n = parseDebugFrameNewToken(b) + case frameTypeStreamBase, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f: + f, n = parseDebugFrameStream(b) + case frameTypeMaxData: + f, n = parseDebugFrameMaxData(b) + case frameTypeMaxStreamData: + f, n = parseDebugFrameMaxStreamData(b) + case frameTypeMaxStreamsBidi, frameTypeMaxStreamsUni: + f, n = parseDebugFrameMaxStreams(b) + case frameTypeDataBlocked: + f, n = parseDebugFrameDataBlocked(b) + case frameTypeStreamDataBlocked: + f, n = parseDebugFrameStreamDataBlocked(b) + case frameTypeStreamsBlockedBidi, frameTypeStreamsBlockedUni: + f, n = parseDebugFrameStreamsBlocked(b) + case frameTypeNewConnectionID: + f, n = parseDebugFrameNewConnectionID(b) + case frameTypeRetireConnectionID: + f, n = parseDebugFrameRetireConnectionID(b) + case frameTypePathChallenge: + f, n = parseDebugFramePathChallenge(b) + case frameTypePathResponse: + f, n = parseDebugFramePathResponse(b) + case frameTypeConnectionCloseTransport: + f, n = parseDebugFrameConnectionCloseTransport(b) + case frameTypeConnectionCloseApplication: + f, n = parseDebugFrameConnectionCloseApplication(b) + case frameTypeHandshakeDone: + f, n = parseDebugFrameHandshakeDone(b) + default: + return nil, -1 + } + return f, n +} + +// debugFramePadding is a sequence of PADDING frames. +type debugFramePadding struct { + size int + to int // alternate for writing packets: pad to +} + +func parseDebugFramePadding(b []byte) (f debugFramePadding, n int) { + for n < len(b) && b[n] == frameTypePadding { + n++ + } + f.size = n + return f, n +} + +func (f debugFramePadding) String() string { + return fmt.Sprintf("PADDING*%v", f.size) +} + +func (f debugFramePadding) write(w *packetWriter) bool { + if w.avail() == 0 { + return false + } + if f.to > 0 { + w.appendPaddingTo(f.to) + return true + } + for i := 0; i < f.size && w.avail() > 0; i++ { + w.b = append(w.b, frameTypePadding) + } + return true +} + +func (f debugFramePadding) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "padding"), + slog.Int("length", f.size), + ) +} + +// debugFramePing is a PING frame. +type debugFramePing struct{} + +func parseDebugFramePing(b []byte) (f debugFramePing, n int) { + return f, 1 +} + +func (f debugFramePing) String() string { + return "PING" +} + +func (f debugFramePing) write(w *packetWriter) bool { + return w.appendPingFrame() +} + +func (f debugFramePing) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "ping"), + ) +} + +// debugFrameAck is an ACK frame. +type debugFrameAck struct { + ackDelay unscaledAckDelay + ranges []i64range[packetNumber] + ecn ecnCounts +} + +func parseDebugFrameAck(b []byte) (f debugFrameAck, n int) { + f.ranges = nil + _, f.ackDelay, f.ecn, n = consumeAckFrame(b, func(_ int, start, end packetNumber) { + f.ranges = append(f.ranges, i64range[packetNumber]{ + start: start, + end: end, + }) + }) + // Ranges are parsed high to low; reverse ranges slice to order them low to high. + for i := 0; i < len(f.ranges)/2; i++ { + j := len(f.ranges) - 1 + f.ranges[i], f.ranges[j] = f.ranges[j], f.ranges[i] + } + return f, n +} + +func (f debugFrameAck) String() string { + s := fmt.Sprintf("ACK Delay=%v", f.ackDelay) + for _, r := range f.ranges { + s += fmt.Sprintf(" [%v,%v)", r.start, r.end) + } + + if (f.ecn != ecnCounts{}) { + s += fmt.Sprintf(" ECN=[%d,%d,%d]", f.ecn.t0, f.ecn.t1, f.ecn.ce) + } + return s +} + +func (f debugFrameAck) write(w *packetWriter) bool { + return w.appendAckFrame(rangeset[packetNumber](f.ranges), f.ackDelay, f.ecn) +} + +func (f debugFrameAck) LogValue() slog.Value { + return slog.StringValue("error: debugFrameAck should not appear as a slog Value") +} + +// debugFrameScaledAck is an ACK frame with scaled ACK Delay. +// +// This type is used in qlog events, which need access to the delay as a duration. +type debugFrameScaledAck struct { + ackDelay time.Duration + ranges []i64range[packetNumber] +} + +func (f debugFrameScaledAck) LogValue() slog.Value { + var ackDelay slog.Attr + if f.ackDelay >= 0 { + ackDelay = slog.Duration("ack_delay", f.ackDelay) + } + return slog.GroupValue( + slog.String("frame_type", "ack"), + // Rather than trying to convert the ack ranges into the slog data model, + // pass a value that can JSON-encode itself. + slog.Any("acked_ranges", debugAckRanges(f.ranges)), + ackDelay, + ) +} + +type debugAckRanges []i64range[packetNumber] + +// AppendJSON appends a JSON encoding of the ack ranges to b, and returns it. +// This is different than the standard json.Marshaler, but more efficient. +// Since we only use this in cooperation with the qlog package, +// encoding/json compatibility is irrelevant. +func (r debugAckRanges) AppendJSON(b []byte) []byte { + b = append(b, '[') + for i, ar := range r { + start, end := ar.start, ar.end-1 // qlog ranges are closed-closed + if i != 0 { + b = append(b, ',') + } + b = append(b, '[') + b = strconv.AppendInt(b, int64(start), 10) + if start != end { + b = append(b, ',') + b = strconv.AppendInt(b, int64(end), 10) + } + b = append(b, ']') + } + b = append(b, ']') + return b +} + +func (r debugAckRanges) String() string { + return string(r.AppendJSON(nil)) +} + +// debugFrameResetStream is a RESET_STREAM frame. +type debugFrameResetStream struct { + id streamID + code uint64 + finalSize int64 +} + +func parseDebugFrameResetStream(b []byte) (f debugFrameResetStream, n int) { + f.id, f.code, f.finalSize, n = consumeResetStreamFrame(b) + return f, n +} + +func (f debugFrameResetStream) String() string { + return fmt.Sprintf("RESET_STREAM ID=%v Code=%v FinalSize=%v", f.id, f.code, f.finalSize) +} + +func (f debugFrameResetStream) write(w *packetWriter) bool { + return w.appendResetStreamFrame(f.id, f.code, f.finalSize) +} + +func (f debugFrameResetStream) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "reset_stream"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Uint64("final_size", uint64(f.finalSize)), + ) +} + +// debugFrameStopSending is a STOP_SENDING frame. +type debugFrameStopSending struct { + id streamID + code uint64 +} + +func parseDebugFrameStopSending(b []byte) (f debugFrameStopSending, n int) { + f.id, f.code, n = consumeStopSendingFrame(b) + return f, n +} + +func (f debugFrameStopSending) String() string { + return fmt.Sprintf("STOP_SENDING ID=%v Code=%v", f.id, f.code) +} + +func (f debugFrameStopSending) write(w *packetWriter) bool { + return w.appendStopSendingFrame(f.id, f.code) +} + +func (f debugFrameStopSending) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "stop_sending"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Uint64("error_code", uint64(f.code)), + ) +} + +// debugFrameCrypto is a CRYPTO frame. +type debugFrameCrypto struct { + off int64 + data []byte +} + +func parseDebugFrameCrypto(b []byte) (f debugFrameCrypto, n int) { + f.off, f.data, n = consumeCryptoFrame(b) + return f, n +} + +func (f debugFrameCrypto) String() string { + return fmt.Sprintf("CRYPTO Offset=%v Length=%v", f.off, len(f.data)) +} + +func (f debugFrameCrypto) write(w *packetWriter) bool { + b, added := w.appendCryptoFrame(f.off, len(f.data)) + copy(b, f.data) + return added +} + +func (f debugFrameCrypto) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "crypto"), + slog.Int64("offset", f.off), + slog.Int("length", len(f.data)), + ) +} + +// debugFrameNewToken is a NEW_TOKEN frame. +type debugFrameNewToken struct { + token []byte +} + +func parseDebugFrameNewToken(b []byte) (f debugFrameNewToken, n int) { + f.token, n = consumeNewTokenFrame(b) + return f, n +} + +func (f debugFrameNewToken) String() string { + return fmt.Sprintf("NEW_TOKEN Token=%x", f.token) +} + +func (f debugFrameNewToken) write(w *packetWriter) bool { + return w.appendNewTokenFrame(f.token) +} + +func (f debugFrameNewToken) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "new_token"), + slogHexstring("token", f.token), + ) +} + +// debugFrameStream is a STREAM frame. +type debugFrameStream struct { + id streamID + fin bool + off int64 + data []byte +} + +func parseDebugFrameStream(b []byte) (f debugFrameStream, n int) { + f.id, f.off, f.fin, f.data, n = consumeStreamFrame(b) + return f, n +} + +func (f debugFrameStream) String() string { + fin := "" + if f.fin { + fin = " FIN" + } + return fmt.Sprintf("STREAM ID=%v%v Offset=%v Length=%v", f.id, fin, f.off, len(f.data)) +} + +func (f debugFrameStream) write(w *packetWriter) bool { + b, added := w.appendStreamFrame(f.id, f.off, len(f.data), f.fin) + copy(b, f.data) + return added +} + +func (f debugFrameStream) LogValue() slog.Value { + var fin slog.Attr + if f.fin { + fin = slog.Bool("fin", true) + } + return slog.GroupValue( + slog.String("frame_type", "stream"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Int64("offset", f.off), + slog.Int("length", len(f.data)), + fin, + ) +} + +// debugFrameMaxData is a MAX_DATA frame. +type debugFrameMaxData struct { + max int64 +} + +func parseDebugFrameMaxData(b []byte) (f debugFrameMaxData, n int) { + f.max, n = consumeMaxDataFrame(b) + return f, n +} + +func (f debugFrameMaxData) String() string { + return fmt.Sprintf("MAX_DATA Max=%v", f.max) +} + +func (f debugFrameMaxData) write(w *packetWriter) bool { + return w.appendMaxDataFrame(f.max) +} + +func (f debugFrameMaxData) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "max_data"), + slog.Int64("maximum", f.max), + ) +} + +// debugFrameMaxStreamData is a MAX_STREAM_DATA frame. +type debugFrameMaxStreamData struct { + id streamID + max int64 +} + +func parseDebugFrameMaxStreamData(b []byte) (f debugFrameMaxStreamData, n int) { + f.id, f.max, n = consumeMaxStreamDataFrame(b) + return f, n +} + +func (f debugFrameMaxStreamData) String() string { + return fmt.Sprintf("MAX_STREAM_DATA ID=%v Max=%v", f.id, f.max) +} + +func (f debugFrameMaxStreamData) write(w *packetWriter) bool { + return w.appendMaxStreamDataFrame(f.id, f.max) +} + +func (f debugFrameMaxStreamData) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "max_stream_data"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Int64("maximum", f.max), + ) +} + +// debugFrameMaxStreams is a MAX_STREAMS frame. +type debugFrameMaxStreams struct { + streamType streamType + max int64 +} + +func parseDebugFrameMaxStreams(b []byte) (f debugFrameMaxStreams, n int) { + f.streamType, f.max, n = consumeMaxStreamsFrame(b) + return f, n +} + +func (f debugFrameMaxStreams) String() string { + return fmt.Sprintf("MAX_STREAMS Type=%v Max=%v", f.streamType, f.max) +} + +func (f debugFrameMaxStreams) write(w *packetWriter) bool { + return w.appendMaxStreamsFrame(f.streamType, f.max) +} + +func (f debugFrameMaxStreams) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "max_streams"), + slog.String("stream_type", f.streamType.qlogString()), + slog.Int64("maximum", f.max), + ) +} + +// debugFrameDataBlocked is a DATA_BLOCKED frame. +type debugFrameDataBlocked struct { + max int64 +} + +func parseDebugFrameDataBlocked(b []byte) (f debugFrameDataBlocked, n int) { + f.max, n = consumeDataBlockedFrame(b) + return f, n +} + +func (f debugFrameDataBlocked) String() string { + return fmt.Sprintf("DATA_BLOCKED Max=%v", f.max) +} + +func (f debugFrameDataBlocked) write(w *packetWriter) bool { + return w.appendDataBlockedFrame(f.max) +} + +func (f debugFrameDataBlocked) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "data_blocked"), + slog.Int64("limit", f.max), + ) +} + +// debugFrameStreamDataBlocked is a STREAM_DATA_BLOCKED frame. +type debugFrameStreamDataBlocked struct { + id streamID + max int64 +} + +func parseDebugFrameStreamDataBlocked(b []byte) (f debugFrameStreamDataBlocked, n int) { + f.id, f.max, n = consumeStreamDataBlockedFrame(b) + return f, n +} + +func (f debugFrameStreamDataBlocked) String() string { + return fmt.Sprintf("STREAM_DATA_BLOCKED ID=%v Max=%v", f.id, f.max) +} + +func (f debugFrameStreamDataBlocked) write(w *packetWriter) bool { + return w.appendStreamDataBlockedFrame(f.id, f.max) +} + +func (f debugFrameStreamDataBlocked) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "stream_data_blocked"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Int64("limit", f.max), + ) +} + +// debugFrameStreamsBlocked is a STREAMS_BLOCKED frame. +type debugFrameStreamsBlocked struct { + streamType streamType + max int64 +} + +func parseDebugFrameStreamsBlocked(b []byte) (f debugFrameStreamsBlocked, n int) { + f.streamType, f.max, n = consumeStreamsBlockedFrame(b) + return f, n +} + +func (f debugFrameStreamsBlocked) String() string { + return fmt.Sprintf("STREAMS_BLOCKED Type=%v Max=%v", f.streamType, f.max) +} + +func (f debugFrameStreamsBlocked) write(w *packetWriter) bool { + return w.appendStreamsBlockedFrame(f.streamType, f.max) +} + +func (f debugFrameStreamsBlocked) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "streams_blocked"), + slog.String("stream_type", f.streamType.qlogString()), + slog.Int64("limit", f.max), + ) +} + +// debugFrameNewConnectionID is a NEW_CONNECTION_ID frame. +type debugFrameNewConnectionID struct { + seq int64 + retirePriorTo int64 + connID []byte + token statelessResetToken +} + +func parseDebugFrameNewConnectionID(b []byte) (f debugFrameNewConnectionID, n int) { + f.seq, f.retirePriorTo, f.connID, f.token, n = consumeNewConnectionIDFrame(b) + return f, n +} + +func (f debugFrameNewConnectionID) String() string { + return fmt.Sprintf("NEW_CONNECTION_ID Seq=%v Retire=%v ID=%x Token=%x", f.seq, f.retirePriorTo, f.connID, f.token[:]) +} + +func (f debugFrameNewConnectionID) write(w *packetWriter) bool { + return w.appendNewConnectionIDFrame(f.seq, f.retirePriorTo, f.connID, f.token) +} + +func (f debugFrameNewConnectionID) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "new_connection_id"), + slog.Int64("sequence_number", f.seq), + slog.Int64("retire_prior_to", f.retirePriorTo), + slogHexstring("connection_id", f.connID), + slogHexstring("stateless_reset_token", f.token[:]), + ) +} + +// debugFrameRetireConnectionID is a NEW_CONNECTION_ID frame. +type debugFrameRetireConnectionID struct { + seq int64 +} + +func parseDebugFrameRetireConnectionID(b []byte) (f debugFrameRetireConnectionID, n int) { + f.seq, n = consumeRetireConnectionIDFrame(b) + return f, n +} + +func (f debugFrameRetireConnectionID) String() string { + return fmt.Sprintf("RETIRE_CONNECTION_ID Seq=%v", f.seq) +} + +func (f debugFrameRetireConnectionID) write(w *packetWriter) bool { + return w.appendRetireConnectionIDFrame(f.seq) +} + +func (f debugFrameRetireConnectionID) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "retire_connection_id"), + slog.Int64("sequence_number", f.seq), + ) +} + +// debugFramePathChallenge is a PATH_CHALLENGE frame. +type debugFramePathChallenge struct { + data pathChallengeData +} + +func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) { + f.data, n = consumePathChallengeFrame(b) + return f, n +} + +func (f debugFramePathChallenge) String() string { + return fmt.Sprintf("PATH_CHALLENGE Data=%x", f.data) +} + +func (f debugFramePathChallenge) write(w *packetWriter) bool { + return w.appendPathChallengeFrame(f.data) +} + +func (f debugFramePathChallenge) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "path_challenge"), + slog.String("data", fmt.Sprintf("%x", f.data)), + ) +} + +// debugFramePathResponse is a PATH_RESPONSE frame. +type debugFramePathResponse struct { + data pathChallengeData +} + +func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) { + f.data, n = consumePathResponseFrame(b) + return f, n +} + +func (f debugFramePathResponse) String() string { + return fmt.Sprintf("PATH_RESPONSE Data=%x", f.data) +} + +func (f debugFramePathResponse) write(w *packetWriter) bool { + return w.appendPathResponseFrame(f.data) +} + +func (f debugFramePathResponse) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "path_response"), + slog.String("data", fmt.Sprintf("%x", f.data)), + ) +} + +// debugFrameConnectionCloseTransport is a CONNECTION_CLOSE frame carrying a transport error. +type debugFrameConnectionCloseTransport struct { + code transportError + frameType uint64 + reason string +} + +func parseDebugFrameConnectionCloseTransport(b []byte) (f debugFrameConnectionCloseTransport, n int) { + f.code, f.frameType, f.reason, n = consumeConnectionCloseTransportFrame(b) + return f, n +} + +func (f debugFrameConnectionCloseTransport) String() string { + s := fmt.Sprintf("CONNECTION_CLOSE Code=%v", f.code) + if f.frameType != 0 { + s += fmt.Sprintf(" FrameType=%v", f.frameType) + } + if f.reason != "" { + s += fmt.Sprintf(" Reason=%q", f.reason) + } + return s +} + +func (f debugFrameConnectionCloseTransport) write(w *packetWriter) bool { + return w.appendConnectionCloseTransportFrame(f.code, f.frameType, f.reason) +} + +func (f debugFrameConnectionCloseTransport) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "connection_close"), + slog.String("error_space", "transport"), + slog.Uint64("error_code_value", uint64(f.code)), + slog.String("reason", f.reason), + ) +} + +// debugFrameConnectionCloseApplication is a CONNECTION_CLOSE frame carrying an application error. +type debugFrameConnectionCloseApplication struct { + code uint64 + reason string +} + +func parseDebugFrameConnectionCloseApplication(b []byte) (f debugFrameConnectionCloseApplication, n int) { + f.code, f.reason, n = consumeConnectionCloseApplicationFrame(b) + return f, n +} + +func (f debugFrameConnectionCloseApplication) String() string { + s := fmt.Sprintf("CONNECTION_CLOSE AppCode=%v", f.code) + if f.reason != "" { + s += fmt.Sprintf(" Reason=%q", f.reason) + } + return s +} + +func (f debugFrameConnectionCloseApplication) write(w *packetWriter) bool { + return w.appendConnectionCloseApplicationFrame(f.code, f.reason) +} + +func (f debugFrameConnectionCloseApplication) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "connection_close"), + slog.String("error_space", "application"), + slog.Uint64("error_code_value", uint64(f.code)), + slog.String("reason", f.reason), + ) +} + +// debugFrameHandshakeDone is a HANDSHAKE_DONE frame. +type debugFrameHandshakeDone struct{} + +func parseDebugFrameHandshakeDone(b []byte) (f debugFrameHandshakeDone, n int) { + return f, 1 +} + +func (f debugFrameHandshakeDone) String() string { + return "HANDSHAKE_DONE" +} + +func (f debugFrameHandshakeDone) write(w *packetWriter) bool { + return w.appendHandshakeDoneFrame() +} + +func (f debugFrameHandshakeDone) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "handshake_done"), + ) +} diff --git a/src/vendor/golang.org/x/net/quic/gate.go b/src/vendor/golang.org/x/net/quic/gate.go new file mode 100644 index 0000000000..b8b8605e62 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/gate.go @@ -0,0 +1,86 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "context" + +// An gate is a monitor (mutex + condition variable) with one bit of state. +// +// The condition may be either set or unset. +// Lock operations may be unconditional, or wait for the condition to be set. +// Unlock operations record the new state of the condition. +type gate struct { + // When unlocked, exactly one of set or unset contains a value. + // When locked, neither chan contains a value. + set chan struct{} + unset chan struct{} +} + +// newGate returns a new, unlocked gate with the condition unset. +func newGate() gate { + g := newLockedGate() + g.unlock(false) + return g +} + +// newLockedGate returns a new, locked gate. +func newLockedGate() gate { + return gate{ + set: make(chan struct{}, 1), + unset: make(chan struct{}, 1), + } +} + +// lock acquires the gate unconditionally. +// It reports whether the condition is set. +func (g *gate) lock() (set bool) { + select { + case <-g.set: + return true + case <-g.unset: + return false + } +} + +// waitAndLock waits until the condition is set before acquiring the gate. +// If the context expires, waitAndLock returns an error and does not acquire the gate. +func (g *gate) waitAndLock(ctx context.Context) error { + select { + case <-g.set: + return nil + default: + } + select { + case <-g.set: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// lockIfSet acquires the gate if and only if the condition is set. +func (g *gate) lockIfSet() (acquired bool) { + select { + case <-g.set: + return true + default: + return false + } +} + +// unlock sets the condition and releases the gate. +func (g *gate) unlock(set bool) { + if set { + g.set <- struct{}{} + } else { + g.unset <- struct{}{} + } +} + +// unlockFunc sets the condition to the result of f and releases the gate. +// Useful in defers. +func (g *gate) unlockFunc(f func() bool) { + g.unlock(f()) +} diff --git a/src/vendor/golang.org/x/net/quic/idle.go b/src/vendor/golang.org/x/net/quic/idle.go new file mode 100644 index 0000000000..6b1dfd1d25 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/idle.go @@ -0,0 +1,168 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// idleState tracks connection idle events. +// +// Before the handshake is confirmed, the idle timeout is Config.HandshakeTimeout. +// +// After the handshake is confirmed, the idle timeout is +// the minimum of Config.MaxIdleTimeout and the peer's max_idle_timeout transport parameter. +// +// If KeepAlivePeriod is set, keep-alive pings are sent. +// Keep-alives are only sent after the handshake is confirmed. +// +// https://www.rfc-editor.org/rfc/rfc9000#section-10.1 +type idleState struct { + // idleDuration is the negotiated idle timeout for the connection. + idleDuration time.Duration + + // idleTimeout is the time at which the connection will be closed due to inactivity. + idleTimeout time.Time + + // nextTimeout is the time of the next idle event. + // If nextTimeout == idleTimeout, this is the idle timeout. + // Otherwise, this is the keep-alive timeout. + nextTimeout time.Time + + // sentSinceLastReceive is set if we have sent an ack-eliciting packet + // since the last time we received and processed a packet from the peer. + sentSinceLastReceive bool +} + +// receivePeerMaxIdleTimeout handles the peer's max_idle_timeout transport parameter. +func (c *Conn) receivePeerMaxIdleTimeout(peerMaxIdleTimeout time.Duration) { + localMaxIdleTimeout := c.config.maxIdleTimeout() + switch { + case localMaxIdleTimeout == 0: + c.idle.idleDuration = peerMaxIdleTimeout + case peerMaxIdleTimeout == 0: + c.idle.idleDuration = localMaxIdleTimeout + default: + c.idle.idleDuration = min(localMaxIdleTimeout, peerMaxIdleTimeout) + } +} + +func (c *Conn) idleHandlePacketReceived(now time.Time) { + if !c.handshakeConfirmed.isSet() { + return + } + // "An endpoint restarts its idle timer when a packet from its peer is + // received and processed successfully." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3 + c.idle.sentSinceLastReceive = false + c.restartIdleTimer(now) +} + +func (c *Conn) idleHandlePacketSent(now time.Time, sent *sentPacket) { + // "An endpoint also restarts its idle timer when sending an ack-eliciting packet + // if no other ack-eliciting packets have been sent since + // last receiving and processing a packet." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3 + if c.idle.sentSinceLastReceive || !sent.ackEliciting || !c.handshakeConfirmed.isSet() { + return + } + c.idle.sentSinceLastReceive = true + c.restartIdleTimer(now) +} + +func (c *Conn) restartIdleTimer(now time.Time) { + if !c.isAlive() { + // Connection is closing, disable timeouts. + c.idle.idleTimeout = time.Time{} + c.idle.nextTimeout = time.Time{} + return + } + var idleDuration time.Duration + if c.handshakeConfirmed.isSet() { + idleDuration = c.idle.idleDuration + } else { + idleDuration = c.config.handshakeTimeout() + } + if idleDuration == 0 { + c.idle.idleTimeout = time.Time{} + } else { + // "[...] endpoints MUST increase the idle timeout period to be + // at least three times the current Probe Timeout (PTO)." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-4 + idleDuration = max(idleDuration, 3*c.loss.ptoPeriod()) + c.idle.idleTimeout = now.Add(idleDuration) + } + // Set the time of our next event: + // The idle timer if no keep-alive is set, or the keep-alive timer if one is. + c.idle.nextTimeout = c.idle.idleTimeout + keepAlive := c.config.keepAlivePeriod() + switch { + case !c.handshakeConfirmed.isSet(): + // We do not send keep-alives before the handshake is complete. + case keepAlive <= 0: + // Keep-alives are not enabled. + case c.idle.sentSinceLastReceive: + // We have sent an ack-eliciting packet to the peer. + // If they don't acknowledge it, loss detection will follow up with PTO probes, + // which will function as keep-alives. + // We don't need to send further pings. + case idleDuration == 0: + // The connection does not have a negotiated idle timeout. + // Send keep-alives anyway, since they may be required to keep middleboxes + // from losing state. + c.idle.nextTimeout = now.Add(keepAlive) + default: + // Schedule our next keep-alive. + // If our configured keep-alive period is greater than half the negotiated + // connection idle timeout, we reduce the keep-alive period to half + // the idle timeout to ensure we have time for the ping to arrive. + c.idle.nextTimeout = now.Add(min(keepAlive, idleDuration/2)) + } +} + +func (c *Conn) appendKeepAlive(now time.Time) bool { + if c.idle.nextTimeout.IsZero() || c.idle.nextTimeout.After(now) { + return true // timer has not expired + } + if c.idle.nextTimeout.Equal(c.idle.idleTimeout) { + return true // no keepalive timer set, only idle + } + if c.idle.sentSinceLastReceive { + return true // already sent an ack-eliciting packet + } + if c.w.sent.ackEliciting { + return true // this packet is already ack-eliciting + } + // Send an ack-eliciting PING frame to the peer to keep the connection alive. + return c.w.appendPingFrame() +} + +var errHandshakeTimeout error = localTransportError{ + code: errConnectionRefused, + reason: "handshake timeout", +} + +func (c *Conn) idleAdvance(now time.Time) (shouldExit bool) { + if c.idle.idleTimeout.IsZero() || now.Before(c.idle.idleTimeout) { + return false + } + c.idle.idleTimeout = time.Time{} + c.idle.nextTimeout = time.Time{} + if !c.handshakeConfirmed.isSet() { + // Handshake timeout has expired. + // If we're a server, we're refusing the too-slow client. + // If we're a client, we're giving up. + // In either case, we're going to send a CONNECTION_CLOSE frame and + // enter the closing state rather than unceremoniously dropping the connection, + // since the peer might still be trying to complete the handshake. + c.abort(now, errHandshakeTimeout) + return false + } + // Idle timeout has expired. + // + // "[...] the connection is silently closed and its state is discarded [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1 + return true +} diff --git a/src/vendor/golang.org/x/net/quic/log.go b/src/vendor/golang.org/x/net/quic/log.go new file mode 100644 index 0000000000..eee2b5fd61 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/log.go @@ -0,0 +1,67 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "fmt" + "os" + "strings" +) + +var logPackets bool + +// Parse GODEBUG settings. +// +// GODEBUG=quiclogpackets=1 -- log every packet sent and received. +func init() { + s := os.Getenv("GODEBUG") + for len(s) > 0 { + var opt string + opt, s, _ = strings.Cut(s, ",") + switch opt { + case "quiclogpackets=1": + logPackets = true + } + } +} + +func logInboundLongPacket(c *Conn, p longPacket) { + if !logPackets { + return + } + prefix := c.String() + fmt.Printf("%v recv %v %v\n", prefix, p.ptype, p.num) + logFrames(prefix+" <- ", p.payload) +} + +func logInboundShortPacket(c *Conn, p shortPacket) { + if !logPackets { + return + } + prefix := c.String() + fmt.Printf("%v recv 1-RTT %v\n", prefix, p.num) + logFrames(prefix+" <- ", p.payload) +} + +func logSentPacket(c *Conn, ptype packetType, pnum packetNumber, src, dst, payload []byte) { + if !logPackets || len(payload) == 0 { + return + } + prefix := c.String() + fmt.Printf("%v send %v %v\n", prefix, ptype, pnum) + logFrames(prefix+" -> ", payload) +} + +func logFrames(prefix string, payload []byte) { + for len(payload) > 0 { + f, n := parseDebugFrame(payload) + if n < 0 { + fmt.Printf("%vBAD DATA\n", prefix) + break + } + payload = payload[n:] + fmt.Printf("%v%v\n", prefix, f) + } +} diff --git a/src/vendor/golang.org/x/net/quic/loss.go b/src/vendor/golang.org/x/net/quic/loss.go new file mode 100644 index 0000000000..95feaba2d4 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/loss.go @@ -0,0 +1,521 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "log/slog" + "math" + "time" +) + +type lossState struct { + side connSide + + // True when the handshake is confirmed. + // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2 + handshakeConfirmed bool + + // Peer's max_ack_delay transport parameter. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.28.1 + maxAckDelay time.Duration + + // Time of the next event: PTO expiration (if ptoTimerArmed is true), + // or loss detection. + // The connection must call lossState.advance when the timer expires. + timer time.Time + + // True when the PTO timer is set. + ptoTimerArmed bool + + // True when the PTO timer has expired and a probe packet has not yet been sent. + ptoExpired bool + + // Count of PTO expirations since the lack received acknowledgement. + // https://www.rfc-editor.org/rfc/rfc9002#section-6.2.1-9 + ptoBackoffCount int + + // Anti-amplification limit: Three times the amount of data received from + // the peer, less the amount of data sent. + // + // Set to antiAmplificationUnlimited (MaxInt) to disable the limit. + // The limit is always disabled for clients, and for servers after the + // peer's address is validated. + // + // Anti-amplification is per-address; this will need to change if/when we + // support address migration. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-8-2 + antiAmplificationLimit int + + // Count of non-ack-eliciting packets (ACKs) sent since the last ack-eliciting one. + consecutiveNonAckElicitingPackets int + + rtt rttState + pacer pacerState + cc *ccReno + + // Per-space loss detection state. + spaces [numberSpaceCount]struct { + sentPacketList + maxAcked packetNumber + lastAckEliciting packetNumber + } + + // Temporary state used when processing an ACK frame. + ackFrameRTT time.Duration // RTT from latest packet in frame + ackFrameContainsAckEliciting bool // newly acks an ack-eliciting packet? +} + +const antiAmplificationUnlimited = math.MaxInt + +func (c *lossState) init(side connSide, maxDatagramSize int, now time.Time) { + c.side = side + if side == clientSide { + // Clients don't have an anti-amplification limit. + c.antiAmplificationLimit = antiAmplificationUnlimited + } + c.rtt.init() + c.cc = newReno(maxDatagramSize) + c.pacer.init(now, c.cc.congestionWindow, timerGranularity) + + // Peer's assumed max_ack_delay, prior to receiving transport parameters. + // https://www.rfc-editor.org/rfc/rfc9000#section-18.2 + c.maxAckDelay = 25 * time.Millisecond + + for space := range c.spaces { + c.spaces[space].maxAcked = -1 + c.spaces[space].lastAckEliciting = -1 + } +} + +// setMaxAckDelay sets the max_ack_delay transport parameter received from the peer. +func (c *lossState) setMaxAckDelay(d time.Duration) { + if d >= (1<<14)*time.Millisecond { + // Values of 2^14 or greater are invalid. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.28.1 + return + } + c.maxAckDelay = d +} + +// confirmHandshake indicates the handshake has been confirmed. +func (c *lossState) confirmHandshake() { + c.handshakeConfirmed = true +} + +// validateClientAddress disables the anti-amplification limit after +// a server validates a client's address. +func (c *lossState) validateClientAddress() { + c.antiAmplificationLimit = antiAmplificationUnlimited +} + +// minDatagramSize is the minimum datagram size permitted by +// anti-amplification protection. +// +// Defining a minimum size avoids the case where, say, anti-amplification +// technically allows us to send a 1-byte datagram, but no such datagram +// can be constructed. +const minPacketSize = 128 + +type ccLimit int + +const ( + ccOK = ccLimit(iota) // OK to send + ccBlocked // sending blocked by anti-amplification + ccLimited // sending blocked by congestion control + ccPaced // sending allowed by congestion, but delayed by pacer +) + +// sendLimit reports whether sending is possible at this time. +// When sending is pacing limited, it returns the next time a packet may be sent. +func (c *lossState) sendLimit(now time.Time) (limit ccLimit, next time.Time) { + if c.antiAmplificationLimit < minPacketSize { + // When at the anti-amplification limit, we may not send anything. + return ccBlocked, time.Time{} + } + if c.ptoExpired { + // On PTO expiry, send a probe. + return ccOK, time.Time{} + } + if !c.cc.canSend() { + // Congestion control blocks sending. + return ccLimited, time.Time{} + } + if c.cc.bytesInFlight == 0 { + // If no bytes are in flight, send packet unpaced. + return ccOK, time.Time{} + } + canSend, next := c.pacer.canSend(now) + if !canSend { + // Pacer blocks sending. + return ccPaced, next + } + return ccOK, time.Time{} +} + +// maxSendSize reports the maximum datagram size that may be sent. +func (c *lossState) maxSendSize() int { + return min(c.antiAmplificationLimit, c.cc.maxDatagramSize) +} + +// advance is called when time passes. +// The lossf function is called for each packet newly detected as lost. +func (c *lossState) advance(now time.Time, lossf func(numberSpace, *sentPacket, packetFate)) { + c.pacer.advance(now, c.cc.congestionWindow, c.rtt.smoothedRTT) + if c.ptoTimerArmed && !c.timer.IsZero() && !c.timer.After(now) { + c.ptoExpired = true + c.timer = time.Time{} + c.ptoBackoffCount++ + } + c.detectLoss(now, lossf) +} + +// nextNumber returns the next packet number to use in a space. +func (c *lossState) nextNumber(space numberSpace) packetNumber { + return c.spaces[space].nextNum +} + +// skipNumber skips a packet number as a defense against optimistic ACK attacks. +func (c *lossState) skipNumber(now time.Time, space numberSpace) { + sent := newSentPacket() + sent.num = c.spaces[space].nextNum + sent.time = now + sent.state = sentPacketUnsent + c.spaces[space].add(sent) +} + +// packetSent records a sent packet. +func (c *lossState) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) { + sent.time = now + c.spaces[space].add(sent) + size := sent.size + if c.antiAmplificationLimit != antiAmplificationUnlimited { + c.antiAmplificationLimit = max(0, c.antiAmplificationLimit-size) + } + if sent.inFlight { + c.cc.packetSent(now, log, space, sent) + c.pacer.packetSent(now, size, c.cc.congestionWindow, c.rtt.smoothedRTT) + if sent.ackEliciting { + c.spaces[space].lastAckEliciting = sent.num + c.ptoExpired = false // reset expired PTO timer after sending probe + } + c.scheduleTimer(now) + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } + } + if sent.ackEliciting { + c.consecutiveNonAckElicitingPackets = 0 + } else { + c.consecutiveNonAckElicitingPackets++ + } +} + +// datagramReceived records a datagram (not packet!) received from the peer. +func (c *lossState) datagramReceived(now time.Time, size int) { + if c.antiAmplificationLimit != antiAmplificationUnlimited { + c.antiAmplificationLimit += 3 * size + // Reset the PTO timer, possibly to a point in the past, in which + // case the caller should execute it immediately. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-2 + c.scheduleTimer(now) + if c.ptoTimerArmed && !c.timer.IsZero() && !c.timer.After(now) { + c.ptoExpired = true + c.timer = time.Time{} + } + } +} + +// receiveAckStart starts processing an ACK frame. +// Call receiveAckRange for each range in the frame. +// Call receiveAckFrameEnd after all ranges are processed. +func (c *lossState) receiveAckStart() { + c.ackFrameContainsAckEliciting = false + c.ackFrameRTT = -1 +} + +// receiveAckRange processes a range within an ACK frame. +// The ackf function is called for each newly-acknowledged packet. +func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex int, start, end packetNumber, ackf func(numberSpace, *sentPacket, packetFate)) error { + // Limit our range to the intersection of the ACK range and + // the in-flight packets we have state for. + if s := c.spaces[space].start(); start < s { + start = s + } + if e := c.spaces[space].end(); end > e { + return localTransportError{ + code: errProtocolViolation, + reason: "acknowledgement for unsent packet", + } + } + if start >= end { + return nil + } + if rangeIndex == 0 { + // If the latest packet in the ACK frame is newly-acked, + // record the RTT in c.ackFrameRTT. + sent := c.spaces[space].num(end - 1) + if sent.state == sentPacketSent { + c.ackFrameRTT = max(0, now.Sub(sent.time)) + } + } + for pnum := start; pnum < end; pnum++ { + sent := c.spaces[space].num(pnum) + if sent.state == sentPacketUnsent { + return localTransportError{ + code: errProtocolViolation, + reason: "acknowledgement for unsent packet", + } + } + if sent.state != sentPacketSent { + continue + } + // This is a newly-acknowledged packet. + if pnum > c.spaces[space].maxAcked { + c.spaces[space].maxAcked = pnum + } + sent.state = sentPacketAcked + c.cc.packetAcked(now, sent) + ackf(space, sent, packetAcked) + if sent.ackEliciting { + c.ackFrameContainsAckEliciting = true + } + } + return nil +} + +// receiveAckEnd finishes processing an ack frame. +// The lossf function is called for each packet newly detected as lost. +func (c *lossState) receiveAckEnd(now time.Time, log *slog.Logger, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) { + c.spaces[space].sentPacketList.clean() + // Update the RTT sample when the largest acknowledged packet in the ACK frame + // is newly acknowledged, and at least one newly acknowledged packet is ack-eliciting. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.1-2.2 + if c.ackFrameRTT >= 0 && c.ackFrameContainsAckEliciting { + c.rtt.updateSample(now, c.handshakeConfirmed, space, c.ackFrameRTT, ackDelay, c.maxAckDelay) + } + // Reset the PTO backoff. + // Exception: A client does not reset the backoff on acks for Initial packets. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1-9 + if !(c.side == clientSide && space == initialSpace) { + c.ptoBackoffCount = 0 + } + // If the client has set a PTO timer with no packets in flight + // we want to restart that timer now. Clearing c.timer does this. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-3 + c.timer = time.Time{} + c.detectLoss(now, lossf) + c.cc.packetBatchEnd(now, log, space, &c.rtt, c.maxAckDelay) + + if logEnabled(log, QLogLevelPacket) { + var ssthresh slog.Attr + if c.cc.slowStartThreshold != math.MaxInt { + ssthresh = slog.Int("ssthresh", c.cc.slowStartThreshold) + } + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:metrics_updated", + slog.Duration("min_rtt", c.rtt.minRTT), + slog.Duration("smoothed_rtt", c.rtt.smoothedRTT), + slog.Duration("latest_rtt", c.rtt.latestRTT), + slog.Duration("rtt_variance", c.rtt.rttvar), + slog.Int("congestion_window", c.cc.congestionWindow), + slog.Int("bytes_in_flight", c.cc.bytesInFlight), + ssthresh, + ) + } +} + +// discardPackets declares that packets within a number space will not be delivered +// and that data contained in them should be resent. +// For example, after receiving a Retry packet we discard already-sent Initial packets. +func (c *lossState) discardPackets(space numberSpace, log *slog.Logger, lossf func(numberSpace, *sentPacket, packetFate)) { + for i := 0; i < c.spaces[space].size; i++ { + sent := c.spaces[space].nth(i) + if sent.state != sentPacketSent { + // This should not be possible, since we only discard packets + // in spaces which have never received an ack, but check anyway. + continue + } + sent.state = sentPacketLost + c.cc.packetDiscarded(sent) + lossf(numberSpace(space), sent, packetLost) + } + c.spaces[space].clean() + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } +} + +// discardKeys is called when dropping packet protection keys for a number space. +func (c *lossState) discardKeys(now time.Time, log *slog.Logger, space numberSpace) { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.4 + for i := 0; i < c.spaces[space].size; i++ { + sent := c.spaces[space].nth(i) + if sent.state != sentPacketSent { + continue + } + c.cc.packetDiscarded(sent) + } + c.spaces[space].discard() + c.spaces[space].maxAcked = -1 + c.spaces[space].lastAckEliciting = -1 + c.scheduleTimer(now) + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } +} + +func (c *lossState) lossDuration() time.Duration { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2 + return max((9*max(c.rtt.smoothedRTT, c.rtt.latestRTT))/8, timerGranularity) +} + +func (c *lossState) detectLoss(now time.Time, lossf func(numberSpace, *sentPacket, packetFate)) { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.1-1 + const lossThreshold = 3 + + lossTime := now.Add(-c.lossDuration()) + for space := numberSpace(0); space < numberSpaceCount; space++ { + for i := 0; i < c.spaces[space].size; i++ { + sent := c.spaces[space].nth(i) + if sent.state != sentPacketSent { + continue + } + // RFC 9002 Section 6.1 states that a packet is only declared lost if it + // is "in flight", which excludes packets that contain only ACK frames. + // However, we need some way to determine when to drop state for ACK-only + // packets, and the loss algorithm in Appendix A handles loss detection of + // not-in-flight packets identically to all others, so we do the same here. + switch { + case c.spaces[space].maxAcked-sent.num >= lossThreshold: + // Packet threshold + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.1 + fallthrough + case sent.num <= c.spaces[space].maxAcked && !sent.time.After(lossTime): + // Time threshold + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2 + sent.state = sentPacketLost + lossf(space, sent, packetLost) + if sent.inFlight { + c.cc.packetLost(now, space, sent, &c.rtt) + } + } + if sent.state != sentPacketLost { + break + } + } + c.spaces[space].clean() + } + c.scheduleTimer(now) +} + +// scheduleTimer sets the loss or PTO timer. +// +// The connection is responsible for arranging for advance to be called after +// the timer expires. +// +// The timer may be set to a point in the past, in which advance should be called +// immediately. We don't do this here, because executing the timer can cause +// packet loss events, and it's simpler for the connection if loss events only +// occur when advancing time. +func (c *lossState) scheduleTimer(now time.Time) { + c.ptoTimerArmed = false + + // Loss timer for sent packets. + // The loss timer is only started once a later packet has been acknowledged, + // and takes precedence over the PTO timer. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2 + var oldestPotentiallyLost time.Time + for space := numberSpace(0); space < numberSpaceCount; space++ { + if c.spaces[space].size > 0 && c.spaces[space].start() <= c.spaces[space].maxAcked { + firstTime := c.spaces[space].nth(0).time + if oldestPotentiallyLost.IsZero() || firstTime.Before(oldestPotentiallyLost) { + oldestPotentiallyLost = firstTime + } + } + } + if !oldestPotentiallyLost.IsZero() { + c.timer = oldestPotentiallyLost.Add(c.lossDuration()) + return + } + + // PTO timer. + if c.ptoExpired { + // PTO timer has expired, don't restart it until we send a probe. + c.timer = time.Time{} + return + } + if c.antiAmplificationLimit >= 0 && c.antiAmplificationLimit < minPacketSize { + // Server is at its anti-amplification limit and can't send any more data. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-1 + c.timer = time.Time{} + return + } + // Timer starts at the most recently sent ack-eliciting packet. + // Prior to confirming the handshake, we consider the Initial and Handshake + // number spaces; after, we consider only Application Data. + var last time.Time + if !c.handshakeConfirmed { + for space := initialSpace; space <= handshakeSpace; space++ { + sent := c.spaces[space].num(c.spaces[space].lastAckEliciting) + if sent == nil { + continue + } + if last.IsZero() || last.After(sent.time) { + last = sent.time + } + } + } else { + sent := c.spaces[appDataSpace].num(c.spaces[appDataSpace].lastAckEliciting) + if sent != nil { + last = sent.time + } + } + if last.IsZero() && + c.side == clientSide && + c.spaces[handshakeSpace].maxAcked < 0 && + !c.handshakeConfirmed { + // The client must always set a PTO timer prior to receiving an ack for a + // handshake packet or the handshake being confirmed. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1 + if !c.timer.IsZero() { + // If c.timer is non-zero here, we've already set the PTO timer and + // should leave it as-is rather than moving it forward. + c.ptoTimerArmed = true + return + } + last = now + } else if last.IsZero() { + c.timer = time.Time{} + return + } + c.timer = last.Add(c.ptoPeriod()) + c.ptoTimerArmed = true +} + +func (c *lossState) ptoPeriod() time.Duration { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 + return c.ptoBasePeriod() << c.ptoBackoffCount +} + +func (c *lossState) ptoBasePeriod() time.Duration { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 + pto := c.rtt.smoothedRTT + max(4*c.rtt.rttvar, timerGranularity) + if c.handshakeConfirmed { + // The max_ack_delay is the maximum amount of time the peer might delay sending + // an ack to us. We only take it into account for the Application Data space. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1-4 + pto += c.maxAckDelay + } + return pto +} + +func logBytesInFlight(log *slog.Logger, bytesInFlight int) { + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:metrics_updated", + slog.Int("bytes_in_flight", bytesInFlight), + ) +} diff --git a/src/vendor/golang.org/x/net/quic/math.go b/src/vendor/golang.org/x/net/quic/math.go new file mode 100644 index 0000000000..d1e8a80025 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/math.go @@ -0,0 +1,12 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +func abs[T ~int | ~int64](a T) T { + if a < 0 { + return -a + } + return a +} diff --git a/src/vendor/golang.org/x/net/quic/pacer.go b/src/vendor/golang.org/x/net/quic/pacer.go new file mode 100644 index 0000000000..5891f42597 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/pacer.go @@ -0,0 +1,129 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// A pacerState controls the rate at which packets are sent using a leaky-bucket rate limiter. +// +// The pacer limits the maximum size of a burst of packets. +// When a burst exceeds this limit, it spreads subsequent packets +// over time. +// +// The bucket is initialized to the maximum burst size (ten packets by default), +// and fills at the rate: +// +// 1.25 * congestion_window / smoothed_rtt +// +// A sender can send one congestion window of packets per RTT, +// since the congestion window consumed by each packet is returned +// one round-trip later by the responding ack. +// The pacer permits sending at slightly faster than this rate to +// avoid underutilizing the congestion window. +// +// The pacer permits the bucket to become negative, and permits +// sending when non-negative. This biases slightly in favor of +// sending packets over limiting them, and permits bursts one +// packet greater than the configured maximum, but permits the pacer +// to be ignorant of the maximum packet size. +// +// https://www.rfc-editor.org/rfc/rfc9002.html#section-7.7 +type pacerState struct { + bucket int // measured in bytes + maxBucket int + timerGranularity time.Duration + lastUpdate time.Time + nextSend time.Time +} + +func (p *pacerState) init(now time.Time, maxBurst int, timerGranularity time.Duration) { + // Bucket is limited to maximum burst size, which is the initial congestion window. + // https://www.rfc-editor.org/rfc/rfc9002#section-7.7-2 + p.maxBucket = maxBurst + p.bucket = p.maxBucket + p.timerGranularity = timerGranularity + p.lastUpdate = now + p.nextSend = now +} + +// pacerBytesForInterval returns the number of bytes permitted over an interval. +// +// rate = 1.25 * congestion_window / smoothed_rtt +// bytes = interval * rate +// +// https://www.rfc-editor.org/rfc/rfc9002#section-7.7-6 +func pacerBytesForInterval(interval time.Duration, congestionWindow int, rtt time.Duration) int { + bytes := (int64(interval) * int64(congestionWindow)) / int64(rtt) + bytes = (bytes * 5) / 4 // bytes *= 1.25 + return int(bytes) +} + +// pacerIntervalForBytes returns the amount of time required for a number of bytes. +// +// time_per_byte = (smoothed_rtt / congestion_window) / 1.25 +// interval = time_per_byte * bytes +// +// https://www.rfc-editor.org/rfc/rfc9002#section-7.7-8 +func pacerIntervalForBytes(bytes int, congestionWindow int, rtt time.Duration) time.Duration { + interval := (int64(rtt) * int64(bytes)) / int64(congestionWindow) + interval = (interval * 4) / 5 // interval /= 1.25 + return time.Duration(interval) +} + +// advance is called when time passes. +func (p *pacerState) advance(now time.Time, congestionWindow int, rtt time.Duration) { + elapsed := now.Sub(p.lastUpdate) + if elapsed < 0 { + // Time has gone backward? + elapsed = 0 + p.nextSend = now // allow a packet through to get back on track + if p.bucket < 0 { + p.bucket = 0 + } + } + p.lastUpdate = now + if rtt == 0 { + // Avoid divide by zero in the implausible case that we measure no RTT. + p.bucket = p.maxBucket + return + } + // Refill the bucket. + delta := pacerBytesForInterval(elapsed, congestionWindow, rtt) + p.bucket = min(p.bucket+delta, p.maxBucket) +} + +// packetSent is called to record transmission of a packet. +func (p *pacerState) packetSent(now time.Time, size, congestionWindow int, rtt time.Duration) { + p.bucket -= size + if p.bucket < -congestionWindow { + // Never allow the bucket to fall more than one congestion window in arrears. + // We can only fall this far behind if the sender is sending unpaced packets, + // the congestion window has been exceeded, or the RTT is less than the + // timer granularity. + // + // Limiting the minimum bucket size limits the maximum pacer delay + // to RTT/1.25. + p.bucket = -congestionWindow + } + if p.bucket >= 0 { + p.nextSend = now + return + } + // Next send occurs when the bucket has refilled to 0. + delay := pacerIntervalForBytes(-p.bucket, congestionWindow, rtt) + p.nextSend = now.Add(delay) +} + +// canSend reports whether a packet can be sent now. +// If it returns false, next is the time when the next packet can be sent. +func (p *pacerState) canSend(now time.Time) (canSend bool, next time.Time) { + // If the next send time is within the timer granularity, send immediately. + if p.nextSend.After(now.Add(p.timerGranularity)) { + return false, p.nextSend + } + return true, time.Time{} +} diff --git a/src/vendor/golang.org/x/net/quic/packet.go b/src/vendor/golang.org/x/net/quic/packet.go new file mode 100644 index 0000000000..b9fa333c54 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet.go @@ -0,0 +1,267 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "encoding/binary" + "fmt" + + "golang.org/x/net/internal/quic/quicwire" +) + +// packetType is a QUIC packet type. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17 +type packetType byte + +const ( + packetTypeInvalid = packetType(iota) + packetTypeInitial + packetType0RTT + packetTypeHandshake + packetTypeRetry + packetType1RTT + packetTypeVersionNegotiation +) + +func (p packetType) String() string { + switch p { + case packetTypeInitial: + return "Initial" + case packetType0RTT: + return "0-RTT" + case packetTypeHandshake: + return "Handshake" + case packetTypeRetry: + return "Retry" + case packetType1RTT: + return "1-RTT" + } + return fmt.Sprintf("unknown packet type %v", byte(p)) +} + +func (p packetType) qlogString() string { + switch p { + case packetTypeInitial: + return "initial" + case packetType0RTT: + return "0RTT" + case packetTypeHandshake: + return "handshake" + case packetTypeRetry: + return "retry" + case packetType1RTT: + return "1RTT" + } + return "unknown" +} + +// Bits set in the first byte of a packet. +const ( + headerFormLong = 0x80 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.2.1 + headerFormShort = 0x00 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.3.1-4.2.1 + fixedBit = 0x40 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.4.1 + reservedLongBits = 0x0c // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1 + reserved1RTTBits = 0x18 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1 + keyPhaseBit = 0x04 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.10.1 +) + +// Long Packet Type bits. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.6.1 +const ( + longPacketTypeInitial = 0 << 4 + longPacketType0RTT = 1 << 4 + longPacketTypeHandshake = 2 << 4 + longPacketTypeRetry = 3 << 4 +) + +// Frame types. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-19 +const ( + frameTypePadding = 0x00 + frameTypePing = 0x01 + frameTypeAck = 0x02 + frameTypeAckECN = 0x03 + frameTypeResetStream = 0x04 + frameTypeStopSending = 0x05 + frameTypeCrypto = 0x06 + frameTypeNewToken = 0x07 + frameTypeStreamBase = 0x08 // low three bits carry stream flags + frameTypeMaxData = 0x10 + frameTypeMaxStreamData = 0x11 + frameTypeMaxStreamsBidi = 0x12 + frameTypeMaxStreamsUni = 0x13 + frameTypeDataBlocked = 0x14 + frameTypeStreamDataBlocked = 0x15 + frameTypeStreamsBlockedBidi = 0x16 + frameTypeStreamsBlockedUni = 0x17 + frameTypeNewConnectionID = 0x18 + frameTypeRetireConnectionID = 0x19 + frameTypePathChallenge = 0x1a + frameTypePathResponse = 0x1b + frameTypeConnectionCloseTransport = 0x1c + frameTypeConnectionCloseApplication = 0x1d + frameTypeHandshakeDone = 0x1e +) + +// The low three bits of STREAM frames. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.8 +const ( + streamOffBit = 0x04 + streamLenBit = 0x02 + streamFinBit = 0x01 +) + +// Maximum length of a connection ID. +const maxConnIDLen = 20 + +// isLongHeader returns true if b is the first byte of a long header. +func isLongHeader(b byte) bool { + return b&headerFormLong == headerFormLong +} + +// getPacketType returns the type of a packet. +func getPacketType(b []byte) packetType { + if len(b) == 0 { + return packetTypeInvalid + } + if !isLongHeader(b[0]) { + if b[0]&fixedBit != fixedBit { + return packetTypeInvalid + } + return packetType1RTT + } + if len(b) < 5 { + return packetTypeInvalid + } + if b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 { + // Version Negotiation packets don't necessarily set the fixed bit. + return packetTypeVersionNegotiation + } + if b[0]&fixedBit != fixedBit { + return packetTypeInvalid + } + switch b[0] & 0x30 { + case longPacketTypeInitial: + return packetTypeInitial + case longPacketType0RTT: + return packetType0RTT + case longPacketTypeHandshake: + return packetTypeHandshake + case longPacketTypeRetry: + return packetTypeRetry + } + return packetTypeInvalid +} + +// dstConnIDForDatagram returns the destination connection ID field of the +// first QUIC packet in a datagram. +func dstConnIDForDatagram(pkt []byte) (id []byte, ok bool) { + if len(pkt) < 1 { + return nil, false + } + var n int + var b []byte + if isLongHeader(pkt[0]) { + if len(pkt) < 6 { + return nil, false + } + n = int(pkt[5]) + b = pkt[6:] + } else { + n = connIDLen + b = pkt[1:] + } + if len(b) < n { + return nil, false + } + return b[:n], true +} + +// parseVersionNegotiation parses a Version Negotiation packet. +// The returned versions is a slice of big-endian uint32s. +// It returns (nil, nil, nil) for an invalid packet. +func parseVersionNegotiation(pkt []byte) (dstConnID, srcConnID, versions []byte) { + p, ok := parseGenericLongHeaderPacket(pkt) + if !ok { + return nil, nil, nil + } + if len(p.data)%4 != 0 { + return nil, nil, nil + } + return p.dstConnID, p.srcConnID, p.data +} + +// appendVersionNegotiation appends a Version Negotiation packet to pkt, +// returning the result. +func appendVersionNegotiation(pkt, dstConnID, srcConnID []byte, versions ...uint32) []byte { + pkt = append(pkt, headerFormLong|fixedBit) // header byte + pkt = append(pkt, 0, 0, 0, 0) // Version (0 for Version Negotiation) + pkt = quicwire.AppendUint8Bytes(pkt, dstConnID) // Destination Connection ID + pkt = quicwire.AppendUint8Bytes(pkt, srcConnID) // Source Connection ID + for _, v := range versions { + pkt = binary.BigEndian.AppendUint32(pkt, v) // Supported Version + } + return pkt +} + +// A longPacket is a long header packet. +type longPacket struct { + ptype packetType + version uint32 + num packetNumber + dstConnID []byte + srcConnID []byte + payload []byte + + // The extra data depends on the packet type: + // Initial: Token. + // Retry: Retry token and integrity tag. + extra []byte +} + +// A shortPacket is a short header (1-RTT) packet. +type shortPacket struct { + num packetNumber + payload []byte +} + +// A genericLongPacket is a long header packet of an arbitrary QUIC version. +// https://www.rfc-editor.org/rfc/rfc8999#section-5.1 +type genericLongPacket struct { + version uint32 + dstConnID []byte + srcConnID []byte + data []byte +} + +func parseGenericLongHeaderPacket(b []byte) (p genericLongPacket, ok bool) { + if len(b) < 5 || !isLongHeader(b[0]) { + return genericLongPacket{}, false + } + b = b[1:] + // Version (32), + var n int + p.version, n = quicwire.ConsumeUint32(b) + if n < 0 { + return genericLongPacket{}, false + } + b = b[n:] + // Destination Connection ID Length (8), + // Destination Connection ID (0..2048), + p.dstConnID, n = quicwire.ConsumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 2048/8 { + return genericLongPacket{}, false + } + b = b[n:] + // Source Connection ID Length (8), + // Source Connection ID (0..2048), + p.srcConnID, n = quicwire.ConsumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 2048/8 { + return genericLongPacket{}, false + } + b = b[n:] + p.data = b + return p, true +} diff --git a/src/vendor/golang.org/x/net/quic/packet_number.go b/src/vendor/golang.org/x/net/quic/packet_number.go new file mode 100644 index 0000000000..9e9f0ad003 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet_number.go @@ -0,0 +1,72 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A packetNumber is a QUIC packet number. +// Packet numbers are integers in the range [0, 2^62-1]. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-12.3 +type packetNumber int64 + +const maxPacketNumber = 1<<62 - 1 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-1 + +// decodePacketNumber decodes a truncated packet number, given +// the largest acknowledged packet number in this number space, +// the truncated number received in a packet, and the size of the +// number received in bytes. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1 +// https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3 +func decodePacketNumber(largest, truncated packetNumber, numLenInBytes int) packetNumber { + expected := largest + 1 + win := packetNumber(1) << (uint(numLenInBytes) * 8) + hwin := win / 2 + mask := win - 1 + candidate := (expected &^ mask) | truncated + if candidate <= expected-hwin && candidate < (1<<62)-win { + return candidate + win + } + if candidate > expected+hwin && candidate >= win { + return candidate - win + } + return candidate +} + +// appendPacketNumber appends an encoded packet number to b. +// The packet number must be larger than the largest acknowledged packet number. +// When no packets have been acknowledged yet, largestAck is -1. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-5 +func appendPacketNumber(b []byte, pnum, largestAck packetNumber) []byte { + switch packetNumberLength(pnum, largestAck) { + case 1: + return append(b, byte(pnum)) + case 2: + return append(b, byte(pnum>>8), byte(pnum)) + case 3: + return append(b, byte(pnum>>16), byte(pnum>>8), byte(pnum)) + default: + return append(b, byte(pnum>>24), byte(pnum>>16), byte(pnum>>8), byte(pnum)) + } +} + +// packetNumberLength returns the minimum length, in bytes, needed to encode +// a packet number given the largest acknowledged packet number. +// The packet number must be larger than the largest acknowledged packet number. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-5 +func packetNumberLength(pnum, largestAck packetNumber) int { + d := pnum - largestAck + switch { + case d < 0x80: + return 1 + case d < 0x8000: + return 2 + case d < 0x800000: + return 3 + default: + return 4 + } +} diff --git a/src/vendor/golang.org/x/net/quic/packet_parser.go b/src/vendor/golang.org/x/net/quic/packet_parser.go new file mode 100644 index 0000000000..265c4aeb3a --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet_parser.go @@ -0,0 +1,517 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "golang.org/x/net/internal/quic/quicwire" + +// parseLongHeaderPacket parses a QUIC long header packet. +// +// It does not parse Version Negotiation packets. +// +// On input, pkt contains a long header packet (possibly followed by more packets), +// k the decryption keys for the packet, and pnumMax the largest packet number seen +// in the number space of this packet. +// +// parseLongHeaderPacket returns the parsed packet with protection removed +// and its length in bytes. +// +// It returns an empty packet and -1 if the packet could not be parsed. +func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p longPacket, n int) { + if len(pkt) < 5 || !isLongHeader(pkt[0]) { + return longPacket{}, -1 + } + + // Header Form (1) = 1, + // Fixed Bit (1) = 1, + // Long Packet Type (2), + // Type-Specific Bits (4), + b := pkt + p.ptype = getPacketType(b) + if p.ptype == packetTypeInvalid { + return longPacket{}, -1 + } + b = b[1:] + // Version (32), + p.version, n = quicwire.ConsumeUint32(b) + if n < 0 { + return longPacket{}, -1 + } + b = b[n:] + if p.version == 0 { + // Version Negotiation packet; not handled here. + return longPacket{}, -1 + } + + // Destination Connection ID Length (8), + // Destination Connection ID (0..160), + p.dstConnID, n = quicwire.ConsumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > maxConnIDLen { + return longPacket{}, -1 + } + b = b[n:] + + // Source Connection ID Length (8), + // Source Connection ID (0..160), + p.srcConnID, n = quicwire.ConsumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > maxConnIDLen { + return longPacket{}, -1 + } + b = b[n:] + + switch p.ptype { + case packetTypeInitial: + // Token Length (i), + // Token (..), + p.extra, n = quicwire.ConsumeVarintBytes(b) + if n < 0 { + return longPacket{}, -1 + } + b = b[n:] + case packetTypeRetry: + // Retry Token (..), + // Retry Integrity Tag (128), + p.extra = b + return p, len(pkt) + } + + // Length (i), + payLen, n := quicwire.ConsumeVarint(b) + if n < 0 { + return longPacket{}, -1 + } + b = b[n:] + if uint64(len(b)) < payLen { + return longPacket{}, -1 + } + + // Packet Number (8..32), + // Packet Payload (..), + pnumOff := len(pkt) - len(b) + pkt = pkt[:pnumOff+int(payLen)] + + if k.isSet() { + var err error + p.payload, p.num, err = k.unprotect(pkt, pnumOff, pnumMax) + if err != nil { + return longPacket{}, -1 + } + } + return p, len(pkt) +} + +// skipLongHeaderPacket returns the length of the long header packet at the start of pkt, +// or -1 if the buffer does not contain a valid packet. +func skipLongHeaderPacket(pkt []byte) int { + // Header byte, 4 bytes of version. + n := 5 + if len(pkt) <= n { + return -1 + } + // Destination connection ID length, destination connection ID. + n += 1 + int(pkt[n]) + if len(pkt) <= n { + return -1 + } + // Source connection ID length, source connection ID. + n += 1 + int(pkt[n]) + if len(pkt) <= n { + return -1 + } + if getPacketType(pkt) == packetTypeInitial { + // Token length, token. + _, nn := quicwire.ConsumeVarintBytes(pkt[n:]) + if nn < 0 { + return -1 + } + n += nn + } + // Length, packet number, payload. + _, nn := quicwire.ConsumeVarintBytes(pkt[n:]) + if nn < 0 { + return -1 + } + n += nn + if len(pkt) < n { + return -1 + } + return n +} + +// parse1RTTPacket parses a QUIC 1-RTT (short header) packet. +// +// On input, pkt contains a short header packet, k the decryption keys for the packet, +// and pnumMax the largest packet number seen in the number space of this packet. +func parse1RTTPacket(pkt []byte, k *updatingKeyPair, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, err error) { + pay, pnum, err := k.unprotect(pkt, 1+dstConnIDLen, pnumMax) + if err != nil { + return shortPacket{}, err + } + p.num = pnum + p.payload = pay + return p, nil +} + +// Consume functions return n=-1 on conditions which result in FRAME_ENCODING_ERROR, +// which includes both general parse failures and specific violations of frame +// constraints. + +func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumber)) (largest packetNumber, ackDelay unscaledAckDelay, ecn ecnCounts, n int) { + b := frame[1:] // type + + largestAck, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + + v, n := quicwire.ConsumeVarintInt64(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + ackDelay = unscaledAckDelay(v) + + ackRangeCount, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + + rangeMax := packetNumber(largestAck) + for i := uint64(0); ; i++ { + rangeLen, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + rangeMin := rangeMax - packetNumber(rangeLen) + if rangeMin < 0 || rangeMin > rangeMax { + return 0, 0, ecnCounts{}, -1 + } + f(int(i), rangeMin, rangeMax+1) + + if i == ackRangeCount { + break + } + + gap, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + + rangeMax = rangeMin - packetNumber(gap) - 2 + } + + if frame[0] != frameTypeAckECN { + return packetNumber(largestAck), ackDelay, ecnCounts{}, len(frame) - len(b) + } + + ect0Count, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + ect1Count, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + ecnCECount, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + + ecn.t0 = int(ect0Count) + ecn.t1 = int(ect1Count) + ecn.ce = int(ecnCECount) + + return packetNumber(largestAck), ackDelay, ecn, len(frame) - len(b) +} + +func consumeResetStreamFrame(b []byte) (id streamID, code uint64, finalSize int64, n int) { + n = 1 + idInt, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, 0, -1 + } + n += nn + code, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, 0, -1 + } + n += nn + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, 0, -1 + } + n += nn + finalSize = int64(v) + return streamID(idInt), code, finalSize, n +} + +func consumeStopSendingFrame(b []byte) (id streamID, code uint64, n int) { + n = 1 + idInt, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + code, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + return streamID(idInt), code, n +} + +func consumeCryptoFrame(b []byte) (off int64, data []byte, n int) { + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, nil, -1 + } + off = int64(v) + n += nn + data, nn = quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, nil, -1 + } + n += nn + return off, data, n +} + +func consumeNewTokenFrame(b []byte) (token []byte, n int) { + n = 1 + data, nn := quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return nil, -1 + } + if len(data) == 0 { + return nil, -1 + } + n += nn + return data, n +} + +func consumeStreamFrame(b []byte) (id streamID, off int64, fin bool, data []byte, n int) { + fin = (b[0] & 0x01) != 0 + n = 1 + idInt, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, false, nil, -1 + } + n += nn + if b[0]&0x04 != 0 { + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, false, nil, -1 + } + n += nn + off = int64(v) + } + if b[0]&0x02 != 0 { + data, nn = quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, 0, false, nil, -1 + } + n += nn + } else { + data = b[n:] + n += len(data) + } + if off+int64(len(data)) >= 1<<62 { + return 0, 0, false, nil, -1 + } + return streamID(idInt), off, fin, data, n +} + +func consumeMaxDataFrame(b []byte) (max int64, n int) { + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return int64(v), n +} + +func consumeMaxStreamDataFrame(b []byte) (id streamID, max int64, n int) { + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + id = streamID(v) + v, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + max = int64(v) + return id, max, n +} + +func consumeMaxStreamsFrame(b []byte) (typ streamType, max int64, n int) { + switch b[0] { + case frameTypeMaxStreamsBidi: + typ = bidiStream + case frameTypeMaxStreamsUni: + typ = uniStream + default: + return 0, 0, -1 + } + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + if v > maxStreamsLimit { + return 0, 0, -1 + } + return typ, int64(v), n +} + +func consumeStreamDataBlockedFrame(b []byte) (id streamID, max int64, n int) { + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + id = streamID(v) + max, nn = quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + return id, max, n +} + +func consumeDataBlockedFrame(b []byte) (max int64, n int) { + n = 1 + max, nn := quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return max, n +} + +func consumeStreamsBlockedFrame(b []byte) (typ streamType, max int64, n int) { + if b[0] == frameTypeStreamsBlockedBidi { + typ = bidiStream + } else { + typ = uniStream + } + n = 1 + max, nn := quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + return typ, max, n +} + +func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, resetToken statelessResetToken, n int) { + n = 1 + var nn int + seq, nn = quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, nil, statelessResetToken{}, -1 + } + n += nn + retire, nn = quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, nil, statelessResetToken{}, -1 + } + n += nn + if seq < retire { + return 0, 0, nil, statelessResetToken{}, -1 + } + connID, nn = quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, 0, nil, statelessResetToken{}, -1 + } + if len(connID) < 1 || len(connID) > 20 { + return 0, 0, nil, statelessResetToken{}, -1 + } + n += nn + if len(b[n:]) < len(resetToken) { + return 0, 0, nil, statelessResetToken{}, -1 + } + copy(resetToken[:], b[n:]) + n += len(resetToken) + return seq, retire, connID, resetToken, n +} + +func consumeRetireConnectionIDFrame(b []byte) (seq int64, n int) { + n = 1 + var nn int + seq, nn = quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return seq, n +} + +func consumePathChallengeFrame(b []byte) (data pathChallengeData, n int) { + n = 1 + nn := copy(data[:], b[n:]) + if nn != len(data) { + return data, -1 + } + n += nn + return data, n +} + +func consumePathResponseFrame(b []byte) (data pathChallengeData, n int) { + return consumePathChallengeFrame(b) // identical frame format +} + +func consumeConnectionCloseTransportFrame(b []byte) (code transportError, frameType uint64, reason string, n int) { + n = 1 + var nn int + var codeInt uint64 + codeInt, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, "", -1 + } + code = transportError(codeInt) + n += nn + frameType, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, "", -1 + } + n += nn + reasonb, nn := quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, 0, "", -1 + } + n += nn + reason = string(reasonb) + return code, frameType, reason, n +} + +func consumeConnectionCloseApplicationFrame(b []byte) (code uint64, reason string, n int) { + n = 1 + var nn int + code, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, "", -1 + } + n += nn + reasonb, nn := quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, "", -1 + } + n += nn + reason = string(reasonb) + return code, reason, n +} diff --git a/src/vendor/golang.org/x/net/quic/packet_protection.go b/src/vendor/golang.org/x/net/quic/packet_protection.go new file mode 100644 index 0000000000..7856d6b5d8 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet_protection.go @@ -0,0 +1,539 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "crypto/tls" + "errors" + "hash" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/hkdf" +) + +var errInvalidPacket = errors.New("quic: invalid packet") + +// headerProtectionSampleSize is the size of the ciphertext sample used for header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.2 +const headerProtectionSampleSize = 16 + +// aeadOverhead is the difference in size between the AEAD output and input. +// All cipher suites defined for use with QUIC have 16 bytes of overhead. +const aeadOverhead = 16 + +// A headerKey applies or removes header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4 +type headerKey struct { + hp headerProtection +} + +func (k headerKey) isSet() bool { + return k.hp != nil +} + +func (k *headerKey) init(suite uint16, secret []byte) { + h, keySize := hashForSuite(suite) + hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keySize) + switch suite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + c, err := aes.NewCipher(hpKey) + if err != nil { + panic(err) + } + k.hp = &aesHeaderProtection{cipher: c} + case tls.TLS_CHACHA20_POLY1305_SHA256: + k.hp = chaCha20HeaderProtection{hpKey} + default: + panic("BUG: unknown cipher suite") + } +} + +// protect applies header protection. +// pnumOff is the offset of the packet number in the packet. +func (k headerKey) protect(hdr []byte, pnumOff int) { + // Apply header protection. + pnumSize := int(hdr[0]&0x03) + 1 + sample := hdr[pnumOff+4:][:headerProtectionSampleSize] + mask := k.hp.headerProtection(sample) + if isLongHeader(hdr[0]) { + hdr[0] ^= mask[0] & 0x0f + } else { + hdr[0] ^= mask[0] & 0x1f + } + for i := 0; i < pnumSize; i++ { + hdr[pnumOff+i] ^= mask[1+i] + } +} + +// unprotect removes header protection. +// pnumOff is the offset of the packet number in the packet. +// pnumMax is the largest packet number seen in the number space of this packet. +func (k headerKey) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (hdr, pay []byte, pnum packetNumber, _ error) { + if len(pkt) < pnumOff+4+headerProtectionSampleSize { + return nil, nil, 0, errInvalidPacket + } + numpay := pkt[pnumOff:] + sample := numpay[4:][:headerProtectionSampleSize] + mask := k.hp.headerProtection(sample) + if isLongHeader(pkt[0]) { + pkt[0] ^= mask[0] & 0x0f + } else { + pkt[0] ^= mask[0] & 0x1f + } + pnumLen := int(pkt[0]&0x03) + 1 + pnum = packetNumber(0) + for i := 0; i < pnumLen; i++ { + numpay[i] ^= mask[1+i] + pnum = (pnum << 8) | packetNumber(numpay[i]) + } + pnum = decodePacketNumber(pnumMax, pnum, pnumLen) + hdr = pkt[:pnumOff+pnumLen] + pay = numpay[pnumLen:] + return hdr, pay, pnum, nil +} + +// headerProtection is the header_protection function as defined in: +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1 +// +// This function takes a sample of the packet ciphertext +// and returns a 5-byte mask which will be applied to the +// protected portions of the packet header. +type headerProtection interface { + headerProtection(sample []byte) (mask [5]byte) +} + +// AES-based header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3 +type aesHeaderProtection struct { + cipher cipher.Block + scratch [aes.BlockSize]byte +} + +func (hp *aesHeaderProtection) headerProtection(sample []byte) (mask [5]byte) { + hp.cipher.Encrypt(hp.scratch[:], sample) + copy(mask[:], hp.scratch[:]) + return mask +} + +// ChaCha20-based header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4 +type chaCha20HeaderProtection struct { + key []byte +} + +func (hp chaCha20HeaderProtection) headerProtection(sample []byte) (mask [5]byte) { + counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0]) + nonce := sample[4:16] + c, err := chacha20.NewUnauthenticatedCipher(hp.key, nonce) + if err != nil { + panic(err) + } + c.SetCounter(counter) + c.XORKeyStream(mask[:], mask[:]) + return mask +} + +// A packetKey applies or removes packet protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.1 +type packetKey struct { + aead cipher.AEAD // AEAD function used for packet protection. + iv []byte // IV used to construct the AEAD nonce. +} + +func (k *packetKey) init(suite uint16, secret []byte) { + // https://www.rfc-editor.org/rfc/rfc9001#section-5.1 + h, keySize := hashForSuite(suite) + key := hkdfExpandLabel(h.New, secret, "quic key", nil, keySize) + switch suite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + k.aead = newAESAEAD(key) + case tls.TLS_CHACHA20_POLY1305_SHA256: + k.aead = newChaCha20AEAD(key) + default: + panic("BUG: unknown cipher suite") + } + k.iv = hkdfExpandLabel(h.New, secret, "quic iv", nil, k.aead.NonceSize()) +} + +func newAESAEAD(key []byte) cipher.AEAD { + c, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(c) + if err != nil { + panic(err) + } + return aead +} + +func newChaCha20AEAD(key []byte) cipher.AEAD { + var err error + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + return aead +} + +func (k packetKey) protect(hdr, pay []byte, pnum packetNumber) []byte { + k.xorIV(pnum) + defer k.xorIV(pnum) + return k.aead.Seal(hdr, k.iv, pay, hdr) +} + +func (k packetKey) unprotect(hdr, pay []byte, pnum packetNumber) (dec []byte, err error) { + k.xorIV(pnum) + defer k.xorIV(pnum) + return k.aead.Open(pay[:0], k.iv, pay, hdr) +} + +// xorIV xors the packet protection IV with the packet number. +func (k packetKey) xorIV(pnum packetNumber) { + k.iv[len(k.iv)-8] ^= uint8(pnum >> 56) + k.iv[len(k.iv)-7] ^= uint8(pnum >> 48) + k.iv[len(k.iv)-6] ^= uint8(pnum >> 40) + k.iv[len(k.iv)-5] ^= uint8(pnum >> 32) + k.iv[len(k.iv)-4] ^= uint8(pnum >> 24) + k.iv[len(k.iv)-3] ^= uint8(pnum >> 16) + k.iv[len(k.iv)-2] ^= uint8(pnum >> 8) + k.iv[len(k.iv)-1] ^= uint8(pnum) +} + +// A fixedKeys is a header protection key and fixed packet protection key. +// The packet protection key is fixed (it does not update). +// +// Fixed keys are used for Initial and Handshake keys, which do not update. +type fixedKeys struct { + hdr headerKey + pkt packetKey +} + +func (k *fixedKeys) init(suite uint16, secret []byte) { + k.hdr.init(suite, secret) + k.pkt.init(suite, secret) +} + +func (k fixedKeys) isSet() bool { + return k.hdr.hp != nil +} + +// protect applies packet protection to a packet. +// +// On input, hdr contains the packet header, pay the unencrypted payload, +// pnumOff the offset of the packet number in the header, and pnum the untruncated +// packet number. +// +// protect returns the result of appending the encrypted payload to hdr and +// applying header protection. +func (k fixedKeys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte { + pkt := k.pkt.protect(hdr, pay, pnum) + k.hdr.protect(pkt, pnumOff) + return pkt +} + +// unprotect removes packet protection from a packet. +// +// On input, pkt contains the full protected packet, pnumOff the offset of +// the packet number in the header, and pnumMax the largest packet number +// seen in the number space of this packet. +// +// unprotect removes header protection from the header in pkt, and returns +// the unprotected payload and packet number. +func (k fixedKeys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) { + hdr, pay, pnum, err := k.hdr.unprotect(pkt, pnumOff, pnumMax) + if err != nil { + return nil, 0, err + } + pay, err = k.pkt.unprotect(hdr, pay, pnum) + if err != nil { + return nil, 0, err + } + return pay, pnum, nil +} + +// A fixedKeyPair is a read/write pair of fixed keys. +type fixedKeyPair struct { + r, w fixedKeys +} + +func (k *fixedKeyPair) discard() { + *k = fixedKeyPair{} +} + +func (k *fixedKeyPair) canRead() bool { + return k.r.isSet() +} + +func (k *fixedKeyPair) canWrite() bool { + return k.w.isSet() +} + +// An updatingKeys is a header protection key and updatable packet protection key. +// updatingKeys are used for 1-RTT keys, where the packet protection key changes +// over the lifetime of a connection. +// https://www.rfc-editor.org/rfc/rfc9001#section-6 +type updatingKeys struct { + suite uint16 + hdr headerKey + pkt [2]packetKey // current, next + nextSecret []byte // secret used to generate pkt[1] +} + +func (k *updatingKeys) init(suite uint16, secret []byte) { + k.suite = suite + k.hdr.init(suite, secret) + // Initialize pkt[1] with secret_0, and then call update to generate secret_1. + k.pkt[1].init(suite, secret) + k.nextSecret = secret + k.update() +} + +// update performs a key update. +// The current key in pkt[0] is discarded. +// The next key in pkt[1] becomes the current key. +// A new next key is generated in pkt[1]. +func (k *updatingKeys) update() { + k.nextSecret = updateSecret(k.suite, k.nextSecret) + k.pkt[0] = k.pkt[1] + k.pkt[1].init(k.suite, k.nextSecret) +} + +func updateSecret(suite uint16, secret []byte) (nextSecret []byte) { + h, _ := hashForSuite(suite) + return hkdfExpandLabel(h.New, secret, "quic ku", nil, len(secret)) +} + +// An updatingKeyPair is a read/write pair of updating keys. +// +// We keep two keys (current and next) in both read and write directions. +// When an incoming packet's phase matches the current phase bit, +// we unprotect it using the current keys; otherwise we use the next keys. +// +// When updating=false, outgoing packets are protected using the current phase. +// +// An update is initiated and updating is set to true when: +// - we decide to initiate a key update; or +// - we successfully unprotect a packet using the next keys, +// indicating the peer has initiated a key update. +// +// When updating=true, outgoing packets are protected using the next phase. +// We do not change the current phase bit or generate new keys yet. +// +// The update concludes when we receive an ACK frame for a packet sent +// with the next keys. At this time, we set updating to false, flip the +// phase bit, and update the keys. This permits us to handle up to 1-RTT +// of reordered packets before discarding the previous phase's keys after +// an update. +type updatingKeyPair struct { + phase uint8 // current key phase (r.pkt[0], w.pkt[0]) + updating bool + authFailures int64 // total packet unprotect failures + minSent packetNumber // min packet number sent since entering the updating state + minReceived packetNumber // min packet number received in the next phase + updateAfter packetNumber // packet number after which to initiate key update + r, w updatingKeys +} + +func (k *updatingKeyPair) init() { + // 1-RTT packets until the first key update. + // + // We perform the first key update early in the connection so a peer + // which does not support key updates will fail rapidly, + // rather than after the connection has been long established. + // + // The QUIC interop runner "keyupdate" test requires that the client + // initiate a key rotation early in the connection. Increasing this + // value may cause interop test failures; if we do want to increase it, + // we should either skip the keyupdate test or provide a way to override + // the setting in interop tests. + k.updateAfter = 100 +} + +func (k *updatingKeyPair) canRead() bool { + return k.r.hdr.hp != nil +} + +func (k *updatingKeyPair) canWrite() bool { + return k.w.hdr.hp != nil +} + +// handleAckFor finishes a key update after receiving an ACK for a packet in the next phase. +func (k *updatingKeyPair) handleAckFor(pnum packetNumber) { + if k.updating && pnum >= k.minSent { + k.updating = false + k.phase ^= keyPhaseBit + k.r.update() + k.w.update() + } +} + +// needAckEliciting reports whether we should send an ack-eliciting packet in the next phase. +// The first packet sent in a phase is ack-eliciting, since the peer must acknowledge a +// packet in the new phase for us to finish the update. +func (k *updatingKeyPair) needAckEliciting() bool { + return k.updating && k.minSent == maxPacketNumber +} + +// protect applies packet protection to a packet. +// Parameters and returns are as for fixedKeyPair.protect. +func (k *updatingKeyPair) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte { + var pkt []byte + if k.updating { + hdr[0] |= k.phase ^ keyPhaseBit + pkt = k.w.pkt[1].protect(hdr, pay, pnum) + k.minSent = min(pnum, k.minSent) + } else { + hdr[0] |= k.phase + pkt = k.w.pkt[0].protect(hdr, pay, pnum) + if pnum >= k.updateAfter { + // Initiate a key update, starting with the next packet we send. + // + // We do this after protecting the current packet + // to allow Conn.appendFrames to ensure that the first packet sent + // in the new phase is ack-eliciting. + k.updating = true + k.minSent = maxPacketNumber + k.minReceived = maxPacketNumber + // The lowest confidentiality limit for a supported AEAD is 2^23 packets. + // https://www.rfc-editor.org/rfc/rfc9001#section-6.6-5 + // + // Schedule our next update for half that. + k.updateAfter += (1 << 22) + } + } + k.w.hdr.protect(pkt, pnumOff) + return pkt +} + +// unprotect removes packet protection from a packet. +// Parameters and returns are as for fixedKeyPair.unprotect. +func (k *updatingKeyPair) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, pnum packetNumber, err error) { + hdr, pay, pnum, err := k.r.hdr.unprotect(pkt, pnumOff, pnumMax) + if err != nil { + return nil, 0, err + } + // To avoid timing signals that might indicate the key phase bit is invalid, + // we always attempt to unprotect the packet with one key. + // + // If the key phase bit matches and the packet number doesn't come after + // the start of an in-progress update, use the current phase. + // Otherwise, use the next phase. + if hdr[0]&keyPhaseBit == k.phase && (!k.updating || pnum < k.minReceived) { + pay, err = k.r.pkt[0].unprotect(hdr, pay, pnum) + } else { + pay, err = k.r.pkt[1].unprotect(hdr, pay, pnum) + if err == nil { + if !k.updating { + // The peer has initiated a key update. + k.updating = true + k.minSent = maxPacketNumber + k.minReceived = pnum + } else { + k.minReceived = min(pnum, k.minReceived) + } + } + } + if err != nil { + k.authFailures++ + if k.authFailures >= aeadIntegrityLimit(k.r.suite) { + return nil, 0, localTransportError{code: errAEADLimitReached} + } + return nil, 0, err + } + return pay, pnum, nil +} + +// aeadIntegrityLimit returns the integrity limit for an AEAD: +// The maximum number of received packets that may fail authentication +// before closing the connection. +// +// https://www.rfc-editor.org/rfc/rfc9001#section-6.6-4 +func aeadIntegrityLimit(suite uint16) int64 { + switch suite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + return 1 << 52 + case tls.TLS_CHACHA20_POLY1305_SHA256: + return 1 << 36 + default: + panic("BUG: unknown cipher suite") + } +} + +// https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2 +var initialSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + +// initialKeys returns the keys used to protect Initial packets. +// +// The Initial packet keys are derived from the Destination Connection ID +// field in the client's first Initial packet. +// +// https://www.rfc-editor.org/rfc/rfc9001#section-5.2 +func initialKeys(cid []byte, side connSide) fixedKeyPair { + initialSecret := hkdf.Extract(sha256.New, cid, initialSalt) + var clientKeys fixedKeys + clientSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size) + clientKeys.init(tls.TLS_AES_128_GCM_SHA256, clientSecret) + var serverKeys fixedKeys + serverSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size) + serverKeys.init(tls.TLS_AES_128_GCM_SHA256, serverSecret) + if side == clientSide { + return fixedKeyPair{r: serverKeys, w: clientKeys} + } else { + return fixedKeyPair{w: serverKeys, r: clientKeys} + } +} + +// checkCipherSuite returns an error if suite is not a supported cipher suite. +func checkCipherSuite(suite uint16) error { + switch suite { + case tls.TLS_AES_128_GCM_SHA256: + case tls.TLS_AES_256_GCM_SHA384: + case tls.TLS_CHACHA20_POLY1305_SHA256: + default: + return errors.New("invalid cipher suite") + } + return nil +} + +func hashForSuite(suite uint16) (h crypto.Hash, keySize int) { + switch suite { + case tls.TLS_AES_128_GCM_SHA256: + return crypto.SHA256, 128 / 8 + case tls.TLS_AES_256_GCM_SHA384: + return crypto.SHA384, 256 / 8 + case tls.TLS_CHACHA20_POLY1305_SHA256: + return crypto.SHA256, chacha20.KeySize + default: + panic("BUG: unknown cipher suite") + } +} + +// hkdfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +// +// Copied from crypto/tls/key_schedule.go. +func hkdfExpandLabel(hash func() hash.Hash, secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(hash, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("quic: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} diff --git a/src/vendor/golang.org/x/net/quic/packet_writer.go b/src/vendor/golang.org/x/net/quic/packet_writer.go new file mode 100644 index 0000000000..f446521d2b --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet_writer.go @@ -0,0 +1,566 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "encoding/binary" + + "golang.org/x/net/internal/quic/quicwire" +) + +// A packetWriter constructs QUIC datagrams. +// +// A datagram consists of one or more packets. +// A packet consists of a header followed by one or more frames. +// +// Packets are written in three steps: +// - startProtectedLongHeaderPacket or start1RTT packet prepare the packet; +// - append*Frame appends frames to the payload; and +// - finishProtectedLongHeaderPacket or finish1RTT finalize the packet. +// +// The start functions are efficient, so we can start speculatively +// writing a packet before we know whether we have any frames to +// put in it. The finish functions will abandon the packet if the +// payload contains no data. +type packetWriter struct { + dgramLim int // max datagram size + pktLim int // max packet size + pktOff int // offset of the start of the current packet + payOff int // offset of the payload of the current packet + b []byte + sent *sentPacket +} + +// reset prepares to write a datagram of at most lim bytes. +func (w *packetWriter) reset(lim int) { + if cap(w.b) < lim { + w.b = make([]byte, 0, lim) + } + w.dgramLim = lim + w.b = w.b[:0] +} + +// datagram returns the current datagram. +func (w *packetWriter) datagram() []byte { + return w.b +} + +// packetLen returns the size of the current packet. +func (w *packetWriter) packetLen() int { + return len(w.b[w.pktOff:]) + aeadOverhead +} + +// payload returns the payload of the current packet. +func (w *packetWriter) payload() []byte { + return w.b[w.payOff:] +} + +func (w *packetWriter) abandonPacket() { + w.b = w.b[:w.payOff] + w.sent.reset() +} + +// startProtectedLongHeaderPacket starts writing an Initial, 0-RTT, or Handshake packet. +func (w *packetWriter) startProtectedLongHeaderPacket(pnumMaxAcked packetNumber, p longPacket) { + if w.sent == nil { + w.sent = newSentPacket() + } + w.pktOff = len(w.b) + hdrSize := 1 // packet type + hdrSize += 4 // version + hdrSize += 1 + len(p.dstConnID) + hdrSize += 1 + len(p.srcConnID) + switch p.ptype { + case packetTypeInitial: + hdrSize += quicwire.SizeVarint(uint64(len(p.extra))) + len(p.extra) + } + hdrSize += 2 // length, hardcoded to a 2-byte varint + pnumOff := len(w.b) + hdrSize + hdrSize += packetNumberLength(p.num, pnumMaxAcked) + payOff := len(w.b) + hdrSize + // Check if we have enough space to hold the packet, including the header, + // header protection sample (RFC 9001, section 5.4.2), and encryption overhead. + if pnumOff+4+headerProtectionSampleSize+aeadOverhead >= w.dgramLim { + // Set the limit on the packet size to be the current write buffer length, + // ensuring that any writes to the payload fail. + w.payOff = len(w.b) + w.pktLim = len(w.b) + return + } + w.payOff = payOff + w.pktLim = w.dgramLim - aeadOverhead + // We hardcode the payload length field to be 2 bytes, which limits the payload + // (including the packet number) to 16383 bytes (the largest 2-byte QUIC varint). + // + // Most networks don't support datagrams over 1472 bytes, and even Ethernet + // jumbo frames are generally only about 9000 bytes. + if lim := pnumOff + 16383 - aeadOverhead; lim < w.pktLim { + w.pktLim = lim + } + w.b = w.b[:payOff] +} + +// finishProtectedLongHeaderPacket finishes writing an Initial, 0-RTT, or Handshake packet, +// canceling the packet if it contains no payload. +// It returns a sentPacket describing the packet, or nil if no packet was written. +func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber, k fixedKeys, p longPacket) *sentPacket { + if len(w.b) == w.payOff { + // The payload is empty, so just abandon the packet. + w.b = w.b[:w.pktOff] + return nil + } + pnumLen := packetNumberLength(p.num, pnumMaxAcked) + plen := w.padPacketLength(pnumLen) + hdr := w.b[:w.pktOff] + var typeBits byte + switch p.ptype { + case packetTypeInitial: + typeBits = longPacketTypeInitial + case packetType0RTT: + typeBits = longPacketType0RTT + case packetTypeHandshake: + typeBits = longPacketTypeHandshake + case packetTypeRetry: + typeBits = longPacketTypeRetry + } + hdr = append(hdr, headerFormLong|fixedBit|typeBits|byte(pnumLen-1)) + hdr = binary.BigEndian.AppendUint32(hdr, p.version) + hdr = quicwire.AppendUint8Bytes(hdr, p.dstConnID) + hdr = quicwire.AppendUint8Bytes(hdr, p.srcConnID) + switch p.ptype { + case packetTypeInitial: + hdr = quicwire.AppendVarintBytes(hdr, p.extra) // token + } + + // Packet length, always encoded as a 2-byte varint. + hdr = append(hdr, 0x40|byte(plen>>8), byte(plen)) + + pnumOff := len(hdr) + hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked) + + k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, p.num) + return w.finish(p.ptype, p.num) +} + +// start1RTTPacket starts writing a 1-RTT (short header) packet. +func (w *packetWriter) start1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte) { + if w.sent == nil { + w.sent = newSentPacket() + } + w.pktOff = len(w.b) + hdrSize := 1 // packet type + hdrSize += len(dstConnID) + // Ensure we have enough space to hold the packet, including the header, + // header protection sample (RFC 9001, section 5.4.2), and encryption overhead. + if len(w.b)+hdrSize+4+headerProtectionSampleSize+aeadOverhead >= w.dgramLim { + w.payOff = len(w.b) + w.pktLim = len(w.b) + return + } + hdrSize += packetNumberLength(pnum, pnumMaxAcked) + w.payOff = len(w.b) + hdrSize + w.pktLim = w.dgramLim - aeadOverhead + w.b = w.b[:w.payOff] +} + +// finish1RTTPacket finishes writing a 1-RTT packet, +// canceling the packet if it contains no payload. +// It returns a sentPacket describing the packet, or nil if no packet was written. +func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k *updatingKeyPair) *sentPacket { + if len(w.b) == w.payOff { + // The payload is empty, so just abandon the packet. + w.b = w.b[:w.pktOff] + return nil + } + // TODO: Spin + pnumLen := packetNumberLength(pnum, pnumMaxAcked) + hdr := w.b[:w.pktOff] + hdr = append(hdr, 0x40|byte(pnumLen-1)) + hdr = append(hdr, dstConnID...) + pnumOff := len(hdr) + hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked) + w.padPacketLength(pnumLen) + k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, pnum) + return w.finish(packetType1RTT, pnum) +} + +// padPacketLength pads out the payload of the current packet to the minimum size, +// and returns the combined length of the packet number and payload (used for the Length +// field of long header packets). +func (w *packetWriter) padPacketLength(pnumLen int) int { + plen := len(w.b) - w.payOff + pnumLen + aeadOverhead + // "To ensure that sufficient data is available for sampling, packets are + // padded so that the combined lengths of the encoded packet number and + // protected payload is at least 4 bytes longer than the sample required + // for header protection." + // https://www.rfc-editor.org/rfc/rfc9001.html#section-5.4.2 + for plen < 4+headerProtectionSampleSize { + w.b = append(w.b, 0) + plen++ + } + return plen +} + +// finish finishes the current packet after protection is applied. +func (w *packetWriter) finish(ptype packetType, pnum packetNumber) *sentPacket { + w.b = w.b[:len(w.b)+aeadOverhead] + w.sent.size = len(w.b) - w.pktOff + w.sent.ptype = ptype + w.sent.num = pnum + sent := w.sent + w.sent = nil + return sent +} + +// avail reports how many more bytes may be written to the current packet. +func (w *packetWriter) avail() int { + return w.pktLim - len(w.b) +} + +// appendPaddingTo appends PADDING frames until the total datagram size +// (including AEAD overhead of the current packet) is n. +func (w *packetWriter) appendPaddingTo(n int) { + n -= aeadOverhead + lim := w.pktLim + if n < lim { + lim = n + } + if len(w.b) >= lim { + return + } + for len(w.b) < lim { + w.b = append(w.b, frameTypePadding) + } + // Packets are considered in flight when they contain a PADDING frame. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.6.1 + w.sent.inFlight = true +} + +func (w *packetWriter) appendPingFrame() (added bool) { + if len(w.b) >= w.pktLim { + return false + } + w.b = append(w.b, frameTypePing) + w.sent.markAckEliciting() // no need to record the frame itself + return true +} + +// appendAckFrame appends an ACK frame to the payload. +// It includes at least the most recent range in the rangeset +// (the range with the largest packet numbers), +// followed by as many additional ranges as fit within the packet. +// +// We always place ACK frames at the start of packets, +// we limit the number of ack ranges retained, and +// we set a minimum packet payload size. +// As a result, appendAckFrame will rarely if ever drop ranges +// in practice. +// +// In the event that ranges are dropped, the impact is limited +// to the peer potentially failing to receive an acknowledgement +// for an older packet during a period of high packet loss or +// reordering. This may result in unnecessary retransmissions. +func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscaledAckDelay, ecn ecnCounts) (added bool) { + if len(seen) == 0 { + return false + } + var ( + largest = uint64(seen.max()) + firstRange = uint64(seen[len(seen)-1].size() - 1) + ) + var ecnLen int + ackType := byte(frameTypeAck) + if (ecn != ecnCounts{}) { + // "Even if an endpoint does not set an ECT field in packets it sends, + // the endpoint MUST provide feedback about ECN markings it receives, if + // these are accessible." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.4.1-2 + ecnLen = quicwire.SizeVarint(uint64(ecn.ce)) + quicwire.SizeVarint(uint64(ecn.t0)) + quicwire.SizeVarint(uint64(ecn.t1)) + ackType = frameTypeAckECN + } + if w.avail() < 1+quicwire.SizeVarint(largest)+quicwire.SizeVarint(uint64(delay))+1+quicwire.SizeVarint(firstRange)+ecnLen { + return false + } + w.b = append(w.b, ackType) + w.b = quicwire.AppendVarint(w.b, largest) + w.b = quicwire.AppendVarint(w.b, uint64(delay)) + // The range count is technically a varint, but we'll reserve a single byte for it + // and never add more than 62 ranges (the maximum varint that fits in a byte). + rangeCountOff := len(w.b) + w.b = append(w.b, 0) + w.b = quicwire.AppendVarint(w.b, firstRange) + rangeCount := byte(0) + for i := len(seen) - 2; i >= 0; i-- { + gap := uint64(seen[i+1].start - seen[i].end - 1) + size := uint64(seen[i].size() - 1) + if w.avail() < quicwire.SizeVarint(gap)+quicwire.SizeVarint(size)+ecnLen || rangeCount > 62 { + break + } + w.b = quicwire.AppendVarint(w.b, gap) + w.b = quicwire.AppendVarint(w.b, size) + rangeCount++ + } + w.b[rangeCountOff] = rangeCount + if ackType == frameTypeAckECN { + w.b = quicwire.AppendVarint(w.b, uint64(ecn.t0)) + w.b = quicwire.AppendVarint(w.b, uint64(ecn.t1)) + w.b = quicwire.AppendVarint(w.b, uint64(ecn.ce)) + } + w.sent.appendNonAckElicitingFrame(ackType) + w.sent.appendInt(uint64(seen.max())) + return true +} + +func (w *packetWriter) appendNewTokenFrame(token []byte) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(len(token)))+len(token) { + return false + } + w.b = append(w.b, frameTypeNewToken) + w.b = quicwire.AppendVarintBytes(w.b, token) + return true +} + +func (w *packetWriter) appendResetStreamFrame(id streamID, code uint64, finalSize int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(code)+quicwire.SizeVarint(uint64(finalSize)) { + return false + } + w.b = append(w.b, frameTypeResetStream) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + w.b = quicwire.AppendVarint(w.b, code) + w.b = quicwire.AppendVarint(w.b, uint64(finalSize)) + w.sent.appendAckElicitingFrame(frameTypeResetStream) + w.sent.appendInt(uint64(id)) + return true +} + +func (w *packetWriter) appendStopSendingFrame(id streamID, code uint64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(code) { + return false + } + w.b = append(w.b, frameTypeStopSending) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + w.b = quicwire.AppendVarint(w.b, code) + w.sent.appendAckElicitingFrame(frameTypeStopSending) + w.sent.appendInt(uint64(id)) + return true +} + +// appendCryptoFrame appends a CRYPTO frame. +// It returns a []byte into which the data should be written and whether a frame was added. +// The returned []byte may be smaller than size if the packet cannot hold all the data. +func (w *packetWriter) appendCryptoFrame(off int64, size int) (_ []byte, added bool) { + max := w.avail() + max -= 1 // frame type + max -= quicwire.SizeVarint(uint64(off)) // offset + max -= quicwire.SizeVarint(uint64(size)) // maximum length + if max <= 0 { + return nil, false + } + if max < size { + size = max + } + w.b = append(w.b, frameTypeCrypto) + w.b = quicwire.AppendVarint(w.b, uint64(off)) + w.b = quicwire.AppendVarint(w.b, uint64(size)) + start := len(w.b) + w.b = w.b[:start+size] + w.sent.appendAckElicitingFrame(frameTypeCrypto) + w.sent.appendOffAndSize(off, size) + return w.b[start:][:size], true +} + +// appendStreamFrame appends a STREAM frame. +// It returns a []byte into which the data should be written and whether a frame was added. +// The returned []byte may be smaller than size if the packet cannot hold all the data. +func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin bool) (_ []byte, added bool) { + typ := uint8(frameTypeStreamBase | streamLenBit) + max := w.avail() + max -= 1 // frame type + max -= quicwire.SizeVarint(uint64(id)) + if off != 0 { + max -= quicwire.SizeVarint(uint64(off)) + typ |= streamOffBit + } + max -= quicwire.SizeVarint(uint64(size)) // maximum length + if max < 0 || (max == 0 && size > 0) { + return nil, false + } + if max < size { + size = max + } else if fin { + typ |= streamFinBit + } + w.b = append(w.b, typ) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + if off != 0 { + w.b = quicwire.AppendVarint(w.b, uint64(off)) + } + w.b = quicwire.AppendVarint(w.b, uint64(size)) + start := len(w.b) + w.b = w.b[:start+size] + w.sent.appendAckElicitingFrame(typ & (frameTypeStreamBase | streamFinBit)) + w.sent.appendInt(uint64(id)) + w.sent.appendOffAndSize(off, size) + return w.b[start:][:size], true +} + +func (w *packetWriter) appendMaxDataFrame(max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeMaxData) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeMaxData) + return true +} + +func (w *packetWriter) appendMaxStreamDataFrame(id streamID, max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeMaxStreamData) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeMaxStreamData) + w.sent.appendInt(uint64(id)) + return true +} + +func (w *packetWriter) appendMaxStreamsFrame(streamType streamType, max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(max)) { + return false + } + var typ byte + if streamType == bidiStream { + typ = frameTypeMaxStreamsBidi + } else { + typ = frameTypeMaxStreamsUni + } + w.b = append(w.b, typ) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(typ) + return true +} + +func (w *packetWriter) appendDataBlockedFrame(max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeDataBlocked) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeDataBlocked) + return true +} + +func (w *packetWriter) appendStreamDataBlockedFrame(id streamID, max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeStreamDataBlocked) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeStreamDataBlocked) + w.sent.appendInt(uint64(id)) + return true +} + +func (w *packetWriter) appendStreamsBlockedFrame(typ streamType, max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(max)) { + return false + } + var ftype byte + if typ == bidiStream { + ftype = frameTypeStreamsBlockedBidi + } else { + ftype = frameTypeStreamsBlockedUni + } + w.b = append(w.b, ftype) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(ftype) + return true +} + +func (w *packetWriter) appendNewConnectionIDFrame(seq, retirePriorTo int64, connID []byte, token [16]byte) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(seq))+quicwire.SizeVarint(uint64(retirePriorTo))+1+len(connID)+len(token) { + return false + } + w.b = append(w.b, frameTypeNewConnectionID) + w.b = quicwire.AppendVarint(w.b, uint64(seq)) + w.b = quicwire.AppendVarint(w.b, uint64(retirePriorTo)) + w.b = quicwire.AppendUint8Bytes(w.b, connID) + w.b = append(w.b, token[:]...) + w.sent.appendAckElicitingFrame(frameTypeNewConnectionID) + w.sent.appendInt(uint64(seq)) + return true +} + +func (w *packetWriter) appendRetireConnectionIDFrame(seq int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(seq)) { + return false + } + w.b = append(w.b, frameTypeRetireConnectionID) + w.b = quicwire.AppendVarint(w.b, uint64(seq)) + w.sent.appendAckElicitingFrame(frameTypeRetireConnectionID) + w.sent.appendInt(uint64(seq)) + return true +} + +func (w *packetWriter) appendPathChallengeFrame(data pathChallengeData) (added bool) { + if w.avail() < 1+8 { + return false + } + w.b = append(w.b, frameTypePathChallenge) + w.b = append(w.b, data[:]...) + w.sent.markAckEliciting() // no need to record the frame itself + return true +} + +func (w *packetWriter) appendPathResponseFrame(data pathChallengeData) (added bool) { + if w.avail() < 1+8 { + return false + } + w.b = append(w.b, frameTypePathResponse) + w.b = append(w.b, data[:]...) + w.sent.markAckEliciting() // no need to record the frame itself + return true +} + +// appendConnectionCloseTransportFrame appends a CONNECTION_CLOSE frame +// carrying a transport error code. +func (w *packetWriter) appendConnectionCloseTransportFrame(code transportError, frameType uint64, reason string) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(code))+quicwire.SizeVarint(frameType)+quicwire.SizeVarint(uint64(len(reason)))+len(reason) { + return false + } + w.b = append(w.b, frameTypeConnectionCloseTransport) + w.b = quicwire.AppendVarint(w.b, uint64(code)) + w.b = quicwire.AppendVarint(w.b, frameType) + w.b = quicwire.AppendVarintBytes(w.b, []byte(reason)) + // We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or + // detected as lost. + return true +} + +// appendConnectionCloseApplicationFrame appends a CONNECTION_CLOSE frame +// carrying an application protocol error code. +func (w *packetWriter) appendConnectionCloseApplicationFrame(code uint64, reason string) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(code)+quicwire.SizeVarint(uint64(len(reason)))+len(reason) { + return false + } + w.b = append(w.b, frameTypeConnectionCloseApplication) + w.b = quicwire.AppendVarint(w.b, code) + w.b = quicwire.AppendVarintBytes(w.b, []byte(reason)) + // We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or + // detected as lost. + return true +} + +func (w *packetWriter) appendHandshakeDoneFrame() (added bool) { + if w.avail() < 1 { + return false + } + w.b = append(w.b, frameTypeHandshakeDone) + w.sent.appendAckElicitingFrame(frameTypeHandshakeDone) + return true +} diff --git a/src/vendor/golang.org/x/net/quic/path.go b/src/vendor/golang.org/x/net/quic/path.go new file mode 100644 index 0000000000..5170562c74 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/path.go @@ -0,0 +1,87 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "time" + +type pathState struct { + // Response to a peer's PATH_CHALLENGE. + // This is not a sentVal, because we don't resend lost PATH_RESPONSE frames. + // We only track the most recent PATH_CHALLENGE. + // If the peer sends a second PATH_CHALLENGE before we respond to the first, + // we'll drop the first response. + sendPathResponse pathResponseType + data pathChallengeData +} + +// pathChallengeData is data carried in a PATH_CHALLENGE or PATH_RESPONSE frame. +type pathChallengeData [64 / 8]byte + +type pathResponseType uint8 + +const ( + pathResponseNotNeeded = pathResponseType(iota) + pathResponseSmall // send PATH_RESPONSE, do not expand datagram + pathResponseExpanded // send PATH_RESPONSE, expand datagram to 1200 bytes +) + +func (c *Conn) handlePathChallenge(_ time.Time, dgram *datagram, data pathChallengeData) { + // A PATH_RESPONSE is sent in a datagram expanded to 1200 bytes, + // except when this would exceed the anti-amplification limit. + // + // Rather than maintaining anti-amplification state for each path + // we may be sending a PATH_RESPONSE on, follow the following heuristic: + // + // If we receive a PATH_CHALLENGE in an expanded datagram, + // respond with an expanded datagram. + // + // If we receive a PATH_CHALLENGE in a non-expanded datagram, + // then the peer is presumably blocked by its own anti-amplification limit. + // Respond with a non-expanded datagram. Receiving this PATH_RESPONSE + // will validate the path to the peer, remove its anti-amplification limit, + // and permit it to send a followup PATH_CHALLENGE in an expanded datagram. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-8.2.1 + if len(dgram.b) >= smallestMaxDatagramSize { + c.path.sendPathResponse = pathResponseExpanded + } else { + c.path.sendPathResponse = pathResponseSmall + } + c.path.data = data +} + +func (c *Conn) handlePathResponse(now time.Time, _ pathChallengeData) { + // "If the content of a PATH_RESPONSE frame does not match the content of + // a PATH_CHALLENGE frame previously sent by the endpoint, + // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4 + // + // We never send PATH_CHALLENGE frames. + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "PATH_RESPONSE received when no PATH_CHALLENGE sent", + }) +} + +// appendPathFrames appends path validation related frames to the current packet. +// If the return value pad is true, then the packet should be padded to 1200 bytes. +func (c *Conn) appendPathFrames() (pad, ok bool) { + if c.path.sendPathResponse == pathResponseNotNeeded { + return pad, true + } + // We're required to send the PATH_RESPONSE on the path where the + // PATH_CHALLENGE was received (RFC 9000, Section 8.2.2). + // + // At the moment, we don't support path migration and reject packets if + // the peer changes its source address, so just sending the PATH_RESPONSE + // in a regular datagram is fine. + if !c.w.appendPathResponseFrame(c.path.data) { + return pad, false + } + if c.path.sendPathResponse == pathResponseExpanded { + pad = true + } + c.path.sendPathResponse = pathResponseNotNeeded + return pad, true +} diff --git a/src/vendor/golang.org/x/net/quic/ping.go b/src/vendor/golang.org/x/net/quic/ping.go new file mode 100644 index 0000000000..e604f014bf --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/ping.go @@ -0,0 +1,14 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "time" + +func (c *Conn) ping(space numberSpace) { + c.sendMsg(func(now time.Time, c *Conn) { + c.testSendPing.setUnsent() + c.testSendPingSpace = space + }) +} diff --git a/src/vendor/golang.org/x/net/quic/pipe.go b/src/vendor/golang.org/x/net/quic/pipe.go new file mode 100644 index 0000000000..2ae651ea3e --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/pipe.go @@ -0,0 +1,173 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "sync" +) + +// A pipe is a byte buffer used in implementing streams. +// +// A pipe contains a window of stream data. +// Random access reads and writes are supported within the window. +// Writing past the end of the window extends it. +// Data may be discarded from the start of the pipe, advancing the window. +type pipe struct { + start int64 // stream position of first stored byte + end int64 // stream position just past the last stored byte + head *pipebuf // if non-nil, then head.off + len(head.b) > start + tail *pipebuf // if non-nil, then tail.off + len(tail.b) == end +} + +type pipebuf struct { + off int64 // stream position of b[0] + b []byte + next *pipebuf +} + +func (pb *pipebuf) end() int64 { + return pb.off + int64(len(pb.b)) +} + +var pipebufPool = sync.Pool{ + New: func() any { + return &pipebuf{ + b: make([]byte, 4096), + } + }, +} + +func newPipebuf() *pipebuf { + return pipebufPool.Get().(*pipebuf) +} + +func (b *pipebuf) recycle() { + b.off = 0 + b.next = nil + pipebufPool.Put(b) +} + +// writeAt writes len(b) bytes to the pipe at offset off. +// +// Writes to offsets before p.start are discarded. +// Writes to offsets after p.end extend the pipe window. +func (p *pipe) writeAt(b []byte, off int64) { + end := off + int64(len(b)) + if end > p.end { + p.end = end + } else if end <= p.start { + return + } + + if off < p.start { + // Discard the portion of b which falls before p.start. + trim := p.start - off + b = b[trim:] + off = p.start + } + + if p.head == nil { + p.head = newPipebuf() + p.head.off = p.start + p.tail = p.head + } + pb := p.head + if off >= p.tail.off { + // Common case: Writing past the end of the pipe. + pb = p.tail + } + for { + pboff := off - pb.off + if pboff < int64(len(pb.b)) { + n := copy(pb.b[pboff:], b) + if n == len(b) { + return + } + off += int64(n) + b = b[n:] + } + if pb.next == nil { + pb.next = newPipebuf() + pb.next.off = pb.off + int64(len(pb.b)) + p.tail = pb.next + } + pb = pb.next + } +} + +// copy copies len(b) bytes into b starting from off. +// The pipe must contain [off, off+len(b)). +func (p *pipe) copy(off int64, b []byte) { + dst := b[:0] + p.read(off, len(b), func(c []byte) error { + dst = append(dst, c...) + return nil + }) +} + +// read calls f with the data in [off, off+n) +// The data may be provided sequentially across multiple calls to f. +// Note that read (unlike an io.Reader) does not consume the read data. +func (p *pipe) read(off int64, n int, f func([]byte) error) error { + if off < p.start { + panic("invalid read range") + } + for pb := p.head; pb != nil && n > 0; pb = pb.next { + if off >= pb.end() { + continue + } + b := pb.b[off-pb.off:] + if len(b) > n { + b = b[:n] + } + off += int64(len(b)) + n -= len(b) + if err := f(b); err != nil { + return err + } + } + if n > 0 { + panic("invalid read range") + } + return nil +} + +// peek returns a reference to up to n bytes of internal data buffer, starting at p.start. +// The returned slice is valid until the next call to discardBefore. +// The length of the returned slice will be in the range [0,n]. +func (p *pipe) peek(n int64) []byte { + pb := p.head + if pb == nil { + return nil + } + b := pb.b[p.start-pb.off:] + return b[:min(int64(len(b)), n)] +} + +// availableBuffer returns the available contiguous, allocated buffer space +// following the pipe window. +// +// This is used by the stream write fast path, which makes multiple writes into the pipe buffer +// without a lock, and then adjusts p.end at a later time with a lock held. +func (p *pipe) availableBuffer() []byte { + if p.tail == nil { + return nil + } + return p.tail.b[p.end-p.tail.off:] +} + +// discardBefore discards all data prior to off. +func (p *pipe) discardBefore(off int64) { + for p.head != nil && p.head.end() < off { + head := p.head + p.head = p.head.next + head.recycle() + } + if p.head == nil { + p.tail = nil + } + p.start = off + p.end = max(p.end, off) +} diff --git a/src/vendor/golang.org/x/net/quic/qlog.go b/src/vendor/golang.org/x/net/quic/qlog.go new file mode 100644 index 0000000000..5d2fd0fc1e --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/qlog.go @@ -0,0 +1,272 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "encoding/hex" + "log/slog" + "net/netip" + "time" +) + +// Log levels for qlog events. +const ( + // QLogLevelFrame includes per-frame information. + // When this level is enabled, packet_sent and packet_received events will + // contain information on individual frames sent/received. + QLogLevelFrame = slog.Level(-6) + + // QLogLevelPacket events occur at most once per packet sent or received. + // + // For example: packet_sent, packet_received. + QLogLevelPacket = slog.Level(-4) + + // QLogLevelConn events occur multiple times over a connection's lifetime, + // but less often than the frequency of individual packets. + // + // For example: connection_state_updated. + QLogLevelConn = slog.Level(-2) + + // QLogLevelEndpoint events occur at most once per connection. + // + // For example: connection_started, connection_closed. + QLogLevelEndpoint = slog.Level(0) +) + +func (c *Conn) logEnabled(level slog.Level) bool { + return logEnabled(c.log, level) +} + +func logEnabled(log *slog.Logger, level slog.Level) bool { + return log != nil && log.Enabled(context.Background(), level) +} + +// slogHexstring returns a slog.Attr for a value of the hexstring type. +// +// https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-1.1.1 +func slogHexstring(key string, value []byte) slog.Attr { + return slog.String(key, hex.EncodeToString(value)) +} + +func slogAddr(key string, value netip.Addr) slog.Attr { + return slog.String(key, value.String()) +} + +func (c *Conn) logConnectionStarted(originalDstConnID []byte, peerAddr netip.AddrPort) { + if c.config.QLogLogger == nil || + !c.config.QLogLogger.Enabled(context.Background(), QLogLevelEndpoint) { + return + } + var vantage string + if c.side == clientSide { + vantage = "client" + originalDstConnID = c.connIDState.originalDstConnID + } else { + vantage = "server" + } + // A qlog Trace container includes some metadata (title, description, vantage_point) + // and a list of Events. The Trace also includes a common_fields field setting field + // values common to all events in the trace. + // + // Trace = { + // ? title: text + // ? description: text + // ? configuration: Configuration + // ? common_fields: CommonFields + // ? vantage_point: VantagePoint + // events: [* Event] + // } + // + // To map this into slog's data model, we start each per-connection trace with a With + // call that includes both the trace metadata and the common fields. + // + // This means that in slog's model, each trace event will also include + // the Trace metadata fields (vantage_point), which is a divergence from the qlog model. + c.log = c.config.QLogLogger.With( + // The group_id permits associating traces taken from different vantage points + // for the same connection. + // + // We use the original destination connection ID as the group ID. + // + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-3.4.6 + slogHexstring("group_id", originalDstConnID), + slog.Group("vantage_point", + slog.String("name", "go quic"), + slog.String("type", vantage), + ), + ) + localAddr := c.endpoint.LocalAddr() + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.2 + c.log.LogAttrs(context.Background(), QLogLevelEndpoint, + "connectivity:connection_started", + slogAddr("src_ip", localAddr.Addr()), + slog.Int("src_port", int(localAddr.Port())), + slogHexstring("src_cid", c.connIDState.local[0].cid), + slogAddr("dst_ip", peerAddr.Addr()), + slog.Int("dst_port", int(peerAddr.Port())), + slogHexstring("dst_cid", c.connIDState.remote[0].cid), + ) +} + +func (c *Conn) logConnectionClosed() { + if !c.logEnabled(QLogLevelEndpoint) { + return + } + err := c.lifetime.finalErr + trigger := "error" + switch e := err.(type) { + case *ApplicationError: + // TODO: Distinguish between peer and locally-initiated close. + trigger = "application" + case localTransportError: + switch err { + case errHandshakeTimeout: + trigger = "handshake_timeout" + default: + if e.code == errNo { + trigger = "clean" + } + } + case peerTransportError: + if e.code == errNo { + trigger = "clean" + } + default: + switch err { + case errIdleTimeout: + trigger = "idle_timeout" + case errStatelessReset: + trigger = "stateless_reset" + } + } + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.3 + c.log.LogAttrs(context.Background(), QLogLevelEndpoint, + "connectivity:connection_closed", + slog.String("trigger", trigger), + ) +} + +func (c *Conn) logPacketDropped(dgram *datagram) { + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "connectivity:packet_dropped", + ) +} + +func (c *Conn) logLongPacketReceived(p longPacket, pkt []byte) { + var frames slog.Attr + if c.logEnabled(QLogLevelFrame) { + frames = c.packetFramesAttr(p.payload) + } + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "transport:packet_received", + slog.Group("header", + slog.String("packet_type", p.ptype.qlogString()), + slog.Uint64("packet_number", uint64(p.num)), + slog.Uint64("flags", uint64(pkt[0])), + slogHexstring("scid", p.srcConnID), + slogHexstring("dcid", p.dstConnID), + ), + slog.Group("raw", + slog.Int("length", len(pkt)), + ), + frames, + ) +} + +func (c *Conn) log1RTTPacketReceived(p shortPacket, pkt []byte) { + var frames slog.Attr + if c.logEnabled(QLogLevelFrame) { + frames = c.packetFramesAttr(p.payload) + } + dstConnID, _ := dstConnIDForDatagram(pkt) + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "transport:packet_received", + slog.Group("header", + slog.String("packet_type", packetType1RTT.qlogString()), + slog.Uint64("packet_number", uint64(p.num)), + slog.Uint64("flags", uint64(pkt[0])), + slogHexstring("dcid", dstConnID), + ), + slog.Group("raw", + slog.Int("length", len(pkt)), + ), + frames, + ) +} + +func (c *Conn) logPacketSent(ptype packetType, pnum packetNumber, src, dst []byte, pktLen int, payload []byte) { + var frames slog.Attr + if c.logEnabled(QLogLevelFrame) { + frames = c.packetFramesAttr(payload) + } + var scid slog.Attr + if len(src) > 0 { + scid = slogHexstring("scid", src) + } + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "transport:packet_sent", + slog.Group("header", + slog.String("packet_type", ptype.qlogString()), + slog.Uint64("packet_number", uint64(pnum)), + scid, + slogHexstring("dcid", dst), + ), + slog.Group("raw", + slog.Int("length", pktLen), + ), + frames, + ) +} + +// packetFramesAttr returns the "frames" attribute containing the frames in a packet. +// We currently pass this as a slog Any containing a []slog.Value, +// where each Value is a debugFrame that implements slog.LogValuer. +// +// This isn't tremendously efficient, but avoids the need to put a JSON encoder +// in the quic package or a frame parser in the qlog package. +func (c *Conn) packetFramesAttr(payload []byte) slog.Attr { + var frames []slog.Value + for len(payload) > 0 { + f, n := parseDebugFrame(payload) + if n < 0 { + break + } + payload = payload[n:] + switch f := f.(type) { + case debugFrameAck: + // The qlog ACK frame contains the ACK Delay field as a duration. + // Interpreting the contents of this field as a duration requires + // knowing the peer's ack_delay_exponent transport parameter, + // and it's possible for us to parse an ACK frame before we've + // received that parameter. + // + // We could plumb connection state down into the frame parser, + // but for now let's minimize the amount of code that needs to + // deal with this and convert the unscaled value into a scaled one here. + ackDelay := time.Duration(-1) + if c.peerAckDelayExponent >= 0 { + ackDelay = f.ackDelay.Duration(uint8(c.peerAckDelayExponent)) + } + frames = append(frames, slog.AnyValue(debugFrameScaledAck{ + ranges: f.ranges, + ackDelay: ackDelay, + })) + default: + frames = append(frames, slog.AnyValue(f)) + } + } + return slog.Any("frames", frames) +} + +func (c *Conn) logPacketLost(space numberSpace, sent *sentPacket) { + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:packet_lost", + slog.Group("header", + slog.String("packet_type", sent.ptype.qlogString()), + slog.Uint64("packet_number", uint64(sent.num)), + ), + ) +} diff --git a/src/vendor/golang.org/x/net/quic/queue.go b/src/vendor/golang.org/x/net/quic/queue.go new file mode 100644 index 0000000000..f2712f4012 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/queue.go @@ -0,0 +1,63 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "context" + +// A queue is an unbounded queue of some item (new connections and streams). +type queue[T any] struct { + // The gate condition is set if the queue is non-empty or closed. + gate gate + err error + q []T +} + +func newQueue[T any]() queue[T] { + return queue[T]{gate: newGate()} +} + +// close closes the queue, causing pending and future pop operations +// to return immediately with err. +func (q *queue[T]) close(err error) { + q.gate.lock() + defer q.unlock() + if q.err == nil { + q.err = err + } +} + +// put appends an item to the queue. +// It returns true if the item was added, false if the queue is closed. +func (q *queue[T]) put(v T) bool { + q.gate.lock() + defer q.unlock() + if q.err != nil { + return false + } + q.q = append(q.q, v) + return true +} + +// get removes the first item from the queue, blocking until ctx is done, an item is available, +// or the queue is closed. +func (q *queue[T]) get(ctx context.Context) (T, error) { + var zero T + if err := q.gate.waitAndLock(ctx); err != nil { + return zero, err + } + defer q.unlock() + if q.err != nil { + return zero, q.err + } + v := q.q[0] + copy(q.q[:], q.q[1:]) + q.q[len(q.q)-1] = zero + q.q = q.q[:len(q.q)-1] + return v, nil +} + +func (q *queue[T]) unlock() { + q.gate.unlock(q.err != nil || len(q.q) > 0) +} diff --git a/src/vendor/golang.org/x/net/quic/quic.go b/src/vendor/golang.org/x/net/quic/quic.go new file mode 100644 index 0000000000..26256bf422 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/quic.go @@ -0,0 +1,221 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// QUIC versions. +// We only support v1 at this time. +const ( + quicVersion1 = 1 + quicVersion2 = 0x6b3343cf // https://www.rfc-editor.org/rfc/rfc9369 +) + +// connIDLen is the length in bytes of connection IDs chosen by this package. +// Since 1-RTT packets don't include a connection ID length field, +// we use a consistent length for all our IDs. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1-6 +const connIDLen = 8 + +// Local values of various transport parameters. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2 +const ( + defaultMaxIdleTimeout = 30 * time.Second // max_idle_timeout + + // The max_udp_payload_size transport parameter is the size of our + // network receive buffer. + // + // Set this to the largest UDP packet that can be sent over + // Ethernet without using jumbo frames: 1500 byte Ethernet frame, + // minus 20 byte IPv4 header and 8 byte UDP header. + // + // The maximum possible UDP payload is 65527 bytes. Supporting this + // without wasting memory in unused receive buffers will require some + // care. For now, just limit ourselves to the most common case. + maxUDPPayloadSize = 1472 + + ackDelayExponent = 3 // ack_delay_exponent + maxAckDelay = 25 * time.Millisecond // max_ack_delay + + // The active_conn_id_limit transport parameter is the maximum + // number of connection IDs from the peer we're willing to store. + // + // maxPeerActiveConnIDLimit is the maximum number of connection IDs + // we're willing to send to the peer. + // + // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-6.2.1 + activeConnIDLimit = 2 + maxPeerActiveConnIDLimit = 4 +) + +// Time limit for completing the handshake. +const defaultHandshakeTimeout = 10 * time.Second + +// Keep-alive ping frequency. +const defaultKeepAlivePeriod = 0 + +// Local timer granularity. +// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2-6 +const timerGranularity = 1 * time.Millisecond + +// The smallest allowed maximum datagram size. +// https://www.rfc-editor.org/rfc/rfc9000#section-14 +const smallestMaxDatagramSize = 1200 + +// Minimum size of a UDP datagram sent by a client carrying an Initial packet, +// or a server containing an ack-eliciting Initial packet. +// https://www.rfc-editor.org/rfc/rfc9000#section-14.1 +const paddedInitialDatagramSize = smallestMaxDatagramSize + +// Maximum number of streams of a given type which may be created. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-4.6-2 +const maxStreamsLimit = 1 << 60 + +// Maximum number of streams we will allow the peer to create implicitly. +// A stream ID that is used out of order results in all streams of that type +// with lower-numbered IDs also being opened. To limit the amount of work we +// will do in response to a single frame, we cap the peer's stream limit to +// this value. +const implicitStreamLimit = 100 + +// A connSide distinguishes between the client and server sides of a connection. +type connSide int8 + +const ( + clientSide = connSide(iota) + serverSide +) + +func (s connSide) String() string { + switch s { + case clientSide: + return "client" + case serverSide: + return "server" + default: + return "BUG" + } +} + +func (s connSide) peer() connSide { + if s == clientSide { + return serverSide + } else { + return clientSide + } +} + +// A numberSpace is the context in which a packet number applies. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-12.3-7 +type numberSpace byte + +const ( + initialSpace = numberSpace(iota) + handshakeSpace + appDataSpace + numberSpaceCount +) + +func (n numberSpace) String() string { + switch n { + case initialSpace: + return "Initial" + case handshakeSpace: + return "Handshake" + case appDataSpace: + return "AppData" + default: + return "BUG" + } +} + +// A streamType is the type of a stream: bidirectional or unidirectional. +type streamType uint8 + +const ( + bidiStream = streamType(iota) + uniStream + streamTypeCount +) + +func (s streamType) qlogString() string { + switch s { + case bidiStream: + return "bidirectional" + case uniStream: + return "unidirectional" + default: + return "BUG" + } +} + +func (s streamType) String() string { + switch s { + case bidiStream: + return "bidi" + case uniStream: + return "uni" + default: + return "BUG" + } +} + +// A streamID is a QUIC stream ID. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-2.1 +type streamID uint64 + +// The two least significant bits of a stream ID indicate the initiator +// and directionality of the stream. The upper bits are the stream number. +// Each of the four possible combinations of initiator and direction +// each has a distinct number space. +const ( + clientInitiatedStreamBit = 0x0 + serverInitiatedStreamBit = 0x1 + initiatorStreamBitMask = 0x1 + + bidiStreamBit = 0x0 + uniStreamBit = 0x2 + dirStreamBitMask = 0x2 +) + +func newStreamID(initiator connSide, typ streamType, num int64) streamID { + id := streamID(num << 2) + if typ == uniStream { + id |= uniStreamBit + } + if initiator == serverSide { + id |= serverInitiatedStreamBit + } + return id +} + +func (s streamID) initiator() connSide { + if s&initiatorStreamBitMask == serverInitiatedStreamBit { + return serverSide + } + return clientSide +} + +func (s streamID) num() int64 { + return int64(s) >> 2 +} + +func (s streamID) streamType() streamType { + if s&dirStreamBitMask == uniStreamBit { + return uniStream + } + return bidiStream +} + +// packetFate is the fate of a sent packet: Either acknowledged by the peer, +// or declared lost. +type packetFate byte + +const ( + packetLost = packetFate(iota) + packetAcked +) diff --git a/src/vendor/golang.org/x/net/quic/rangeset.go b/src/vendor/golang.org/x/net/quic/rangeset.go new file mode 100644 index 0000000000..3d6f5f9799 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/rangeset.go @@ -0,0 +1,193 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A rangeset is a set of int64s, stored as an ordered list of non-overlapping, +// non-empty ranges. +// +// Rangesets are efficient for small numbers of ranges, +// which is expected to be the common case. +type rangeset[T ~int64] []i64range[T] + +type i64range[T ~int64] struct { + start, end T // [start, end) +} + +// size returns the size of the range. +func (r i64range[T]) size() T { + return r.end - r.start +} + +// contains reports whether v is in the range. +func (r i64range[T]) contains(v T) bool { + return r.start <= v && v < r.end +} + +// add adds [start, end) to the set, combining it with existing ranges if necessary. +func (s *rangeset[T]) add(start, end T) { + if start == end { + return + } + for i := range *s { + r := &(*s)[i] + if r.start > end { + // The new range comes before range i. + s.insertrange(i, start, end) + return + } + if start > r.end { + // The new range comes after range i. + continue + } + // The new range is adjacent to or overlapping range i. + if start < r.start { + r.start = start + } + if end <= r.end { + return + } + // Possibly coalesce subsequent ranges into range i. + r.end = end + j := i + 1 + for ; j < len(*s) && r.end >= (*s)[j].start; j++ { + if e := (*s)[j].end; e > r.end { + // Range j ends after the new range. + r.end = e + } + } + s.removeranges(i+1, j) + return + } + *s = append(*s, i64range[T]{start, end}) +} + +// sub removes [start, end) from the set. +func (s *rangeset[T]) sub(start, end T) { + removefrom, removeto := -1, -1 + for i := range *s { + r := &(*s)[i] + if end < r.start { + break + } + if r.end < start { + continue + } + switch { + case start <= r.start && end >= r.end: + // Remove the entire range. + if removefrom == -1 { + removefrom = i + } + removeto = i + 1 + case start <= r.start: + // Remove a prefix. + r.start = end + case end >= r.end: + // Remove a suffix. + r.end = start + default: + // Remove the middle, leaving two new ranges. + rend := r.end + r.end = start + s.insertrange(i+1, end, rend) + return + } + } + if removefrom != -1 { + s.removeranges(removefrom, removeto) + } +} + +// contains reports whether s contains v. +func (s rangeset[T]) contains(v T) bool { + for _, r := range s { + if v >= r.end { + continue + } + if r.start <= v { + return true + } + return false + } + return false +} + +// rangeContaining returns the range containing v, or the range [0,0) if v is not in s. +func (s rangeset[T]) rangeContaining(v T) i64range[T] { + for _, r := range s { + if v >= r.end { + continue + } + if r.start <= v { + return r + } + break + } + return i64range[T]{0, 0} +} + +// min returns the minimum value in the set, or 0 if empty. +func (s rangeset[T]) min() T { + if len(s) == 0 { + return 0 + } + return s[0].start +} + +// max returns the maximum value in the set, or 0 if empty. +func (s rangeset[T]) max() T { + if len(s) == 0 { + return 0 + } + return s[len(s)-1].end - 1 +} + +// end returns the end of the last range in the set, or 0 if empty. +func (s rangeset[T]) end() T { + if len(s) == 0 { + return 0 + } + return s[len(s)-1].end +} + +// numRanges returns the number of ranges in the rangeset. +func (s rangeset[T]) numRanges() int { + return len(s) +} + +// size returns the size of all ranges in the rangeset. +func (s rangeset[T]) size() (total T) { + for _, r := range s { + total += r.size() + } + return total +} + +// isrange reports if the rangeset covers exactly the range [start, end). +func (s rangeset[T]) isrange(start, end T) bool { + switch len(s) { + case 0: + return start == 0 && end == 0 + case 1: + return s[0].start == start && s[0].end == end + } + return false +} + +// removeranges removes ranges [i,j). +func (s *rangeset[T]) removeranges(i, j int) { + if i == j { + return + } + copy((*s)[i:], (*s)[j:]) + *s = (*s)[:len(*s)-(j-i)] +} + +// insert adds a new range at index i. +func (s *rangeset[T]) insertrange(i int, start, end T) { + *s = append(*s, i64range[T]{}) + copy((*s)[i+1:], (*s)[i:]) + (*s)[i] = i64range[T]{start, end} +} diff --git a/src/vendor/golang.org/x/net/quic/retry.go b/src/vendor/golang.org/x/net/quic/retry.go new file mode 100644 index 0000000000..0392ca9159 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/retry.go @@ -0,0 +1,240 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "net/netip" + "time" + + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/internal/quic/quicwire" +) + +// AEAD and nonce used to compute the Retry Integrity Tag. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.8 +var ( + retrySecret = []byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e} + retryNonce = []byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} + retryAEAD = func() cipher.AEAD { + c, err := aes.NewCipher(retrySecret) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(c) + if err != nil { + panic(err) + } + return aead + }() +) + +// retryTokenValidityPeriod is how long we accept a Retry packet token after sending it. +const retryTokenValidityPeriod = 5 * time.Second + +// retryState generates and validates an endpoint's retry tokens. +type retryState struct { + aead cipher.AEAD +} + +func (rs *retryState) init() error { + // Retry tokens are authenticated using a per-server key chosen at start time. + // TODO: Provide a way for the user to set this key. + secret := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(secret); err != nil { + return err + } + aead, err := chacha20poly1305.NewX(secret) + if err != nil { + panic(err) + } + rs.aead = aead + return nil +} + +// Retry tokens are encrypted with an AEAD. +// The plaintext contains the time the token was created and +// the original destination connection ID. +// The additional data contains the sender's source address and original source connection ID. +// The token nonce is randomly generated. +// We use the nonce as the Source Connection ID of the Retry packet. +// Since the 24-byte XChaCha20-Poly1305 nonce is too large to fit in a 20-byte connection ID, +// we include the remaining 4 bytes of nonce in the token. +// +// Token { +// Last 4 Bytes of Nonce (32), +// Ciphertext (..), +// } +// +// Plaintext { +// Timestamp (64), +// Original Destination Connection ID, +// } +// +// +// Additional Data { +// Original Source Connection ID Length (8), +// Original Source Connection ID (..), +// IP Address (32..128), +// Port (16), +// } +// +// TODO: Consider using AES-256-GCM-SIV once crypto/tls supports it. + +func (rs *retryState) makeToken(now time.Time, srcConnID, origDstConnID []byte, addr netip.AddrPort) (token, newDstConnID []byte, err error) { + nonce := make([]byte, rs.aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, nil, err + } + + var plaintext []byte + plaintext = binary.BigEndian.AppendUint64(plaintext, uint64(now.Unix())) + plaintext = append(plaintext, origDstConnID...) + + token = append(token, nonce[maxConnIDLen:]...) + token = rs.aead.Seal(token, nonce, plaintext, rs.additionalData(srcConnID, addr)) + return token, nonce[:maxConnIDLen], nil +} + +func (rs *retryState) validateToken(now time.Time, token, srcConnID, dstConnID []byte, addr netip.AddrPort) (origDstConnID []byte, ok bool) { + tokenNonceLen := rs.aead.NonceSize() - maxConnIDLen + if len(token) < tokenNonceLen { + return nil, false + } + nonce := append([]byte{}, dstConnID...) + nonce = append(nonce, token[:tokenNonceLen]...) + ciphertext := token[tokenNonceLen:] + if len(nonce) != rs.aead.NonceSize() { + return nil, false + } + + plaintext, err := rs.aead.Open(nil, nonce, ciphertext, rs.additionalData(srcConnID, addr)) + if err != nil { + return nil, false + } + if len(plaintext) < 8 { + return nil, false + } + when := time.Unix(int64(binary.BigEndian.Uint64(plaintext)), 0) + origDstConnID = plaintext[8:] + + // We allow for tokens created in the future (up to the validity period), + // which likely indicates that the system clock was adjusted backwards. + if d := abs(now.Sub(when)); d > retryTokenValidityPeriod { + return nil, false + } + + return origDstConnID, true +} + +func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []byte { + var additional []byte + additional = quicwire.AppendUint8Bytes(additional, srcConnID) + additional = append(additional, addr.Addr().AsSlice()...) + additional = binary.BigEndian.AppendUint16(additional, addr.Port()) + return additional +} + +func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) (origDstConnID []byte, ok bool) { + // The retry token is at the start of an Initial packet's data. + token, n := quicwire.ConsumeUint8Bytes(p.data) + if n < 0 { + // We've already validated that the packet is at least 1200 bytes long, + // so there's no way for even a maximum size token to not fit. + // Check anyway. + return nil, false + } + if len(token) == 0 { + // The sender has not provided a token. + // Send a Retry packet to them with one. + e.sendRetry(now, p, peerAddr) + return nil, false + } + origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, peerAddr) + if !ok { + // This does not seem to be a valid token. + // Close the connection with an INVALID_TOKEN error. + // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5 + e.sendConnectionClose(p, peerAddr, errInvalidToken) + return nil, false + } + return origDstConnID, true +} + +func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) { + token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, peerAddr) + if err != nil { + return + } + b := encodeRetryPacket(p.dstConnID, retryPacket{ + dstConnID: p.srcConnID, + srcConnID: srcConnID, + token: token, + }) + e.sendDatagram(datagram{ + b: b, + peerAddr: peerAddr, + }) +} + +type retryPacket struct { + dstConnID []byte + srcConnID []byte + token []byte +} + +func encodeRetryPacket(originalDstConnID []byte, p retryPacket) []byte { + // Retry packets include an integrity tag, computed by AEAD_AES_128_GCM over + // the original destination connection ID followed by the Retry packet + // (less the integrity tag itself). + // https://www.rfc-editor.org/rfc/rfc9001#section-5.8 + // + // Create the pseudo-packet (including the original DCID), append the tag, + // and return the Retry packet. + var b []byte + b = quicwire.AppendUint8Bytes(b, originalDstConnID) // Original Destination Connection ID + start := len(b) // start of the Retry packet + b = append(b, headerFormLong|fixedBit|longPacketTypeRetry) + b = binary.BigEndian.AppendUint32(b, quicVersion1) // Version + b = quicwire.AppendUint8Bytes(b, p.dstConnID) // Destination Connection ID + b = quicwire.AppendUint8Bytes(b, p.srcConnID) // Source Connection ID + b = append(b, p.token...) // Token + b = retryAEAD.Seal(b, retryNonce, nil, b) // Retry Integrity Tag + return b[start:] +} + +func parseRetryPacket(b, origDstConnID []byte) (p retryPacket, ok bool) { + const retryIntegrityTagLength = 128 / 8 + + lp, ok := parseGenericLongHeaderPacket(b) + if !ok { + return retryPacket{}, false + } + if len(lp.data) < retryIntegrityTagLength { + return retryPacket{}, false + } + gotTag := lp.data[len(lp.data)-retryIntegrityTagLength:] + + // Create the pseudo-packet consisting of the original destination connection ID + // followed by the Retry packet (less the integrity tag). + // Use this to validate the packet integrity tag. + pseudo := quicwire.AppendUint8Bytes(nil, origDstConnID) + pseudo = append(pseudo, b[:len(b)-retryIntegrityTagLength]...) + wantTag := retryAEAD.Seal(nil, retryNonce, nil, pseudo) + if !bytes.Equal(gotTag, wantTag) { + return retryPacket{}, false + } + + token := lp.data[:len(lp.data)-retryIntegrityTagLength] + return retryPacket{ + dstConnID: lp.dstConnID, + srcConnID: lp.srcConnID, + token: token, + }, true +} diff --git a/src/vendor/golang.org/x/net/quic/rtt.go b/src/vendor/golang.org/x/net/quic/rtt.go new file mode 100644 index 0000000000..0dc9bf5bf2 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/rtt.go @@ -0,0 +1,71 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +type rttState struct { + minRTT time.Duration + latestRTT time.Duration + smoothedRTT time.Duration + rttvar time.Duration // RTT variation + firstSampleTime time.Time // time of first RTT sample +} + +func (r *rttState) init() { + r.minRTT = -1 // -1 indicates the first sample has not been taken yet + + // "[...] the initial RTT SHOULD be set to 333 milliseconds." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2-1 + const initialRTT = 333 * time.Millisecond + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-12 + r.smoothedRTT = initialRTT + r.rttvar = initialRTT / 2 +} + +func (r *rttState) establishPersistentCongestion() { + // "Endpoints SHOULD set the min_rtt to the newest RTT sample + // after persistent congestion is established." + // https://www.rfc-editor.org/rfc/rfc9002#section-5.2-5 + r.minRTT = r.latestRTT +} + +// updateSample is called when we generate a new RTT sample. +// https://www.rfc-editor.org/rfc/rfc9002.html#section-5 +func (r *rttState) updateSample(now time.Time, handshakeConfirmed bool, spaceID numberSpace, latestRTT, ackDelay, maxAckDelay time.Duration) { + r.latestRTT = latestRTT + + if r.minRTT < 0 { + // First RTT sample. + // "min_rtt MUST be set to the latest_rtt on the first RTT sample." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2 + r.minRTT = latestRTT + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-14 + r.smoothedRTT = latestRTT + r.rttvar = latestRTT / 2 + r.firstSampleTime = now + return + } + + // "min_rtt MUST be set to the lesser of min_rtt and latest_rtt [...] + // on all other samples." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2 + r.minRTT = min(r.minRTT, latestRTT) + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-16 + if handshakeConfirmed { + ackDelay = min(ackDelay, maxAckDelay) + } + adjustedRTT := latestRTT - ackDelay + if adjustedRTT < r.minRTT { + adjustedRTT = latestRTT + } + rttvarSample := abs(r.smoothedRTT - adjustedRTT) + r.rttvar = (3*r.rttvar + rttvarSample) / 4 + r.smoothedRTT = ((7 * r.smoothedRTT) + adjustedRTT) / 8 +} diff --git a/src/vendor/golang.org/x/net/quic/sent_packet.go b/src/vendor/golang.org/x/net/quic/sent_packet.go new file mode 100644 index 0000000000..f67606b353 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/sent_packet.go @@ -0,0 +1,119 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "sync" + "time" + + "golang.org/x/net/internal/quic/quicwire" +) + +// A sentPacket tracks state related to an in-flight packet we sent, +// to be committed when the peer acks it or resent if the packet is lost. +type sentPacket struct { + num packetNumber + size int // size in bytes + time time.Time // time sent + ptype packetType + + state sentPacketState + ackEliciting bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.4.1 + inFlight bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.6.1 + + // Frames sent in the packet. + // + // This is an abbreviated version of the packet payload, containing only the information + // we need to process an ack for or loss of this packet. + // For example, a CRYPTO frame is recorded as the frame type (0x06), offset, and length, + // but does not include the sent data. + // + // This buffer is written by packetWriter.append* and read by Conn.handleAckOrLoss. + b []byte + n int // read offset into b +} + +type sentPacketState uint8 + +const ( + sentPacketSent = sentPacketState(iota) // sent but neither acked nor lost + sentPacketAcked // acked + sentPacketLost // declared lost + sentPacketUnsent // never sent +) + +var sentPool = sync.Pool{ + New: func() any { + return &sentPacket{} + }, +} + +func newSentPacket() *sentPacket { + sent := sentPool.Get().(*sentPacket) + sent.reset() + return sent +} + +// recycle returns a sentPacket to the pool. +func (sent *sentPacket) recycle() { + sentPool.Put(sent) +} + +func (sent *sentPacket) reset() { + *sent = sentPacket{ + b: sent.b[:0], + } +} + +// markAckEliciting marks the packet as containing an ack-eliciting frame. +func (sent *sentPacket) markAckEliciting() { + sent.ackEliciting = true + sent.inFlight = true +} + +// The append* methods record information about frames in the packet. + +func (sent *sentPacket) appendNonAckElicitingFrame(frameType byte) { + sent.b = append(sent.b, frameType) +} + +func (sent *sentPacket) appendAckElicitingFrame(frameType byte) { + sent.ackEliciting = true + sent.inFlight = true + sent.b = append(sent.b, frameType) +} + +func (sent *sentPacket) appendInt(v uint64) { + sent.b = quicwire.AppendVarint(sent.b, v) +} + +func (sent *sentPacket) appendOffAndSize(start int64, size int) { + sent.b = quicwire.AppendVarint(sent.b, uint64(start)) + sent.b = quicwire.AppendVarint(sent.b, uint64(size)) +} + +// The next* methods read back information about frames in the packet. + +func (sent *sentPacket) next() (frameType byte) { + f := sent.b[sent.n] + sent.n++ + return f +} + +func (sent *sentPacket) nextInt() uint64 { + v, n := quicwire.ConsumeVarint(sent.b[sent.n:]) + sent.n += n + return v +} + +func (sent *sentPacket) nextRange() (start, end int64) { + start = int64(sent.nextInt()) + end = start + int64(sent.nextInt()) + return start, end +} + +func (sent *sentPacket) done() bool { + return sent.n == len(sent.b) +} diff --git a/src/vendor/golang.org/x/net/quic/sent_packet_list.go b/src/vendor/golang.org/x/net/quic/sent_packet_list.go new file mode 100644 index 0000000000..04116f2109 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/sent_packet_list.go @@ -0,0 +1,93 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A sentPacketList is a ring buffer of sentPackets. +// +// Processing an ack for a packet causes all older packets past a small threshold +// to be discarded (RFC 9002, Section 6.1.1), so the list of in-flight packets is +// not sparse and will contain at most a few acked/lost packets we no longer +// care about. +type sentPacketList struct { + nextNum packetNumber // next packet number to add to the buffer + off int // offset of first packet in the buffer + size int // number of packets + p []*sentPacket +} + +// start is the first packet in the list. +func (s *sentPacketList) start() packetNumber { + return s.nextNum - packetNumber(s.size) +} + +// end is one after the last packet in the list. +// If the list is empty, start == end. +func (s *sentPacketList) end() packetNumber { + return s.nextNum +} + +// discard clears the list. +func (s *sentPacketList) discard() { + *s = sentPacketList{} +} + +// add appends a packet to the list. +func (s *sentPacketList) add(sent *sentPacket) { + if s.nextNum != sent.num { + panic("inserting out-of-order packet") + } + s.nextNum++ + if s.size >= len(s.p) { + s.grow() + } + i := (s.off + s.size) % len(s.p) + s.size++ + s.p[i] = sent +} + +// nth returns a packet by index. +func (s *sentPacketList) nth(n int) *sentPacket { + index := (s.off + n) % len(s.p) + return s.p[index] +} + +// num returns a packet by number. +// It returns nil if the packet is not in the list. +func (s *sentPacketList) num(num packetNumber) *sentPacket { + i := int(num - s.start()) + if i < 0 || i >= s.size { + return nil + } + return s.nth(i) +} + +// clean removes all acked or lost packets from the head of the list. +func (s *sentPacketList) clean() { + for s.size > 0 { + sent := s.p[s.off] + if sent.state == sentPacketSent { + return + } + sent.recycle() + s.p[s.off] = nil + s.off = (s.off + 1) % len(s.p) + s.size-- + } + s.off = 0 +} + +// grow increases the buffer to hold more packaets. +func (s *sentPacketList) grow() { + newSize := len(s.p) * 2 + if newSize == 0 { + newSize = 64 + } + p := make([]*sentPacket, newSize) + for i := 0; i < s.size; i++ { + p[i] = s.nth(i) + } + s.p = p + s.off = 0 +} diff --git a/src/vendor/golang.org/x/net/quic/sent_val.go b/src/vendor/golang.org/x/net/quic/sent_val.go new file mode 100644 index 0000000000..f1682dbd78 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/sent_val.go @@ -0,0 +1,103 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A sentVal tracks sending some piece of information to the peer. +// It tracks whether the information has been sent, acked, and +// (when in-flight) the most recent packet to carry it. +// +// For example, a sentVal can track sending of a RESET_STREAM frame. +// +// - unset: stream is active, no need to send RESET_STREAM +// - unsent: we should send a RESET_STREAM, but have not yet +// - sent: we have sent a RESET_STREAM, but have not received an ack +// - received: we have sent a RESET_STREAM, and the peer has acked the packet that contained it +// +// In the "sent" state, a sentVal also tracks the latest packet number to carry +// the information. (QUIC packet numbers are always at most 62 bits in size, +// so the sentVal keeps the number in the low 62 bits and the state in the high 2 bits.) +type sentVal uint64 + +const ( + sentValUnset = 0 // unset + sentValUnsent = 1 << 62 // set, not sent to the peer + sentValSent = 2 << 62 // set, sent to the peer but not yet acked; pnum is set + sentValReceived = 3 << 62 // set, peer acked receipt + + sentValStateMask = 3 << 62 +) + +// isSet reports whether the value is set. +func (s sentVal) isSet() bool { return s != 0 } + +// shouldSend reports whether the value is set and has not been sent to the peer. +func (s sentVal) shouldSend() bool { return s.state() == sentValUnsent } + +// shouldSendPTO reports whether the value needs to be sent to the peer. +// The value needs to be sent if it is set and has not been sent. +// If pto is true, indicating that we are sending a PTO probe, the value +// should also be sent if it is set and has not been acknowledged. +func (s sentVal) shouldSendPTO(pto bool) bool { + st := s.state() + return st == sentValUnsent || (pto && st == sentValSent) +} + +// isReceived reports whether the value has been received by the peer. +func (s sentVal) isReceived() bool { return s == sentValReceived } + +// set sets the value and records that it should be sent to the peer. +// If the value has already been sent, it is not resent. +func (s *sentVal) set() { + if *s == 0 { + *s = sentValUnsent + } +} + +// reset sets the value to the unsent state. +func (s *sentVal) setUnsent() { *s = sentValUnsent } + +// clear sets the value to the unset state. +func (s *sentVal) clear() { *s = sentValUnset } + +// setSent sets the value to the send state and records the number of the most recent +// packet containing the value. +func (s *sentVal) setSent(pnum packetNumber) { + *s = sentValSent | sentVal(pnum) +} + +// setReceived sets the value to the received state. +func (s *sentVal) setReceived() { *s = sentValReceived } + +// ackOrLoss reports that an acknowledgement has been received for the value, +// or that the packet carrying the value has been lost. +func (s *sentVal) ackOrLoss(pnum packetNumber, fate packetFate) { + if fate == packetAcked { + *s = sentValReceived + } else if *s == sentVal(pnum)|sentValSent { + *s = sentValUnsent + } +} + +// ackLatestOrLoss reports that an acknowledgement has been received for the value, +// or that the packet carrying the value has been lost. +// The value is set to the acked state only if pnum is the latest packet containing it. +// +// We use this to handle acks for data that varies every time it is sent. +// For example, if we send a MAX_DATA frame followed by an updated MAX_DATA value in a +// second packet, we consider the data sent only upon receiving an ack for the most +// recent value. +func (s *sentVal) ackLatestOrLoss(pnum packetNumber, fate packetFate) { + if fate == packetAcked { + if *s == sentVal(pnum)|sentValSent { + *s = sentValReceived + } + } else { + if *s == sentVal(pnum)|sentValSent { + *s = sentValUnsent + } + } +} + +func (s sentVal) state() uint64 { return uint64(s) & sentValStateMask } diff --git a/src/vendor/golang.org/x/net/quic/skip.go b/src/vendor/golang.org/x/net/quic/skip.go new file mode 100644 index 0000000000..f0d0234ee6 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/skip.go @@ -0,0 +1,62 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// skipState is state for optimistic ACK defenses. +// +// An endpoint performs an optimistic ACK attack by sending acknowledgements for packets +// which it has not received, potentially convincing the sender's congestion controller to +// send at rates beyond what the network supports. +// +// We defend against this by periodically skipping packet numbers. +// Receiving an ACK for an unsent packet number is a PROTOCOL_VIOLATION error. +// +// We only skip packet numbers in the Application Data number space. +// The total data sent in the Initial/Handshake spaces should generally fit into +// the initial congestion window. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-21.4 +type skipState struct { + // skip is the next packet number (in the Application Data space) we should skip. + skip packetNumber + + // maxSkip is the maximum number of packets to send before skipping another number. + // Increases over time. + maxSkip int64 +} + +func (ss *skipState) init(c *Conn) { + ss.maxSkip = 256 // skip our first packet number within this range + ss.updateNumberSkip(c) +} + +// shouldSkip returns whether we should skip the given packet number. +func (ss *skipState) shouldSkip(num packetNumber) bool { + return ss.skip == num +} + +// updateNumberSkip schedules a packet to be skipped after skipping lastSkipped. +func (ss *skipState) updateNumberSkip(c *Conn) { + // Send at least this many packets before skipping. + // Limits the impact of skipping a little, + // plus allows most tests to ignore skipping. + const minSkip = 64 + + skip := minSkip + c.prng.Int64N(ss.maxSkip-minSkip) + ss.skip += packetNumber(skip) + + // Double the size of the skip each time until we reach 128k. + // The idea here is that an attacker needs to correctly ack ~N packets in order + // to send an optimistic ack for another ~N packets. + // Skipping packet numbers comes with a small cost (it causes the receiver to + // send an immediate ACK rather than the usual delayed ACK), so we increase the + // time between skips as a connection's lifetime grows. + // + // The 128k cap is arbitrary, chosen so that we skip a packet number + // about once a second when sending full-size datagrams at 1Gbps. + if ss.maxSkip < 128*1024 { + ss.maxSkip *= 2 + } +} diff --git a/src/vendor/golang.org/x/net/quic/stateless_reset.go b/src/vendor/golang.org/x/net/quic/stateless_reset.go new file mode 100644 index 0000000000..8907e2e58b --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/stateless_reset.go @@ -0,0 +1,59 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "hash" + "sync" +) + +const statelessResetTokenLen = 128 / 8 + +// A statelessResetToken is a stateless reset token. +// https://www.rfc-editor.org/rfc/rfc9000#section-10.3 +type statelessResetToken [statelessResetTokenLen]byte + +type statelessResetTokenGenerator struct { + canReset bool + + // The hash.Hash interface is not concurrency safe, + // so we need a mutex here. + // + // There shouldn't be much contention on stateless reset token generation. + // If this proves to be a problem, we could avoid the mutex by using a separate + // generator per Conn, or by using a concurrency-safe generator. + mu sync.Mutex + mac hash.Hash +} + +func (g *statelessResetTokenGenerator) init(secret [32]byte) { + zero := true + for _, b := range secret { + if b != 0 { + zero = false + break + } + } + if zero { + // Generate tokens using a random secret, but don't send stateless resets. + rand.Read(secret[:]) + g.canReset = false + } else { + g.canReset = true + } + g.mac = hmac.New(sha256.New, secret[:]) +} + +func (g *statelessResetTokenGenerator) tokenForConnID(cid []byte) (token statelessResetToken) { + g.mu.Lock() + defer g.mu.Unlock() + defer g.mac.Reset() + g.mac.Write(cid) + copy(token[:], g.mac.Sum(nil)) + return token +} diff --git a/src/vendor/golang.org/x/net/quic/stream.go b/src/vendor/golang.org/x/net/quic/stream.go new file mode 100644 index 0000000000..383a6c160a --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/stream.go @@ -0,0 +1,1041 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "errors" + "fmt" + "io" + "math" + + "golang.org/x/net/internal/quic/quicwire" +) + +// A Stream is an ordered byte stream. +// +// Streams may be bidirectional, read-only, or write-only. +// Methods inappropriate for a stream's direction +// (for example, [Write] to a read-only stream) +// return errors. +// +// It is not safe to perform concurrent reads from or writes to a stream. +// It is safe, however, to read and write at the same time. +// +// Reads and writes are buffered. +// It is generally not necessary to wrap a stream in a [bufio.ReadWriter] +// or otherwise apply additional buffering. +// +// To cancel reads or writes, use the [SetReadContext] and [SetWriteContext] methods. +type Stream struct { + id streamID + conn *Conn + + // Contexts used for read/write operations. + // Intentionally not mutex-guarded, to allow the race detector to catch concurrent access. + inctx context.Context + outctx context.Context + + // ingate's lock guards receive-related state. + // + // The gate condition is set if a read from the stream will not block, + // either because the stream has available data or because the read will fail. + ingate gate + in pipe // received data + inwin int64 // last MAX_STREAM_DATA sent to the peer + insendmax sentVal // set when we should send MAX_STREAM_DATA to the peer + inmaxbuf int64 // maximum amount of data we will buffer + insize int64 // stream final size; -1 before this is known + inset rangeset[int64] // received ranges + inclosed sentVal // set by CloseRead + inresetcode int64 // RESET_STREAM code received from the peer; -1 if not reset + + // outgate's lock guards send-related state. + // + // The gate condition is set if a write to the stream will not block, + // either because the stream has available flow control or because + // the write will fail. + outgate gate + out pipe // buffered data to send + outflushed int64 // offset of last flush call + outwin int64 // maximum MAX_STREAM_DATA received from the peer + outmaxsent int64 // maximum data offset we've sent to the peer + outmaxbuf int64 // maximum amount of data we will buffer + outunsent rangeset[int64] // ranges buffered but not yet sent (only flushed data) + outacked rangeset[int64] // ranges sent and acknowledged + outopened sentVal // set if we should open the stream + outclosed sentVal // set by CloseWrite + outblocked sentVal // set when a write to the stream is blocked by flow control + outreset sentVal // set by Reset + outresetcode uint64 // reset code to send in RESET_STREAM + outdone chan struct{} // closed when all data sent + + // Unsynchronized buffers, used for lock-free fast path. + inbuf []byte // received data + inbufoff int // bytes of inbuf which have been consumed + outbuf []byte // written data + outbufoff int // bytes of outbuf which contain data to write + + // Atomic stream state bits. + // + // These bits provide a fast way to coordinate between the + // send and receive sides of the stream, and the conn's loop. + // + // streamIn* bits must be set with ingate held. + // streamOut* bits must be set with outgate held. + // streamConn* bits are set by the conn's loop. + // streamQueue* bits must be set with streamsState.sendMu held. + state atomicBits[streamState] + + prev, next *Stream // guarded by streamsState.sendMu +} + +type streamState uint32 + +const ( + // streamInSendMeta is set when there are frames to send for the + // inbound side of the stream. For example, MAX_STREAM_DATA. + // Inbound frames are never flow-controlled. + streamInSendMeta = streamState(1 << iota) + + // streamOutSendMeta is set when there are non-flow-controlled frames + // to send for the outbound side of the stream. For example, STREAM_DATA_BLOCKED. + // streamOutSendData is set when there are no non-flow-controlled outbound frames + // and the stream has data to send. + // + // At most one of streamOutSendMeta and streamOutSendData is set at any time. + streamOutSendMeta + streamOutSendData + + // streamInDone and streamOutDone are set when the inbound or outbound + // sides of the stream are finished. When both are set, the stream + // can be removed from the Conn and forgotten. + streamInDone + streamOutDone + + // streamConnRemoved is set when the stream has been removed from the conn. + streamConnRemoved + + // streamQueueMeta and streamQueueData indicate which of the streamsState + // send queues the conn is currently on. + streamQueueMeta + streamQueueData +) + +type streamQueue int + +const ( + noQueue = streamQueue(iota) + metaQueue // streamsState.queueMeta + dataQueue // streamsState.queueData +) + +// streamResetByConnClose is assigned to Stream.inresetcode to indicate that a stream +// was implicitly reset when the connection closed. It's out of the range of +// possible reset codes the peer can send. +const streamResetByConnClose = math.MaxInt64 + +// wantQueue returns the send queue the stream should be on. +func (s streamState) wantQueue() streamQueue { + switch { + case s&(streamInSendMeta|streamOutSendMeta) != 0: + return metaQueue + case s&(streamInDone|streamOutDone|streamConnRemoved) == streamInDone|streamOutDone: + return metaQueue + case s&streamOutSendData != 0: + // The stream has no non-flow-controlled frames to send, + // but does have data. Put it on the data queue, which is only + // processed when flow control is available. + return dataQueue + } + return noQueue +} + +// inQueue returns the send queue the stream is currently on. +func (s streamState) inQueue() streamQueue { + switch { + case s&streamQueueMeta != 0: + return metaQueue + case s&streamQueueData != 0: + return dataQueue + } + return noQueue +} + +// newStream returns a new stream. +// +// The stream's ingate and outgate are locked. +// (We create the stream with locked gates so after the caller +// initializes the flow control window, +// unlocking outgate will set the stream writability state.) +func newStream(c *Conn, id streamID) *Stream { + s := &Stream{ + conn: c, + id: id, + insize: -1, // -1 indicates the stream size is unknown + inresetcode: -1, // -1 indicates no RESET_STREAM received + ingate: newLockedGate(), + outgate: newLockedGate(), + inctx: context.Background(), + outctx: context.Background(), + } + if !s.IsReadOnly() { + s.outdone = make(chan struct{}) + } + return s +} + +// ID returns the QUIC stream ID of s. +// +// As specified in RFC 9000, the two least significant bits of a stream ID +// indicate the initiator and directionality of the stream. The upper bits are +// the stream number. +func (s *Stream) ID() int64 { + return int64(s.id) +} + +// SetReadContext sets the context used for reads from the stream. +// +// It is not safe to call SetReadContext concurrently. +func (s *Stream) SetReadContext(ctx context.Context) { + s.inctx = ctx +} + +// SetWriteContext sets the context used for writes to the stream. +// The write context is also used by Close when waiting for writes to be +// received by the peer. +// +// It is not safe to call SetWriteContext concurrently. +func (s *Stream) SetWriteContext(ctx context.Context) { + s.outctx = ctx +} + +// IsReadOnly reports whether the stream is read-only +// (a unidirectional stream created by the peer). +func (s *Stream) IsReadOnly() bool { + return s.id.streamType() == uniStream && s.id.initiator() != s.conn.side +} + +// IsWriteOnly reports whether the stream is write-only +// (a unidirectional stream created locally). +func (s *Stream) IsWriteOnly() bool { + return s.id.streamType() == uniStream && s.id.initiator() == s.conn.side +} + +// Read reads data from the stream. +// +// Read returns as soon as at least one byte of data is available. +// +// If the peer closes the stream cleanly, Read returns io.EOF after +// returning all data sent by the peer. +// If the peer aborts reads on the stream, Read returns +// an error wrapping StreamResetCode. +// +// It is not safe to call Read concurrently. +func (s *Stream) Read(b []byte) (n int, err error) { + if s.IsWriteOnly() { + return 0, errors.New("read from write-only stream") + } + if len(s.inbuf) > s.inbufoff { + // Fast path: If s.inbuf contains unread bytes, return them immediately + // without taking a lock. + n = copy(b, s.inbuf[s.inbufoff:]) + s.inbufoff += n + return n, nil + } + if err := s.ingate.waitAndLock(s.inctx); err != nil { + return 0, err + } + if s.inbufoff > 0 { + // Discard bytes consumed by the fast path above. + s.in.discardBefore(s.in.start + int64(s.inbufoff)) + s.inbufoff = 0 + s.inbuf = nil + } + // bytesRead contains the number of bytes of connection-level flow control to return. + // We return flow control for bytes read by this Read call, as well as bytes moved + // to the fast-path read buffer (s.inbuf). + var bytesRead int64 + defer func() { + s.inUnlock() + s.conn.handleStreamBytesReadOffLoop(bytesRead) // must be done with ingate unlocked + }() + if s.inresetcode != -1 { + if s.inresetcode == streamResetByConnClose { + if err := s.conn.finalError(); err != nil { + return 0, err + } + } + return 0, fmt.Errorf("stream reset by peer: %w", StreamErrorCode(s.inresetcode)) + } + if s.inclosed.isSet() { + return 0, errors.New("read from closed stream") + } + if s.insize == s.in.start { + return 0, io.EOF + } + // Getting here indicates the stream contains data to be read. + if len(s.inset) < 1 || s.inset[0].start != 0 || s.inset[0].end <= s.in.start { + panic("BUG: inconsistent input stream state") + } + if size := int(s.inset[0].end - s.in.start); size < len(b) { + b = b[:size] + } + bytesRead = int64(len(b)) + start := s.in.start + end := start + int64(len(b)) + s.in.copy(start, b) + s.in.discardBefore(end) + if end == s.insize { + // We have read up to the end of the stream. + // No need to update stream flow control. + return len(b), io.EOF + } + if len(s.inset) > 0 && s.inset[0].start <= s.in.start && s.inset[0].end > s.in.start { + // If we have more readable bytes available, put the next chunk of data + // in s.inbuf for lock-free reads. + s.inbuf = s.in.peek(s.inset[0].end - s.in.start) + bytesRead += int64(len(s.inbuf)) + } + if s.insize == -1 || s.insize > s.inwin { + newWindow := s.in.start + int64(len(s.inbuf)) + s.inmaxbuf + addedWindow := newWindow - s.inwin + if shouldUpdateFlowControl(s.inmaxbuf, addedWindow) { + // Update stream flow control with a STREAM_MAX_DATA frame. + s.insendmax.setUnsent() + } + } + return len(b), nil +} + +// ReadByte reads and returns a single byte from the stream. +// +// It is not safe to call ReadByte concurrently. +func (s *Stream) ReadByte() (byte, error) { + if len(s.inbuf) > s.inbufoff { + b := s.inbuf[s.inbufoff] + s.inbufoff++ + return b, nil + } + var b [1]byte + n, err := s.Read(b[:]) + if n > 0 { + return b[0], nil + } + return 0, err +} + +// shouldUpdateFlowControl determines whether to send a flow control window update. +// +// We want to balance keeping the peer well-supplied with flow control with not sending +// many small updates. +func shouldUpdateFlowControl(maxWindow, addedWindow int64) bool { + return addedWindow >= maxWindow/8 +} + +// Write writes data to the stream. +// +// Write writes data to the stream write buffer. +// Buffered data is only sent when the buffer is sufficiently full. +// Call the Flush method to ensure buffered data is sent. +func (s *Stream) Write(b []byte) (n int, err error) { + if s.IsReadOnly() { + return 0, errors.New("write to read-only stream") + } + if len(b) > 0 && len(s.outbuf)-s.outbufoff >= len(b) { + // Fast path: The data to write fits in s.outbuf. + copy(s.outbuf[s.outbufoff:], b) + s.outbufoff += len(b) + return len(b), nil + } + canWrite := s.outgate.lock() + s.flushFastOutputBuffer() + for { + // The first time through this loop, we may or may not be write blocked. + // We exit the loop after writing all data, so on subsequent passes through + // the loop we are always write blocked. + if len(b) > 0 && !canWrite { + // Our send buffer is full. Wait for the peer to ack some data. + s.outUnlock() + if err := s.outgate.waitAndLock(s.outctx); err != nil { + return n, err + } + // Successfully returning from waitAndLockGate means we are no longer + // write blocked. (Unlike traditional condition variables, gates do not + // have spurious wakeups.) + } + if err := s.writeErrorLocked(); err != nil { + s.outUnlock() + return n, err + } + if len(b) == 0 { + break + } + // Write limit is our send buffer limit. + // This is a stream offset. + lim := s.out.start + s.outmaxbuf + // Amount to write is min(the full buffer, data up to the write limit). + // This is a number of bytes. + nn := min(int64(len(b)), lim-s.out.end) + // Copy the data into the output buffer. + s.out.writeAt(b[:nn], s.out.end) + b = b[nn:] + n += int(nn) + // Possibly flush the output buffer. + // We automatically flush if: + // - We have enough data to consume the send window. + // Sending this data may cause the peer to extend the window. + // - We have buffered as much data as we're willing do. + // We need to send data to clear out buffer space. + // - We have enough data to fill a 1-RTT packet using the smallest + // possible maximum datagram size (1200 bytes, less header byte, + // connection ID, packet number, and AEAD overhead). + const autoFlushSize = smallestMaxDatagramSize - 1 - connIDLen - 1 - aeadOverhead + shouldFlush := s.out.end >= s.outwin || // peer send window is full + s.out.end >= lim || // local send buffer is full + (s.out.end-s.outflushed) >= autoFlushSize // enough data buffered + if shouldFlush { + s.flushLocked() + } + if s.out.end > s.outwin { + // We're blocked by flow control. + // Send a STREAM_DATA_BLOCKED frame to let the peer know. + s.outblocked.set() + } + // If we have bytes left to send, we're blocked. + canWrite = false + } + if lim := s.out.start + s.outmaxbuf - s.out.end - 1; lim > 0 { + // If s.out has space allocated and available to be written into, + // then reference it in s.outbuf for fast-path writes. + // + // It's perhaps a bit pointless to limit s.outbuf to the send buffer limit. + // We've already allocated this buffer so we aren't saving any memory + // by not using it. + // For now, we limit it anyway to make it easier to reason about limits. + // + // We set the limit to one less than the send buffer limit (the -1 above) + // so that a write which completely fills the buffer will overflow + // s.outbuf and trigger a flush. + s.outbuf = s.out.availableBuffer() + if int64(len(s.outbuf)) > lim { + s.outbuf = s.outbuf[:lim] + } + } + s.outUnlock() + return n, nil +} + +// WriteByte writes a single byte to the stream. +func (s *Stream) WriteByte(c byte) error { + if s.outbufoff < len(s.outbuf) { + s.outbuf[s.outbufoff] = c + s.outbufoff++ + return nil + } + b := [1]byte{c} + _, err := s.Write(b[:]) + return err +} + +func (s *Stream) flushFastOutputBuffer() { + if s.outbuf == nil { + return + } + // Commit data previously written to s.outbuf. + // s.outbuf is a reference to a buffer in s.out, so we just need to record + // that the output buffer has been extended. + s.out.end += int64(s.outbufoff) + s.outbuf = nil + s.outbufoff = 0 +} + +// Flush flushes data written to the stream. +// It does not wait for the peer to acknowledge receipt of the data. +// Use Close to wait for the peer's acknowledgement. +func (s *Stream) Flush() error { + if s.IsReadOnly() { + return errors.New("flush of read-only stream") + } + s.outgate.lock() + defer s.outUnlock() + if err := s.writeErrorLocked(); err != nil { + return err + } + s.flushLocked() + return nil +} + +// writeErrorLocked returns the error (if any) which should be returned by write operations +// due to the stream being reset or closed. +func (s *Stream) writeErrorLocked() error { + if s.outreset.isSet() { + if s.outresetcode == streamResetByConnClose { + if err := s.conn.finalError(); err != nil { + return err + } + } + return errors.New("write to reset stream") + } + if s.outclosed.isSet() { + return errors.New("write to closed stream") + } + return nil +} + +func (s *Stream) flushLocked() { + s.flushFastOutputBuffer() + s.outopened.set() + if s.outflushed < s.outwin { + s.outunsent.add(s.outflushed, min(s.outwin, s.out.end)) + } + s.outflushed = s.out.end +} + +// Close closes the stream. +// Any blocked stream operations will be unblocked and return errors. +// +// Close flushes any data in the stream write buffer and waits for the peer to +// acknowledge receipt of the data. +// If the stream has been reset, it waits for the peer to acknowledge the reset. +// If the context expires before the peer receives the stream's data, +// Close discards the buffer and returns the context error. +func (s *Stream) Close() error { + s.CloseRead() + if s.IsReadOnly() { + return nil + } + s.CloseWrite() + // TODO: Return code from peer's RESET_STREAM frame? + if err := s.conn.waitOnDone(s.outctx, s.outdone); err != nil { + return err + } + s.outgate.lock() + defer s.outUnlock() + if s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end) { + return nil + } + return errors.New("stream reset") +} + +// CloseRead aborts reads on the stream. +// Any blocked reads will be unblocked and return errors. +// +// CloseRead notifies the peer that the stream has been closed for reading. +// It does not wait for the peer to acknowledge the closure. +// Use Close to wait for the peer's acknowledgement. +func (s *Stream) CloseRead() { + if s.IsWriteOnly() { + return + } + s.ingate.lock() + if s.inset.isrange(0, s.insize) || s.inresetcode != -1 { + // We've already received all data from the peer, + // so there's no need to send STOP_SENDING. + // This is the same as saying we sent one and they got it. + s.inclosed.setReceived() + } else { + s.inclosed.set() + } + discarded := s.in.end - s.in.start + s.in.discardBefore(s.in.end) + s.inUnlock() + s.conn.handleStreamBytesReadOffLoop(discarded) // must be done with ingate unlocked +} + +// CloseWrite aborts writes on the stream. +// Any blocked writes will be unblocked and return errors. +// +// CloseWrite sends any data in the stream write buffer to the peer. +// It does not wait for the peer to acknowledge receipt of the data. +// Use Close to wait for the peer's acknowledgement. +func (s *Stream) CloseWrite() { + if s.IsReadOnly() { + return + } + s.outgate.lock() + defer s.outUnlock() + s.outclosed.set() + s.flushLocked() +} + +// Reset aborts writes on the stream and notifies the peer +// that the stream was terminated abruptly. +// Any blocked writes will be unblocked and return errors. +// +// Reset sends the application protocol error code, which must be +// less than 2^62, to the peer. +// It does not wait for the peer to acknowledge receipt of the error. +// Use Close to wait for the peer's acknowledgement. +// +// Reset does not affect reads. +// Use CloseRead to abort reads on the stream. +func (s *Stream) Reset(code uint64) { + const userClosed = true + s.resetInternal(code, userClosed) +} + +// resetInternal resets the send side of the stream. +// +// If userClosed is true, this is s.Reset. +// If userClosed is false, this is a reaction to a STOP_SENDING frame. +func (s *Stream) resetInternal(code uint64, userClosed bool) { + s.outgate.lock() + defer s.outUnlock() + if s.IsReadOnly() { + return + } + if userClosed { + // Mark that the user closed the stream. + s.outclosed.set() + } + if s.outreset.isSet() { + return + } + if code > quicwire.MaxVarint { + code = quicwire.MaxVarint + } + // We could check here to see if the stream is closed and the + // peer has acked all the data and the FIN, but sending an + // extra RESET_STREAM in this case is harmless. + s.outreset.set() + s.outresetcode = code + s.outbuf = nil + s.outbufoff = 0 + s.out.discardBefore(s.out.end) + s.outunsent = rangeset[int64]{} + s.outblocked.clear() +} + +// connHasClosed indicates the stream's conn has closed. +func (s *Stream) connHasClosed() { + // If we're in the closing state, the user closed the conn. + // Otherwise, we the peer initiated the close. + // This only matters for the error we're going to return from stream operations. + localClose := s.conn.lifetime.state == connStateClosing + + s.ingate.lock() + if !s.inset.isrange(0, s.insize) && s.inresetcode == -1 { + if localClose { + s.inclosed.set() + } else { + s.inresetcode = streamResetByConnClose + } + } + s.inUnlock() + + s.outgate.lock() + if localClose { + s.outclosed.set() + s.outreset.set() + } else { + s.outresetcode = streamResetByConnClose + s.outreset.setReceived() + } + s.outUnlock() +} + +// inUnlock unlocks s.ingate. +// It sets the gate condition if reads from s will not block. +// If s has receive-related frames to write or if both directions +// are done and the stream should be removed, it notifies the Conn. +func (s *Stream) inUnlock() { + state := s.inUnlockNoQueue() + s.conn.maybeQueueStreamForSend(s, state) +} + +// inUnlockNoQueue is inUnlock, +// but reports whether s has frames to write rather than notifying the Conn. +func (s *Stream) inUnlockNoQueue() streamState { + nextByte := s.in.start + int64(len(s.inbuf)) + canRead := s.inset.contains(nextByte) || // data available to read + s.insize == s.in.start+int64(len(s.inbuf)) || // at EOF + s.inresetcode != -1 || // reset by peer + s.inclosed.isSet() // closed locally + defer s.ingate.unlock(canRead) + var state streamState + switch { + case s.IsWriteOnly(): + state = streamInDone + case s.inresetcode != -1: // reset by peer + fallthrough + case s.in.start == s.insize: // all data received and read + // We don't increase MAX_STREAMS until the user calls ReadClose or Close, + // so the receive side is not finished until inclosed is set. + if s.inclosed.isSet() { + state = streamInDone + } + case s.insendmax.shouldSend(): // STREAM_MAX_DATA + state = streamInSendMeta + case s.inclosed.shouldSend(): // STOP_SENDING + state = streamInSendMeta + } + const mask = streamInDone | streamInSendMeta + return s.state.set(state, mask) +} + +// outUnlock unlocks s.outgate. +// It sets the gate condition if writes to s will not block. +// If s has send-related frames to write or if both directions +// are done and the stream should be removed, it notifies the Conn. +func (s *Stream) outUnlock() { + state := s.outUnlockNoQueue() + s.conn.maybeQueueStreamForSend(s, state) +} + +// outUnlockNoQueue is outUnlock, +// but reports whether s has frames to write rather than notifying the Conn. +func (s *Stream) outUnlockNoQueue() streamState { + isDone := s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end) || // all data acked + s.outreset.isSet() // reset locally + if isDone { + select { + case <-s.outdone: + default: + if !s.IsReadOnly() { + close(s.outdone) + } + } + } + lim := s.out.start + s.outmaxbuf + canWrite := lim > s.out.end || // available send buffer + s.outclosed.isSet() || // closed locally + s.outreset.isSet() // reset locally + defer s.outgate.unlock(canWrite) + var state streamState + switch { + case s.IsReadOnly(): + state = streamOutDone + case s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end): // all data sent and acked + fallthrough + case s.outreset.isReceived(): // RESET_STREAM sent and acked + // We don't increase MAX_STREAMS until the user calls WriteClose or Close, + // so the send side is not finished until outclosed is set. + if s.outclosed.isSet() { + state = streamOutDone + } + case s.outreset.shouldSend(): // RESET_STREAM + state = streamOutSendMeta + case s.outreset.isSet(): // RESET_STREAM sent but not acknowledged + case s.outblocked.shouldSend(): // STREAM_DATA_BLOCKED + state = streamOutSendMeta + case len(s.outunsent) > 0: // STREAM frame with data + if s.outunsent.min() < s.outmaxsent { + state = streamOutSendMeta // resent data, will not consume flow control + } else { + state = streamOutSendData // new data, requires flow control + } + case s.outclosed.shouldSend() && s.out.end == s.outmaxsent: // empty STREAM frame with FIN bit + state = streamOutSendMeta + case s.outopened.shouldSend(): // STREAM frame with no data + state = streamOutSendMeta + } + const mask = streamOutDone | streamOutSendMeta | streamOutSendData + return s.state.set(state, mask) +} + +// handleData handles data received in a STREAM frame. +func (s *Stream) handleData(off int64, b []byte, fin bool) error { + s.ingate.lock() + defer s.inUnlock() + end := off + int64(len(b)) + if err := s.checkStreamBounds(end, fin); err != nil { + return err + } + if s.inclosed.isSet() || s.inresetcode != -1 { + // The user read-closed the stream, or the peer reset it. + // Either way, we can discard this frame. + return nil + } + if s.insize == -1 && end > s.in.end { + added := end - s.in.end + if err := s.conn.handleStreamBytesReceived(added); err != nil { + return err + } + } + s.in.writeAt(b, off) + s.inset.add(off, end) + if fin { + s.insize = end + // The peer has enough flow control window to send the entire stream. + s.insendmax.clear() + } + return nil +} + +// handleReset handles a RESET_STREAM frame. +func (s *Stream) handleReset(code uint64, finalSize int64) error { + s.ingate.lock() + defer s.inUnlock() + const fin = true + if err := s.checkStreamBounds(finalSize, fin); err != nil { + return err + } + if s.inresetcode != -1 { + // The stream was already reset. + return nil + } + if s.insize == -1 { + added := finalSize - s.in.end + if err := s.conn.handleStreamBytesReceived(added); err != nil { + return err + } + } + s.conn.handleStreamBytesReadOnLoop(finalSize - s.in.start) + s.in.discardBefore(s.in.end) + s.inresetcode = int64(code) + s.insize = finalSize + return nil +} + +// checkStreamBounds validates the stream offset in a STREAM or RESET_STREAM frame. +func (s *Stream) checkStreamBounds(end int64, fin bool) error { + if end > s.inwin { + // The peer sent us data past the maximum flow control window we gave them. + return localTransportError{ + code: errFlowControl, + reason: "stream flow control window exceeded", + } + } + if s.insize != -1 && end > s.insize { + // The peer sent us data past the final size of the stream they previously gave us. + return localTransportError{ + code: errFinalSize, + reason: "data received past end of stream", + } + } + if fin && s.insize != -1 && end != s.insize { + // The peer changed the final size of the stream. + return localTransportError{ + code: errFinalSize, + reason: "final size of stream changed", + } + } + if fin && end < s.in.end { + // The peer has previously sent us data past the final size. + return localTransportError{ + code: errFinalSize, + reason: "end of stream occurs before prior data", + } + } + return nil +} + +// handleStopSending handles a STOP_SENDING frame. +func (s *Stream) handleStopSending(code uint64) error { + // Peer requests that we reset this stream. + // https://www.rfc-editor.org/rfc/rfc9000#section-3.5-4 + const userReset = false + s.resetInternal(code, userReset) + return nil +} + +// handleMaxStreamData handles an update received in a MAX_STREAM_DATA frame. +func (s *Stream) handleMaxStreamData(maxStreamData int64) error { + s.outgate.lock() + defer s.outUnlock() + if maxStreamData <= s.outwin { + return nil + } + if s.outflushed > s.outwin { + s.outunsent.add(s.outwin, min(maxStreamData, s.outflushed)) + } + s.outwin = maxStreamData + if s.out.end > s.outwin { + // We've still got more data than flow control window. + s.outblocked.setUnsent() + } else { + s.outblocked.clear() + } + return nil +} + +// ackOrLoss handles the fate of stream frames other than STREAM. +func (s *Stream) ackOrLoss(pnum packetNumber, ftype byte, fate packetFate) { + // Frames which carry new information each time they are sent + // (MAX_STREAM_DATA, STREAM_DATA_BLOCKED) must only be marked + // as received if the most recent packet carrying this frame is acked. + // + // Frames which are always the same (STOP_SENDING, RESET_STREAM) + // can be marked as received if any packet carrying this frame is acked. + switch ftype { + case frameTypeResetStream: + s.outgate.lock() + s.outreset.ackOrLoss(pnum, fate) + s.outUnlock() + case frameTypeStopSending: + s.ingate.lock() + s.inclosed.ackOrLoss(pnum, fate) + s.inUnlock() + case frameTypeMaxStreamData: + s.ingate.lock() + s.insendmax.ackLatestOrLoss(pnum, fate) + s.inUnlock() + case frameTypeStreamDataBlocked: + s.outgate.lock() + s.outblocked.ackLatestOrLoss(pnum, fate) + s.outUnlock() + default: + panic("unhandled frame type") + } +} + +// ackOrLossData handles the fate of a STREAM frame. +func (s *Stream) ackOrLossData(pnum packetNumber, start, end int64, fin bool, fate packetFate) { + s.outgate.lock() + defer s.outUnlock() + s.outopened.ackOrLoss(pnum, fate) + if fin { + s.outclosed.ackOrLoss(pnum, fate) + } + if s.outreset.isSet() { + // If the stream has been reset, we don't care any more. + return + } + switch fate { + case packetAcked: + s.outacked.add(start, end) + s.outunsent.sub(start, end) + // If this ack is for data at the start of the send buffer, we can now discard it. + if s.outacked.contains(s.out.start) { + s.out.discardBefore(s.outacked[0].end) + } + case packetLost: + // Mark everything lost, but not previously acked, as needing retransmission. + // We do this by adding all the lost bytes to outunsent, and then + // removing everything already acked. + s.outunsent.add(start, end) + for _, a := range s.outacked { + s.outunsent.sub(a.start, a.end) + } + } +} + +// appendInFramesLocked appends STOP_SENDING and MAX_STREAM_DATA frames +// to the current packet. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (s *Stream) appendInFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool { + if s.inclosed.shouldSendPTO(pto) { + // We don't currently have an API for setting the error code. + // Just send zero. + code := uint64(0) + if !w.appendStopSendingFrame(s.id, code) { + return false + } + s.inclosed.setSent(pnum) + } + // TODO: STOP_SENDING + if s.insendmax.shouldSendPTO(pto) { + // MAX_STREAM_DATA + maxStreamData := s.in.start + s.inmaxbuf + if !w.appendMaxStreamDataFrame(s.id, maxStreamData) { + return false + } + s.inwin = maxStreamData + s.insendmax.setSent(pnum) + } + return true +} + +// appendOutFramesLocked appends RESET_STREAM, STREAM_DATA_BLOCKED, and STREAM frames +// to the current packet. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool { + if s.outreset.isSet() { + // RESET_STREAM + if s.outreset.shouldSendPTO(pto) { + if !w.appendResetStreamFrame(s.id, s.outresetcode, min(s.outwin, s.out.end)) { + return false + } + s.outreset.setSent(pnum) + s.frameOpensStream(pnum) + } + return true + } + if s.outblocked.shouldSendPTO(pto) { + // STREAM_DATA_BLOCKED + if !w.appendStreamDataBlockedFrame(s.id, s.outwin) { + return false + } + s.outblocked.setSent(pnum) + s.frameOpensStream(pnum) + } + for { + // STREAM + off, size := dataToSend(min(s.out.start, s.outwin), min(s.outflushed, s.outwin), s.outunsent, s.outacked, pto) + if end := off + size; end > s.outmaxsent { + // This will require connection-level flow control to send. + end = min(end, s.outmaxsent+s.conn.streams.outflow.avail()) + end = max(end, off) + size = end - off + } + fin := s.outclosed.isSet() && off+size == s.out.end + shouldSend := size > 0 || // have data to send + s.outopened.shouldSendPTO(pto) || // should open the stream + (fin && s.outclosed.shouldSendPTO(pto)) // should close the stream + if !shouldSend { + return true + } + b, added := w.appendStreamFrame(s.id, off, int(size), fin) + if !added { + return false + } + s.out.copy(off, b) + end := off + int64(len(b)) + if end > s.outmaxsent { + s.conn.streams.outflow.consume(end - s.outmaxsent) + s.outmaxsent = end + } + s.outunsent.sub(off, end) + s.frameOpensStream(pnum) + if fin { + s.outclosed.setSent(pnum) + } + if pto { + return true + } + if int64(len(b)) < size { + return false + } + } +} + +// frameOpensStream records that we're sending a frame that will open the stream. +// +// If we don't have an acknowledgement from the peer for a previous frame opening the stream, +// record this packet as being the latest one to open it. +func (s *Stream) frameOpensStream(pnum packetNumber) { + if !s.outopened.isReceived() { + s.outopened.setSent(pnum) + } +} + +// dataToSend returns the next range of data to send in a STREAM or CRYPTO_STREAM. +func dataToSend(start, end int64, outunsent, outacked rangeset[int64], pto bool) (sendStart, size int64) { + switch { + case pto: + // On PTO, resend unacked data that fits in the probe packet. + // For simplicity, we send the range starting at s.out.start + // (which is definitely unacked, or else we would have discarded it) + // up to the next acked byte (if any). + // + // This may miss unacked data starting after that acked byte, + // but avoids resending data the peer has acked. + for _, r := range outacked { + if r.start > start { + return start, r.start - start + } + } + return start, end - start + case outunsent.numRanges() > 0: + return outunsent.min(), outunsent[0].size() + default: + return end, 0 + } +} diff --git a/src/vendor/golang.org/x/net/quic/stream_limits.go b/src/vendor/golang.org/x/net/quic/stream_limits.go new file mode 100644 index 0000000000..f1abcae99c --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/stream_limits.go @@ -0,0 +1,125 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" +) + +// Limits on the number of open streams. +// Every connection has separate limits for bidirectional and unidirectional streams. +// +// Note that the MAX_STREAMS limit includes closed as well as open streams. +// Closing a stream doesn't enable an endpoint to open a new one; +// only an increase in the MAX_STREAMS limit does. + +// localStreamLimits are limits on the number of open streams created by us. +type localStreamLimits struct { + gate gate + max int64 // peer-provided MAX_STREAMS + opened int64 // number of streams opened by us, -1 when conn is closed +} + +func (lim *localStreamLimits) init() { + lim.gate = newGate() +} + +// open creates a new local stream, blocking until MAX_STREAMS quota is available. +func (lim *localStreamLimits) open(ctx context.Context, c *Conn) (num int64, err error) { + // TODO: Send a STREAMS_BLOCKED when blocked. + if err := lim.gate.waitAndLock(ctx); err != nil { + return 0, err + } + if lim.opened < 0 { + lim.gate.unlock(true) + return 0, errConnClosed + } + num = lim.opened + lim.opened++ + lim.gate.unlock(lim.opened < lim.max) + return num, nil +} + +// connHasClosed indicates the connection has been closed, locally or by the peer. +func (lim *localStreamLimits) connHasClosed() { + lim.gate.lock() + lim.opened = -1 + lim.gate.unlock(true) +} + +// setMax sets the MAX_STREAMS provided by the peer. +func (lim *localStreamLimits) setMax(maxStreams int64) { + lim.gate.lock() + lim.max = max(lim.max, maxStreams) + lim.gate.unlock(lim.opened < lim.max) +} + +// remoteStreamLimits are limits on the number of open streams created by the peer. +type remoteStreamLimits struct { + max int64 // last MAX_STREAMS sent to the peer + opened int64 // number of streams opened by the peer (including subsequently closed ones) + closed int64 // number of peer streams in the "closed" state + maxOpen int64 // how many streams we want to let the peer simultaneously open + sendMax sentVal // set when we should send MAX_STREAMS +} + +func (lim *remoteStreamLimits) init(maxOpen int64) { + lim.maxOpen = maxOpen + lim.max = min(maxOpen, implicitStreamLimit) // initial limit sent in transport parameters + lim.opened = 0 +} + +// open handles the peer opening a new stream. +func (lim *remoteStreamLimits) open(id streamID) error { + num := id.num() + if num >= lim.max { + return localTransportError{ + code: errStreamLimit, + reason: "stream limit exceeded", + } + } + if num >= lim.opened { + lim.opened = num + 1 + lim.maybeUpdateMax() + } + return nil +} + +// close handles the peer closing an open stream. +func (lim *remoteStreamLimits) close() { + lim.closed++ + lim.maybeUpdateMax() +} + +// maybeUpdateMax updates the MAX_STREAMS value we will send to the peer. +func (lim *remoteStreamLimits) maybeUpdateMax() { + newMax := min( + // Max streams the peer can have open at once. + lim.closed+lim.maxOpen, + // Max streams the peer can open with a single frame. + lim.opened+implicitStreamLimit, + ) + avail := lim.max - lim.opened + if newMax > lim.max && (avail < 8 || newMax-lim.max >= 2*avail) { + // If the peer has less than 8 streams, or if increasing the peer's + // stream limit would double it, then send a MAX_STREAMS. + lim.max = newMax + lim.sendMax.setUnsent() + } +} + +// appendFrame appends a MAX_STREAMS frame to the current packet, if necessary. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (lim *remoteStreamLimits) appendFrame(w *packetWriter, typ streamType, pnum packetNumber, pto bool) bool { + if lim.sendMax.shouldSendPTO(pto) { + if !w.appendMaxStreamsFrame(typ, lim.max) { + return false + } + lim.sendMax.setSent(pnum) + } + return true +} diff --git a/src/vendor/golang.org/x/net/quic/tls.go b/src/vendor/golang.org/x/net/quic/tls.go new file mode 100644 index 0000000000..9f6e0bc29a --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/tls.go @@ -0,0 +1,123 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "time" +) + +// startTLS starts the TLS handshake. +func (c *Conn) startTLS(now time.Time, initialConnID []byte, peerHostname string, params transportParameters) error { + tlsConfig := c.config.TLSConfig + if a, _, err := net.SplitHostPort(peerHostname); err == nil { + peerHostname = a + } + if tlsConfig.ServerName == "" && peerHostname != "" { + tlsConfig = tlsConfig.Clone() + tlsConfig.ServerName = peerHostname + } + + c.keysInitial = initialKeys(initialConnID, c.side) + + qconfig := &tls.QUICConfig{TLSConfig: tlsConfig} + if c.side == clientSide { + c.tls = tls.QUICClient(qconfig) + } else { + c.tls = tls.QUICServer(qconfig) + } + c.tls.SetTransportParameters(marshalTransportParameters(params)) + // TODO: We don't need or want a context for cancellation here, + // but users can use a context to plumb values through to hooks defined + // in the tls.Config. Pass through a context. + if err := c.tls.Start(context.TODO()); err != nil { + return err + } + return c.handleTLSEvents(now) +} + +func (c *Conn) handleTLSEvents(now time.Time) error { + for { + e := c.tls.NextEvent() + if c.testHooks != nil { + c.testHooks.handleTLSEvent(e) + } + switch e.Kind { + case tls.QUICNoEvent: + return nil + case tls.QUICSetReadSecret: + if err := checkCipherSuite(e.Suite); err != nil { + return err + } + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + c.keysHandshake.r.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + c.keysAppData.r.init(e.Suite, e.Data) + } + case tls.QUICSetWriteSecret: + if err := checkCipherSuite(e.Suite); err != nil { + return err + } + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + c.keysHandshake.w.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + c.keysAppData.w.init(e.Suite, e.Data) + } + case tls.QUICWriteData: + var space numberSpace + switch e.Level { + case tls.QUICEncryptionLevelInitial: + space = initialSpace + case tls.QUICEncryptionLevelHandshake: + space = handshakeSpace + case tls.QUICEncryptionLevelApplication: + space = appDataSpace + default: + return fmt.Errorf("quic: internal error: write handshake data at level %v", e.Level) + } + c.crypto[space].write(e.Data) + case tls.QUICHandshakeDone: + if c.side == serverSide { + // "[...] the TLS handshake is considered confirmed + // at the server when the handshake completes." + // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2-1 + c.confirmHandshake(now) + } + c.handshakeDone() + case tls.QUICTransportParameters: + params, err := unmarshalTransportParams(e.Data) + if err != nil { + return err + } + if err := c.receiveTransportParameters(params); err != nil { + return err + } + } + } +} + +// handleCrypto processes data received in a CRYPTO frame. +func (c *Conn) handleCrypto(now time.Time, space numberSpace, off int64, data []byte) error { + var level tls.QUICEncryptionLevel + switch space { + case initialSpace: + level = tls.QUICEncryptionLevelInitial + case handshakeSpace: + level = tls.QUICEncryptionLevelHandshake + case appDataSpace: + level = tls.QUICEncryptionLevelApplication + default: + return errors.New("quic: internal error: received CRYPTO frame in unexpected number space") + } + return c.crypto[space].handleCrypto(off, data, func(b []byte) error { + return c.tls.HandleData(level, b) + }) +} diff --git a/src/vendor/golang.org/x/net/quic/transport_params.go b/src/vendor/golang.org/x/net/quic/transport_params.go new file mode 100644 index 0000000000..2734c586de --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/transport_params.go @@ -0,0 +1,283 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "encoding/binary" + "net/netip" + "time" + + "golang.org/x/net/internal/quic/quicwire" +) + +// transportParameters transferred in the quic_transport_parameters TLS extension. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2 +type transportParameters struct { + originalDstConnID []byte + maxIdleTimeout time.Duration + statelessResetToken []byte + maxUDPPayloadSize int64 + initialMaxData int64 + initialMaxStreamDataBidiLocal int64 + initialMaxStreamDataBidiRemote int64 + initialMaxStreamDataUni int64 + initialMaxStreamsBidi int64 + initialMaxStreamsUni int64 + ackDelayExponent int8 + maxAckDelay time.Duration + disableActiveMigration bool + preferredAddrV4 netip.AddrPort + preferredAddrV6 netip.AddrPort + preferredAddrConnID []byte + preferredAddrResetToken []byte + activeConnIDLimit int64 + initialSrcConnID []byte + retrySrcConnID []byte +} + +const ( + defaultParamMaxUDPPayloadSize = 65527 + defaultParamAckDelayExponent = 3 + defaultParamMaxAckDelayMilliseconds = 25 + defaultParamActiveConnIDLimit = 2 +) + +// defaultTransportParameters is initialized to the RFC 9000 default values. +func defaultTransportParameters() transportParameters { + return transportParameters{ + maxUDPPayloadSize: defaultParamMaxUDPPayloadSize, + ackDelayExponent: defaultParamAckDelayExponent, + maxAckDelay: defaultParamMaxAckDelayMilliseconds * time.Millisecond, + activeConnIDLimit: defaultParamActiveConnIDLimit, + } +} + +const ( + paramOriginalDestinationConnectionID = 0x00 + paramMaxIdleTimeout = 0x01 + paramStatelessResetToken = 0x02 + paramMaxUDPPayloadSize = 0x03 + paramInitialMaxData = 0x04 + paramInitialMaxStreamDataBidiLocal = 0x05 + paramInitialMaxStreamDataBidiRemote = 0x06 + paramInitialMaxStreamDataUni = 0x07 + paramInitialMaxStreamsBidi = 0x08 + paramInitialMaxStreamsUni = 0x09 + paramAckDelayExponent = 0x0a + paramMaxAckDelay = 0x0b + paramDisableActiveMigration = 0x0c + paramPreferredAddress = 0x0d + paramActiveConnectionIDLimit = 0x0e + paramInitialSourceConnectionID = 0x0f + paramRetrySourceConnectionID = 0x10 +) + +func marshalTransportParameters(p transportParameters) []byte { + var b []byte + if v := p.originalDstConnID; v != nil { + b = quicwire.AppendVarint(b, paramOriginalDestinationConnectionID) + b = quicwire.AppendVarintBytes(b, v) + } + if v := uint64(p.maxIdleTimeout / time.Millisecond); v != 0 { + b = quicwire.AppendVarint(b, paramMaxIdleTimeout) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(v))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.statelessResetToken; v != nil { + b = quicwire.AppendVarint(b, paramStatelessResetToken) + b = quicwire.AppendVarintBytes(b, v) + } + if v := p.maxUDPPayloadSize; v != defaultParamMaxUDPPayloadSize { + b = quicwire.AppendVarint(b, paramMaxUDPPayloadSize) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxData; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxData) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamDataBidiLocal; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamDataBidiLocal) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamDataBidiRemote; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamDataBidiRemote) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamDataUni; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamDataUni) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamsBidi; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamsBidi) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamsUni; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamsUni) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.ackDelayExponent; v != defaultParamAckDelayExponent { + b = quicwire.AppendVarint(b, paramAckDelayExponent) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := uint64(p.maxAckDelay / time.Millisecond); v != defaultParamMaxAckDelayMilliseconds { + b = quicwire.AppendVarint(b, paramMaxAckDelay) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(v))) + b = quicwire.AppendVarint(b, v) + } + if p.disableActiveMigration { + b = quicwire.AppendVarint(b, paramDisableActiveMigration) + b = append(b, 0) // 0-length value + } + if p.preferredAddrConnID != nil { + b = append(b, paramPreferredAddress) + b = quicwire.AppendVarint(b, uint64(4+2+16+2+1+len(p.preferredAddrConnID)+16)) + b = append(b, p.preferredAddrV4.Addr().AsSlice()...) // 4 bytes + b = binary.BigEndian.AppendUint16(b, p.preferredAddrV4.Port()) // 2 bytes + b = append(b, p.preferredAddrV6.Addr().AsSlice()...) // 16 bytes + b = binary.BigEndian.AppendUint16(b, p.preferredAddrV6.Port()) // 2 bytes + b = quicwire.AppendUint8Bytes(b, p.preferredAddrConnID) // 1 byte + len(conn_id) + b = append(b, p.preferredAddrResetToken...) // 16 bytes + } + if v := p.activeConnIDLimit; v != defaultParamActiveConnIDLimit { + b = quicwire.AppendVarint(b, paramActiveConnectionIDLimit) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialSrcConnID; v != nil { + b = quicwire.AppendVarint(b, paramInitialSourceConnectionID) + b = quicwire.AppendVarintBytes(b, v) + } + if v := p.retrySrcConnID; v != nil { + b = quicwire.AppendVarint(b, paramRetrySourceConnectionID) + b = quicwire.AppendVarintBytes(b, v) + } + return b +} + +func unmarshalTransportParams(params []byte) (transportParameters, error) { + p := defaultTransportParameters() + for len(params) > 0 { + id, n := quicwire.ConsumeVarint(params) + if n < 0 { + return p, localTransportError{code: errTransportParameter} + } + params = params[n:] + val, n := quicwire.ConsumeVarintBytes(params) + if n < 0 { + return p, localTransportError{code: errTransportParameter} + } + params = params[n:] + n = 0 + switch id { + case paramOriginalDestinationConnectionID: + p.originalDstConnID = val + n = len(val) + case paramMaxIdleTimeout: + var v uint64 + v, n = quicwire.ConsumeVarint(val) + // If this is unreasonably large, consider it as no timeout to avoid + // time.Duration overflows. + if v > 1<<32 { + v = 0 + } + p.maxIdleTimeout = time.Duration(v) * time.Millisecond + case paramStatelessResetToken: + if len(val) != 16 { + return p, localTransportError{code: errTransportParameter} + } + p.statelessResetToken = val + n = 16 + case paramMaxUDPPayloadSize: + p.maxUDPPayloadSize, n = quicwire.ConsumeVarintInt64(val) + if p.maxUDPPayloadSize < 1200 { + return p, localTransportError{code: errTransportParameter} + } + case paramInitialMaxData: + p.initialMaxData, n = quicwire.ConsumeVarintInt64(val) + case paramInitialMaxStreamDataBidiLocal: + p.initialMaxStreamDataBidiLocal, n = quicwire.ConsumeVarintInt64(val) + case paramInitialMaxStreamDataBidiRemote: + p.initialMaxStreamDataBidiRemote, n = quicwire.ConsumeVarintInt64(val) + case paramInitialMaxStreamDataUni: + p.initialMaxStreamDataUni, n = quicwire.ConsumeVarintInt64(val) + case paramInitialMaxStreamsBidi: + p.initialMaxStreamsBidi, n = quicwire.ConsumeVarintInt64(val) + if p.initialMaxStreamsBidi > maxStreamsLimit { + return p, localTransportError{code: errTransportParameter} + } + case paramInitialMaxStreamsUni: + p.initialMaxStreamsUni, n = quicwire.ConsumeVarintInt64(val) + if p.initialMaxStreamsUni > maxStreamsLimit { + return p, localTransportError{code: errTransportParameter} + } + case paramAckDelayExponent: + var v uint64 + v, n = quicwire.ConsumeVarint(val) + if v > 20 { + return p, localTransportError{code: errTransportParameter} + } + p.ackDelayExponent = int8(v) + case paramMaxAckDelay: + var v uint64 + v, n = quicwire.ConsumeVarint(val) + if v >= 1<<14 { + return p, localTransportError{code: errTransportParameter} + } + p.maxAckDelay = time.Duration(v) * time.Millisecond + case paramDisableActiveMigration: + p.disableActiveMigration = true + case paramPreferredAddress: + if len(val) < 4+2+16+2+1 { + return p, localTransportError{code: errTransportParameter} + } + p.preferredAddrV4 = netip.AddrPortFrom( + netip.AddrFrom4(*(*[4]byte)(val[:4])), + binary.BigEndian.Uint16(val[4:][:2]), + ) + val = val[4+2:] + p.preferredAddrV6 = netip.AddrPortFrom( + netip.AddrFrom16(*(*[16]byte)(val[:16])), + binary.BigEndian.Uint16(val[16:][:2]), + ) + val = val[16+2:] + var nn int + p.preferredAddrConnID, nn = quicwire.ConsumeUint8Bytes(val) + if nn < 0 { + return p, localTransportError{code: errTransportParameter} + } + val = val[nn:] + if len(val) != 16 { + return p, localTransportError{code: errTransportParameter} + } + p.preferredAddrResetToken = val + val = nil + case paramActiveConnectionIDLimit: + p.activeConnIDLimit, n = quicwire.ConsumeVarintInt64(val) + if p.activeConnIDLimit < 2 { + return p, localTransportError{code: errTransportParameter} + } + case paramInitialSourceConnectionID: + p.initialSrcConnID = val + n = len(val) + case paramRetrySourceConnectionID: + p.retrySrcConnID = val + n = len(val) + default: + n = len(val) + } + if n != len(val) { + return p, localTransportError{code: errTransportParameter} + } + } + return p, nil +} diff --git a/src/vendor/golang.org/x/net/quic/udp.go b/src/vendor/golang.org/x/net/quic/udp.go new file mode 100644 index 0000000000..cf23c5ce88 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp.go @@ -0,0 +1,28 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "net/netip" + +// Per-plaform consts describing support for various features. +// +// const udpECNSupport indicates whether the platform supports setting +// the ECN (Explicit Congestion Notification) IP header bits. +// +// const udpInvalidLocalAddrIsError indicates whether sending a packet +// from an local address not associated with the system is an error. +// For example, assuming 127.0.0.2 is not a local address, does sending +// from it (using IP_PKTINFO or some other such feature) result in an error? + +// unmapAddrPort returns a with any IPv4-mapped IPv6 address prefix removed. +func unmapAddrPort(a netip.AddrPort) netip.AddrPort { + if a.Addr().Is4In6() { + return netip.AddrPortFrom( + a.Addr().Unmap(), + a.Port(), + ) + } + return a +} diff --git a/src/vendor/golang.org/x/net/quic/udp_darwin.go b/src/vendor/golang.org/x/net/quic/udp_darwin.go new file mode 100644 index 0000000000..91e8e81c17 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_darwin.go @@ -0,0 +1,45 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build darwin + +package quic + +import ( + "encoding/binary" + "syscall" +) + +// These socket options are available on darwin, but are not in the syscall +// package. Since syscall package is frozen, just define them manually here. +const ( + ip_recvtos = 0x1b + ipv6_recvpktinfo = 0x3d + ipv6_pktinfo = 0x2e +) + +// See udp.go. +const ( + udpECNSupport = true + udpInvalidLocalAddrIsError = true +) + +// Confusingly, on Darwin the contents of the IP_TOS option differ depending on whether +// it is used as an inbound or outbound cmsg. + +func parseIPTOS(b []byte) (ecnBits, bool) { + // Single byte. The low two bits are the ECN field. + if len(b) != 1 { + return 0, false + } + return ecnBits(b[0] & ecnMask), true +} + +func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { + // 32-bit integer. + // https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/in_tclass.c#L1062-L1073 + b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_TOS, 4) + binary.NativeEndian.PutUint32(data, uint32(ecn)) + return b +} diff --git a/src/vendor/golang.org/x/net/quic/udp_linux.go b/src/vendor/golang.org/x/net/quic/udp_linux.go new file mode 100644 index 0000000000..08deaf9829 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_linux.go @@ -0,0 +1,39 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux + +package quic + +import ( + "syscall" +) + +const ( + ip_recvtos = syscall.IP_RECVTOS + ipv6_recvpktinfo = syscall.IPV6_RECVPKTINFO + ipv6_pktinfo = syscall.IPV6_PKTINFO +) + +// See udp.go. +const ( + udpECNSupport = true + udpInvalidLocalAddrIsError = false +) + +// The IP_TOS socket option is a single byte containing the IP TOS field. +// The low two bits are the ECN field. + +func parseIPTOS(b []byte) (ecnBits, bool) { + if len(b) != 1 { + return 0, false + } + return ecnBits(b[0] & ecnMask), true +} + +func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_TOS, 1) + data[0] = byte(ecn) + return b +} diff --git a/src/vendor/golang.org/x/net/quic/udp_msg.go b/src/vendor/golang.org/x/net/quic/udp_msg.go new file mode 100644 index 0000000000..10909044fe --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_msg.go @@ -0,0 +1,245 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !quicbasicnet && (darwin || linux) + +package quic + +import ( + "encoding/binary" + "net" + "net/netip" + "sync" + "syscall" + "unsafe" +) + +// Network interface for platforms using sendmsg/recvmsg with cmsgs. + +type netUDPConn struct { + c *net.UDPConn + localAddr netip.AddrPort +} + +func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) { + a, _ := uc.LocalAddr().(*net.UDPAddr) + localAddr := a.AddrPort() + if localAddr.Addr().IsUnspecified() { + // If the conn is not bound to a specified (non-wildcard) address, + // then set localAddr.Addr to an invalid netip.Addr. + // This better conveys that this is not an address we should be using, + // and is a bit more efficient to test against. + localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port()) + } + + sc, err := uc.SyscallConn() + if err != nil { + return nil, err + } + sc.Control(func(fd uintptr) { + // Ask for ECN info and (when we aren't bound to a fixed local address) + // destination info. + // + // If any of these calls fail, we won't get the requested information. + // That's fine, we'll gracefully handle the lack. + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ip_recvtos, 1) + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_RECVTCLASS, 1) + if !localAddr.IsValid() { + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1) + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, ipv6_recvpktinfo, 1) + } + }) + + return &netUDPConn{ + c: uc, + localAddr: localAddr, + }, nil +} + +func (c *netUDPConn) Close() error { return c.c.Close() } + +func (c *netUDPConn) LocalAddr() netip.AddrPort { + a, _ := c.c.LocalAddr().(*net.UDPAddr) + return a.AddrPort() +} + +func (c *netUDPConn) Read(f func(*datagram)) { + // We shouldn't ever see all of these messages at the same time, + // but the total is small so just allocate enough space for everything we use. + const ( + inPktinfoSize = 12 // int + in_addr + in_addr + in6PktinfoSize = 20 // in6_addr + int + ipTOSSize = 4 + ipv6TclassSize = 4 + ) + control := make([]byte, 0+ + syscall.CmsgSpace(inPktinfoSize)+ + syscall.CmsgSpace(in6PktinfoSize)+ + syscall.CmsgSpace(ipTOSSize)+ + syscall.CmsgSpace(ipv6TclassSize)) + + for { + d := newDatagram() + n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control) + if err != nil { + return + } + if n == 0 { + continue + } + d.localAddr = c.localAddr + d.peerAddr = unmapAddrPort(peerAddr) + d.b = d.b[:n] + parseControl(d, control[:controlLen]) + f(d) + } +} + +var cmsgPool = sync.Pool{ + New: func() any { + return new([]byte) + }, +} + +func (c *netUDPConn) Write(dgram datagram) error { + controlp := cmsgPool.Get().(*[]byte) + control := *controlp + defer func() { + *controlp = control[:0] + cmsgPool.Put(controlp) + }() + + localIP := dgram.localAddr.Addr() + if localIP.IsValid() { + if localIP.Is4() { + control = appendCmsgIPSourceAddrV4(control, localIP) + } else { + control = appendCmsgIPSourceAddrV6(control, localIP) + } + } + if dgram.ecn != ecnNotECT { + if dgram.peerAddr.Addr().Is4() { + control = appendCmsgECNv4(control, dgram.ecn) + } else { + control = appendCmsgECNv6(control, dgram.ecn) + } + } + + _, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr) + return err +} + +func parseControl(d *datagram, control []byte) { + msgs, err := syscall.ParseSocketControlMessage(control) + if err != nil { + return + } + for _, m := range msgs { + switch m.Header.Level { + case syscall.IPPROTO_IP: + switch m.Header.Type { + case syscall.IP_TOS, ip_recvtos: + // (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS, + // just check for both.) + if ecn, ok := parseIPTOS(m.Data); ok { + d.ecn = ecn + } + case syscall.IP_PKTINFO: + if a, ok := parseInPktinfo(m.Data); ok { + d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) + } + } + case syscall.IPPROTO_IPV6: + switch m.Header.Type { + case syscall.IPV6_TCLASS: + // 32-bit integer containing the traffic class field. + // The low two bits are the ECN field. + if ecn, ok := parseIPv6TCLASS(m.Data); ok { + d.ecn = ecn + } + case ipv6_pktinfo: + if a, ok := parseIn6Pktinfo(m.Data); ok { + d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) + } + } + } + } +} + +// IPV6_TCLASS is specified by RFC 3542 as an int. + +func parseIPv6TCLASS(b []byte) (ecnBits, bool) { + if len(b) != 4 { + return 0, false + } + return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true +} + +func appendCmsgECNv6(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, 4) + binary.NativeEndian.PutUint32(data, uint32(ecn)) + return b +} + +// struct in_pktinfo { +// unsigned int ipi_ifindex; /* send/recv interface index */ +// struct in_addr ipi_spec_dst; /* Local address */ +// struct in_addr ipi_addr; /* IP Header dst address */ +// }; + +// parseInPktinfo returns the destination address from an IP_PKTINFO. +func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) { + if len(b) != 12 { + return netip.Addr{}, false + } + return netip.AddrFrom4([4]byte(b[8:][:4])), true +} + +// appendCmsgIPSourceAddrV4 appends an IP_PKTINFO setting the source address +// for an outbound datagram. +func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte { + // struct in_pktinfo { + // unsigned int ipi_ifindex; /* send/recv interface index */ + // struct in_addr ipi_spec_dst; /* Local address */ + // struct in_addr ipi_addr; /* IP Header dst address */ + // }; + b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_PKTINFO, 12) + ip := src.As4() + copy(data[4:], ip[:]) + return b +} + +// struct in6_pktinfo { +// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ +// unsigned int ipi6_ifindex; /* send/recv interface index */ +// }; + +// parseIn6Pktinfo returns the destination address from an IPV6_PKTINFO. +func parseIn6Pktinfo(b []byte) (netip.Addr, bool) { + if len(b) != 20 { + return netip.Addr{}, false + } + return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true +} + +// appendCmsgIPSourceAddrV6 appends an IPV6_PKTINFO setting the source address +// for an outbound datagram. +func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte { + b, data := appendCmsg(b, syscall.IPPROTO_IPV6, ipv6_pktinfo, 20) + ip := src.As16() + copy(data[0:], ip[:]) + return b +} + +// appendCmsg appends a cmsg with the given level, type, and size to b. +// It returns the new buffer, and the data section of the cmsg. +func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) { + off := len(b) + b = append(b, make([]byte, syscall.CmsgSpace(size))...) + h := (*syscall.Cmsghdr)(unsafe.Pointer(&b[off])) + h.Level = level + h.Type = typ + h.SetLen(syscall.CmsgLen(size)) + return b, b[off+syscall.CmsgSpace(0):][:size] +} diff --git a/src/vendor/golang.org/x/net/quic/udp_other.go b/src/vendor/golang.org/x/net/quic/udp_other.go new file mode 100644 index 0000000000..02e4a5fc23 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_other.go @@ -0,0 +1,62 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build quicbasicnet || !(darwin || linux) + +package quic + +import ( + "net" + "net/netip" +) + +// Lowest common denominator network interface: Basic net.UDPConn, no cmsgs. +// We will not be able to send or receive ECN bits, +// and we will not know what our local address is. +// +// The quicbasicnet build tag allows selecting this interface on any platform. + +// See udp.go. +const ( + udpECNSupport = false + udpInvalidLocalAddrIsError = false +) + +type netUDPConn struct { + c *net.UDPConn +} + +func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) { + return &netUDPConn{ + c: uc, + }, nil +} + +func (c *netUDPConn) Close() error { return c.c.Close() } + +func (c *netUDPConn) LocalAddr() netip.AddrPort { + a, _ := c.c.LocalAddr().(*net.UDPAddr) + return a.AddrPort() +} + +func (c *netUDPConn) Read(f func(*datagram)) { + for { + dgram := newDatagram() + n, peerAddr, err := c.c.ReadFromUDPAddrPort(dgram.b) + if err != nil { + return + } + if n == 0 { + continue + } + dgram.peerAddr = unmapAddrPort(peerAddr) + dgram.b = dgram.b[:n] + f(dgram) + } +} + +func (c *netUDPConn) Write(dgram datagram) error { + _, err := c.c.WriteToUDPAddrPort(dgram.b, dgram.peerAddr) + return err +} diff --git a/src/vendor/golang.org/x/net/quic/udp_packetconn.go b/src/vendor/golang.org/x/net/quic/udp_packetconn.go new file mode 100644 index 0000000000..2c7e71cf61 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_packetconn.go @@ -0,0 +1,67 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "net" + "net/netip" +) + +// netPacketConn is a packetConn implementation wrapping a net.PacketConn. +// +// This is mostly useful for tests, since PacketConn doesn't provide access to +// important features such as identifying the local address packets were received on. +type netPacketConn struct { + c net.PacketConn + localAddr netip.AddrPort +} + +func newNetPacketConn(pc net.PacketConn) (*netPacketConn, error) { + addr, err := addrPortFromAddr(pc.LocalAddr()) + if err != nil { + return nil, err + } + return &netPacketConn{ + c: pc, + localAddr: addr, + }, nil +} + +func (c *netPacketConn) Close() error { + return c.c.Close() +} + +func (c *netPacketConn) LocalAddr() netip.AddrPort { + return c.localAddr +} + +func (c *netPacketConn) Read(f func(*datagram)) { + for { + dgram := newDatagram() + n, peerAddr, err := c.c.ReadFrom(dgram.b) + if err != nil { + return + } + dgram.peerAddr, err = addrPortFromAddr(peerAddr) + if err != nil { + continue + } + dgram.b = dgram.b[:n] + f(dgram) + } +} + +func (c *netPacketConn) Write(dgram datagram) error { + _, err := c.c.WriteTo(dgram.b, net.UDPAddrFromAddrPort(dgram.peerAddr)) + return err +} + +func addrPortFromAddr(addr net.Addr) (netip.AddrPort, error) { + switch a := addr.(type) { + case *net.UDPAddr: + return a.AddrPort(), nil + } + return netip.ParseAddrPort(addr.String()) +} diff --git a/src/vendor/modules.txt b/src/vendor/modules.txt index 35468f6f1d..c7220124e2 100644 --- a/src/vendor/modules.txt +++ b/src/vendor/modules.txt @@ -4,6 +4,7 @@ golang.org/x/crypto/chacha20 golang.org/x/crypto/chacha20poly1305 golang.org/x/crypto/cryptobyte golang.org/x/crypto/cryptobyte/asn1 +golang.org/x/crypto/hkdf golang.org/x/crypto/internal/alias golang.org/x/crypto/internal/poly1305 # golang.org/x/net v0.52.1-0.20260406204716-056ac742146a @@ -12,9 +13,14 @@ golang.org/x/net/dns/dnsmessage golang.org/x/net/http/httpguts golang.org/x/net/http/httpproxy golang.org/x/net/http2/hpack +golang.org/x/net/http3 golang.org/x/net/idna +golang.org/x/net/internal/http3 +golang.org/x/net/internal/httpcommon +golang.org/x/net/internal/quic/quicwire golang.org/x/net/lif golang.org/x/net/nettest +golang.org/x/net/quic # golang.org/x/sys v0.42.1-0.20260320201212-a76ec62d6c53 ## explicit; go 1.25.0 golang.org/x/sys/cpu |
