diff options
Diffstat (limited to 'src/vendor/golang.org/x/net/internal/http3/stream.go')
| -rw-r--r-- | src/vendor/golang.org/x/net/internal/http3/stream.go | 260 |
1 files changed, 260 insertions, 0 deletions
diff --git a/src/vendor/golang.org/x/net/internal/http3/stream.go b/src/vendor/golang.org/x/net/internal/http3/stream.go new file mode 100644 index 0000000000..93294d43de --- /dev/null +++ b/src/vendor/golang.org/x/net/internal/http3/stream.go @@ -0,0 +1,260 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "context" + "io" + + "golang.org/x/net/quic" +) + +// A stream wraps a QUIC stream, providing methods to read/write various values. +type stream struct { + stream *quic.Stream + + // lim is the current read limit. + // Reading a frame header sets the limit to the end of the frame. + // Reading past the limit or reading less than the limit and ending the frame + // results in an error. + // -1 indicates no limit. + lim int64 +} + +// newConnStream creates a new stream on a connection. +// It writes the stream header for unidirectional streams. +// +// The stream returned by newStream is not flushed, +// and will not be sent to the peer until the caller calls +// Flush or writes enough data to the stream. +func newConnStream(ctx context.Context, qconn *quic.Conn, stype streamType) (*stream, error) { + var qs *quic.Stream + var err error + if stype == streamTypeRequest { + // Request streams are bidirectional. + qs, err = qconn.NewStream(ctx) + } else { + // All other streams are unidirectional. + qs, err = qconn.NewSendOnlyStream(ctx) + } + if err != nil { + return nil, err + } + st := &stream{ + stream: qs, + lim: -1, // no limit + } + if stype != streamTypeRequest { + // Unidirectional stream header. + st.writeVarint(int64(stype)) + } + return st, err +} + +func newStream(qs *quic.Stream) *stream { + return &stream{ + stream: qs, + lim: -1, // no limit + } +} + +// readFrameHeader reads the type and length fields of an HTTP/3 frame. +// It sets the read limit to the end of the frame. +// +// https://www.rfc-editor.org/rfc/rfc9114.html#section-7.1 +func (st *stream) readFrameHeader() (ftype frameType, err error) { + if st.lim >= 0 { + // We shouldn't call readFrameHeader before ending the previous frame. + return 0, errH3FrameError + } + ftype, err = readVarint[frameType](st) + if err != nil { + return 0, err + } + size, err := st.readVarint() + if err != nil { + return 0, err + } + st.lim = size + return ftype, nil +} + +// endFrame is called after reading a frame to reset the read limit. +// It returns an error if the entire contents of a frame have not been read. +func (st *stream) endFrame() error { + if st.lim != 0 { + return &connectionError{ + code: errH3FrameError, + message: "invalid HTTP/3 frame", + } + } + st.lim = -1 + return nil +} + +// readFrameData returns the remaining data in the current frame. +func (st *stream) readFrameData() ([]byte, error) { + if st.lim < 0 { + return nil, errH3FrameError + } + // TODO: Pool buffers to avoid allocation here. + b := make([]byte, st.lim) + _, err := io.ReadFull(st, b) + if err != nil { + return nil, err + } + return b, nil +} + +// ReadByte reads one byte from the stream. +func (st *stream) ReadByte() (b byte, err error) { + if err := st.recordBytesRead(1); err != nil { + return 0, err + } + b, err = st.stream.ReadByte() + if err != nil { + if err == io.EOF && st.lim < 0 { + return 0, io.EOF + } + return 0, errH3FrameError + } + return b, nil +} + +// Read reads from the stream. +func (st *stream) Read(b []byte) (int, error) { + n, err := st.stream.Read(b) + if e2 := st.recordBytesRead(n); e2 != nil { + return 0, e2 + } + if err == io.EOF { + if st.lim == 0 { + // EOF at end of frame, ignore. + return n, nil + } else if st.lim > 0 { + // EOF inside frame, error. + return 0, errH3FrameError + } else { + // EOF outside of frame, surface to caller. + return n, io.EOF + } + } + if err != nil { + return 0, errH3FrameError + } + return n, nil +} + +// discardUnknownFrame discards an unknown frame. +// +// HTTP/3 requires that unknown frames be ignored on all streams. +// However, a known frame appearing in an unexpected place is a fatal error, +// so this returns an error if the frame is one we know. +func (st *stream) discardUnknownFrame(ftype frameType) error { + switch ftype { + case frameTypeData, + frameTypeHeaders, + frameTypeCancelPush, + frameTypeSettings, + frameTypePushPromise, + frameTypeGoaway, + frameTypeMaxPushID: + return &connectionError{ + code: errH3FrameUnexpected, + message: "unexpected " + ftype.String() + " frame", + } + } + return st.discardFrame() +} + +// discardFrame discards any remaining data in the current frame and resets the read limit. +func (st *stream) discardFrame() error { + // TODO: Consider adding a *quic.Stream method to discard some amount of data. + for range st.lim { + _, err := st.stream.ReadByte() + if err != nil { + return &streamError{errH3FrameError, err.Error()} + } + } + st.lim = -1 + return nil +} + +// Write writes to the stream. +func (st *stream) Write(b []byte) (int, error) { return st.stream.Write(b) } + +// Flush commits data written to the stream. +func (st *stream) Flush() error { return st.stream.Flush() } + +// readVarint reads a QUIC variable-length integer from the stream. +func (st *stream) readVarint() (v int64, err error) { + b, err := st.stream.ReadByte() + if err != nil { + return 0, err + } + v = int64(b & 0x3f) + n := 1 << (b >> 6) + for i := 1; i < n; i++ { + b, err := st.stream.ReadByte() + if err != nil { + return 0, errH3FrameError + } + v = (v << 8) | int64(b) + } + if err := st.recordBytesRead(n); err != nil { + return 0, err + } + return v, nil +} + +// readVarint reads a varint of a particular type. +func readVarint[T ~int64 | ~uint64](st *stream) (T, error) { + v, err := st.readVarint() + return T(v), err +} + +// writeVarint writes a QUIC variable-length integer to the stream. +func (st *stream) writeVarint(v int64) { + switch { + case v <= (1<<6)-1: + st.stream.WriteByte(byte(v)) + case v <= (1<<14)-1: + st.stream.WriteByte((1 << 6) | byte(v>>8)) + st.stream.WriteByte(byte(v)) + case v <= (1<<30)-1: + st.stream.WriteByte((2 << 6) | byte(v>>24)) + st.stream.WriteByte(byte(v >> 16)) + st.stream.WriteByte(byte(v >> 8)) + st.stream.WriteByte(byte(v)) + case v <= (1<<62)-1: + st.stream.WriteByte((3 << 6) | byte(v>>56)) + st.stream.WriteByte(byte(v >> 48)) + st.stream.WriteByte(byte(v >> 40)) + st.stream.WriteByte(byte(v >> 32)) + st.stream.WriteByte(byte(v >> 24)) + st.stream.WriteByte(byte(v >> 16)) + st.stream.WriteByte(byte(v >> 8)) + st.stream.WriteByte(byte(v)) + default: + panic("varint too large") + } +} + +// recordBytesRead records that n bytes have been read. +// It returns an error if the read passes the current limit. +func (st *stream) recordBytesRead(n int) error { + if st.lim < 0 { + return nil + } + st.lim -= int64(n) + if st.lim < 0 { + st.stream = nil // panic if we try to read again + return &connectionError{ + code: errH3FrameError, + message: "invalid HTTP/3 frame", + } + } + return nil +} |
