diff options
Diffstat (limited to 'src/vendor/golang.org/x/net/quic')
54 files changed, 12121 insertions, 0 deletions
diff --git a/src/vendor/golang.org/x/net/quic/ack_delay.go b/src/vendor/golang.org/x/net/quic/ack_delay.go new file mode 100644 index 0000000000..029ce6faec --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/ack_delay.go @@ -0,0 +1,26 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "math" + "time" +) + +// An unscaledAckDelay is an ACK Delay field value from an ACK packet, +// without the ack_delay_exponent scaling applied. +type unscaledAckDelay int64 + +func unscaledAckDelayFromDuration(d time.Duration, ackDelayExponent uint8) unscaledAckDelay { + return unscaledAckDelay(d.Microseconds() >> ackDelayExponent) +} + +func (d unscaledAckDelay) Duration(ackDelayExponent uint8) time.Duration { + if int64(d) > (math.MaxInt64>>ackDelayExponent)/int64(time.Microsecond) { + // If scaling the delay would overflow, ignore the delay. + return 0 + } + return time.Duration(d<<ackDelayExponent) * time.Microsecond +} diff --git a/src/vendor/golang.org/x/net/quic/acks.go b/src/vendor/golang.org/x/net/quic/acks.go new file mode 100644 index 0000000000..90f82bed03 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/acks.go @@ -0,0 +1,213 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// ackState tracks packets received from a peer within a number space. +// It handles packet deduplication (don't process the same packet twice) and +// determines the timing and content of ACK frames. +type ackState struct { + seen rangeset[packetNumber] + + // The time at which we must send an ACK frame, even if we have no other data to send. + nextAck time.Time + + // The time we received the largest-numbered packet in seen. + maxRecvTime time.Time + + // The largest-numbered ack-eliciting packet in seen. + maxAckEliciting packetNumber + + // The number of ack-eliciting packets in seen that we have not yet acknowledged. + unackedAckEliciting int + + // Total ECN counters for this packet number space. + ecn ecnCounts +} + +type ecnCounts struct { + t0 int + t1 int + ce int +} + +// shouldProcess reports whether a packet should be handled or discarded. +func (acks *ackState) shouldProcess(num packetNumber) bool { + if packetNumber(acks.seen.min()) > num { + // We've discarded the state for this range of packet numbers. + // Discard the packet rather than potentially processing a duplicate. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.3-5 + return false + } + if acks.seen.contains(num) { + // Discard duplicate packets. + return false + } + return true +} + +// receive records receipt of a packet. +func (acks *ackState) receive(now time.Time, space numberSpace, num packetNumber, ackEliciting bool, ecn ecnBits) { + if ackEliciting { + acks.unackedAckEliciting++ + if acks.mustAckImmediately(space, num, ecn) { + acks.nextAck = now + } else if acks.nextAck.IsZero() { + // This packet does not need to be acknowledged immediately, + // but the ack must not be intentionally delayed by more than + // the max_ack_delay transport parameter we sent to the peer. + // + // We always delay acks by the maximum allowed, less the timer + // granularity. ("[max_ack_delay] SHOULD include the receiver's + // expected delays in alarms firing.") + // + // https://www.rfc-editor.org/rfc/rfc9000#section-18.2-4.28.1 + acks.nextAck = now.Add(maxAckDelay - timerGranularity) + } + if num > acks.maxAckEliciting { + acks.maxAckEliciting = num + } + } + + acks.seen.add(num, num+1) + if num == acks.seen.max() { + acks.maxRecvTime = now + } + + switch ecn { + case ecnECT0: + acks.ecn.t0++ + case ecnECT1: + acks.ecn.t1++ + case ecnCE: + acks.ecn.ce++ + } + + // Limit the total number of ACK ranges by dropping older ranges. + // + // Remembering more ranges results in larger ACK frames. + // + // Remembering a large number of ranges could result in ACK frames becoming + // too large to fit in a packet, in which case we will silently drop older + // ranges during packet construction. + // + // Remembering fewer ranges can result in unnecessary retransmissions, + // since we cannot accept packets older than the oldest remembered range. + // + // The limit here is completely arbitrary. If it seems wrong, it probably is. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.3 + const maxAckRanges = 8 + if overflow := acks.seen.numRanges() - maxAckRanges; overflow > 0 { + acks.seen.removeranges(0, overflow) + } +} + +// mustAckImmediately reports whether an ack-eliciting packet must be acknowledged immediately, +// or whether the ack may be deferred. +func (acks *ackState) mustAckImmediately(space numberSpace, num packetNumber, ecn ecnBits) bool { + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1 + if space != appDataSpace { + // "[...] all ack-eliciting Initial and Handshake packets [...]" + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1-2 + return true + } + if num < acks.maxAckEliciting { + // "[...] when the received packet has a packet number less than another + // ack-eliciting packet that has been received [...]" + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1-8.1 + return true + } + if acks.seen.rangeContaining(acks.maxAckEliciting).end != num { + // "[...] when the packet has a packet number larger than the highest-numbered + // ack-eliciting packet that has been received and there are missing packets + // between that packet and this packet." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1-8.2 + // + // This case is a bit tricky. Let's say we've received: + // 0, ack-eliciting + // 1, ack-eliciting + // 3, NOT ack eliciting + // + // We have sent ACKs for 0 and 1. If we receive ack-eliciting packet 2, + // we do not need to send an immediate ACK, because there are no missing + // packets between it and the highest-numbered ack-eliciting packet (1). + // If we receive ack-eliciting packet 4, we do need to send an immediate ACK, + // because there's a gap (the missing packet 2). + // + // We check for this by looking up the ACK range which contains the + // highest-numbered ack-eliciting packet: [0, 1) in the above example. + // If the range ends just before the packet we are now processing, + // there are no gaps. If it does not, there must be a gap. + return true + } + // "[...] packets marked with the ECN Congestion Experienced (CE) codepoint + // in the IP header SHOULD be acknowledged immediately [...]" + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1-9 + if ecn == ecnCE { + return true + } + // "[...] SHOULD send an ACK frame after receiving at least two ack-eliciting packets." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2 + // + // This ack frequency takes a substantial toll on performance, however. + // Follow the behavior of Google QUICHE: + // Ack every other packet for the first 100 packets, and then ack every 10th packet. + // This keeps ack frequency high during the beginning of slow start when CWND is + // increasing rapidly. + packetsBeforeAck := 2 + if acks.seen.max() > 100 { + packetsBeforeAck = 10 + } + return acks.unackedAckEliciting >= packetsBeforeAck +} + +// shouldSendAck reports whether the connection should send an ACK frame at this time, +// in an ACK-only packet if necessary. +func (acks *ackState) shouldSendAck(now time.Time) bool { + return !acks.nextAck.IsZero() && !acks.nextAck.After(now) +} + +// acksToSend returns the set of packet numbers to ACK at this time, and the current ack delay. +// It may return acks even if shouldSendAck returns false, when there are unacked +// ack-eliciting packets whose ack is being delayed. +func (acks *ackState) acksToSend(now time.Time) (nums rangeset[packetNumber], ackDelay time.Duration) { + if acks.nextAck.IsZero() && acks.unackedAckEliciting == 0 { + return nil, 0 + } + // "[...] the delays intentionally introduced between the time the packet with the + // largest packet number is received and the time an acknowledgement is sent." + // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.5-1 + delay := now.Sub(acks.maxRecvTime) + if delay < 0 { + delay = 0 + } + return acks.seen, delay +} + +// sentAck records that an ACK frame has been sent. +func (acks *ackState) sentAck() { + acks.nextAck = time.Time{} + acks.unackedAckEliciting = 0 +} + +// handleAck records that an ack has been received for a ACK frame we sent +// containing the given Largest Acknowledged field. +func (acks *ackState) handleAck(largestAcked packetNumber) { + // We can stop acking packets less or equal to largestAcked. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.4-1 + // + // We rely on acks.seen containing the largest packet number that has been successfully + // processed, so we retain the range containing largestAcked and discard previous ones. + acks.seen.sub(0, acks.seen.rangeContaining(largestAcked).start) +} + +// largestSeen reports the largest seen packet. +func (acks *ackState) largestSeen() packetNumber { + return acks.seen.max() +} diff --git a/src/vendor/golang.org/x/net/quic/atomic_bits.go b/src/vendor/golang.org/x/net/quic/atomic_bits.go new file mode 100644 index 0000000000..5244e04201 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/atomic_bits.go @@ -0,0 +1,31 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "sync/atomic" + +// atomicBits is an atomic uint32 that supports setting individual bits. +type atomicBits[T ~uint32] struct { + bits atomic.Uint32 +} + +// set sets the bits in mask to the corresponding bits in v. +// It returns the new value. +func (a *atomicBits[T]) set(v, mask T) T { + if v&^mask != 0 { + panic("BUG: bits in v are not in mask") + } + for { + o := a.bits.Load() + n := (o &^ uint32(mask)) | uint32(v) + if a.bits.CompareAndSwap(o, n) { + return T(n) + } + } +} + +func (a *atomicBits[T]) load() T { + return T(a.bits.Load()) +} diff --git a/src/vendor/golang.org/x/net/quic/config.go b/src/vendor/golang.org/x/net/quic/config.go new file mode 100644 index 0000000000..a9ec4bc437 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/config.go @@ -0,0 +1,158 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "crypto/tls" + "log/slog" + "math" + "time" + + "golang.org/x/net/internal/quic/quicwire" +) + +// A Config structure configures a QUIC endpoint. +// A Config must not be modified after it has been passed to a QUIC function. +// A Config may be reused; the quic package will also not modify it. +type Config struct { + // TLSConfig is the endpoint's TLS configuration. + // It must be non-nil and include at least one certificate or else set GetCertificate. + TLSConfig *tls.Config + + // MaxBidiRemoteStreams limits the number of simultaneous bidirectional streams + // a peer may open. + // If zero, the default value of 100 is used. + // If negative, the limit is zero. + MaxBidiRemoteStreams int64 + + // MaxUniRemoteStreams limits the number of simultaneous unidirectional streams + // a peer may open. + // If zero, the default value of 100 is used. + // If negative, the limit is zero. + MaxUniRemoteStreams int64 + + // MaxStreamReadBufferSize is the maximum amount of data sent by the peer that a + // stream will buffer for reading. + // If zero, the default value of 1MiB is used. + // If negative, the limit is zero. + MaxStreamReadBufferSize int64 + + // MaxStreamWriteBufferSize is the maximum amount of data a stream will buffer for + // sending to the peer. + // If zero, the default value of 1MiB is used. + // If negative, the limit is zero. + MaxStreamWriteBufferSize int64 + + // MaxConnReadBufferSize is the maximum amount of data sent by the peer that a + // connection will buffer for reading, across all streams. + // If zero, the default value of 1MiB is used. + // If negative, the limit is zero. + MaxConnReadBufferSize int64 + + // RequireAddressValidation may be set to true to enable address validation + // of client connections prior to starting the handshake. + // + // Enabling this setting reduces the amount of work packets with spoofed + // source address information can cause a server to perform, + // at the cost of increased handshake latency. + RequireAddressValidation bool + + // StatelessResetKey is used to provide stateless reset of connections. + // A restart may leave an endpoint without access to the state of + // existing connections. Stateless reset permits an endpoint to respond + // to a packet for a connection it does not recognize. + // + // This field should be filled with random bytes. + // The contents should remain stable across restarts, + // to permit an endpoint to send a reset for + // connections created before a restart. + // + // The contents of the StatelessResetKey should not be exposed. + // An attacker can use knowledge of this field's value to + // reset existing connections. + // + // If this field is left as zero, stateless reset is disabled. + StatelessResetKey [32]byte + + // HandshakeTimeout is the maximum time in which a connection handshake must complete. + // If zero, the default of 10 seconds is used. + // If negative, there is no handshake timeout. + HandshakeTimeout time.Duration + + // MaxIdleTimeout is the maximum time after which an idle connection will be closed. + // If zero, the default of 30 seconds is used. + // If negative, idle connections are never closed. + // + // The idle timeout for a connection is the minimum of the maximum idle timeouts + // of the endpoints. + MaxIdleTimeout time.Duration + + // KeepAlivePeriod is the time after which a packet will be sent to keep + // an idle connection alive. + // If zero, keep alive packets are not sent. + // If greater than zero, the keep alive period is the smaller of KeepAlivePeriod and + // half the connection idle timeout. + KeepAlivePeriod time.Duration + + // QLogLogger receives qlog events. + // + // Events currently correspond to the definitions in draft-ietf-qlog-quic-events-03. + // This is not the latest version of the draft, but is the latest version supported + // by common event log viewers as of the time this paragraph was written. + // + // The qlog package contains a slog.Handler which serializes qlog events + // to a standard JSON representation. + QLogLogger *slog.Logger +} + +// Clone returns a shallow clone of c, or nil if c is nil. +// It is safe to clone a [Config] that is being used concurrently by a QUIC endpoint. +func (c *Config) Clone() *Config { + n := *c + return &n +} + +func configDefault[T ~int64](v, def, limit T) T { + switch { + case v == 0: + return def + case v < 0: + return 0 + default: + return min(v, limit) + } +} + +func (c *Config) maxBidiRemoteStreams() int64 { + return configDefault(c.MaxBidiRemoteStreams, 100, maxStreamsLimit) +} + +func (c *Config) maxUniRemoteStreams() int64 { + return configDefault(c.MaxUniRemoteStreams, 100, maxStreamsLimit) +} + +func (c *Config) maxStreamReadBufferSize() int64 { + return configDefault(c.MaxStreamReadBufferSize, 1<<20, quicwire.MaxVarint) +} + +func (c *Config) maxStreamWriteBufferSize() int64 { + return configDefault(c.MaxStreamWriteBufferSize, 1<<20, quicwire.MaxVarint) +} + +func (c *Config) maxConnReadBufferSize() int64 { + return configDefault(c.MaxConnReadBufferSize, 1<<20, quicwire.MaxVarint) +} + +func (c *Config) handshakeTimeout() time.Duration { + return configDefault(c.HandshakeTimeout, defaultHandshakeTimeout, math.MaxInt64) +} + +func (c *Config) maxIdleTimeout() time.Duration { + return configDefault(c.MaxIdleTimeout, defaultMaxIdleTimeout, math.MaxInt64) +} + +func (c *Config) keepAlivePeriod() time.Duration { + return configDefault(c.KeepAlivePeriod, defaultKeepAlivePeriod, math.MaxInt64) +} diff --git a/src/vendor/golang.org/x/net/quic/congestion_reno.go b/src/vendor/golang.org/x/net/quic/congestion_reno.go new file mode 100644 index 0000000000..028a2ed6c3 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/congestion_reno.go @@ -0,0 +1,306 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "log/slog" + "math" + "time" +) + +// ccReno is the NewReno-based congestion controller defined in RFC 9002. +// https://www.rfc-editor.org/rfc/rfc9002.html#section-7 +type ccReno struct { + maxDatagramSize int + + // Maximum number of bytes allowed to be in flight. + congestionWindow int + + // Sum of size of all packets that contain at least one ack-eliciting + // or PADDING frame (i.e., any non-ACK frame), and have neither been + // acknowledged nor declared lost. + bytesInFlight int + + // When the congestion window is below the slow start threshold, + // the controller is in slow start. + slowStartThreshold int + + // The time the current recovery period started, or zero when not + // in a recovery period. + recoveryStartTime time.Time + + // Accumulated count of bytes acknowledged in congestion avoidance. + congestionPendingAcks int + + // When entering a recovery period, we are allowed to send one packet + // before reducing the congestion window. sendOnePacketInRecovery is + // true if we haven't sent that packet yet. + sendOnePacketInRecovery bool + + // inRecovery is set when we are in the recovery state. + inRecovery bool + + // underutilized is set if the congestion window is underutilized + // due to insufficient application data, flow control limits, or + // anti-amplification limits. + underutilized bool + + // ackLastLoss is the sent time of the newest lost packet processed + // in the current batch. + ackLastLoss time.Time + + // Data tracking the duration of the most recently handled sequence of + // contiguous lost packets. If this exceeds the persistent congestion duration, + // persistent congestion is declared. + // + // https://www.rfc-editor.org/rfc/rfc9002#section-7.6 + persistentCongestion [numberSpaceCount]struct { + start time.Time // send time of first lost packet + end time.Time // send time of last lost packet + next packetNumber // one plus the number of the last lost packet + } +} + +func newReno(maxDatagramSize int) *ccReno { + c := &ccReno{ + maxDatagramSize: maxDatagramSize, + } + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.2-1 + c.congestionWindow = min(10*maxDatagramSize, max(14720, c.minimumCongestionWindow())) + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.1-1 + c.slowStartThreshold = math.MaxInt + + for space := range c.persistentCongestion { + c.persistentCongestion[space].next = -1 + } + return c +} + +// canSend reports whether the congestion controller permits sending +// a maximum-size datagram at this time. +// +// "An endpoint MUST NOT send a packet if it would cause bytes_in_flight [...] +// to be larger than the congestion window [...]" +// https://www.rfc-editor.org/rfc/rfc9002#section-7-7 +// +// For simplicity and efficiency, we don't permit sending undersized datagrams. +func (c *ccReno) canSend() bool { + if c.sendOnePacketInRecovery { + return true + } + return c.bytesInFlight+c.maxDatagramSize <= c.congestionWindow +} + +// setUnderutilized indicates that the congestion window is underutilized. +// +// The congestion window is underutilized if bytes in flight is smaller than +// the congestion window and sending is not pacing limited; that is, the +// congestion controller permits sending data, but no data is sent. +// +// https://www.rfc-editor.org/rfc/rfc9002#section-7.8 +func (c *ccReno) setUnderutilized(log *slog.Logger, v bool) { + if c.underutilized == v { + return + } + oldState := c.state() + c.underutilized = v + if logEnabled(log, QLogLevelPacket) { + logCongestionStateUpdated(log, oldState, c.state()) + } +} + +// packetSent indicates that a packet has been sent. +func (c *ccReno) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) { + if !sent.inFlight { + return + } + c.bytesInFlight += sent.size + if c.sendOnePacketInRecovery { + c.sendOnePacketInRecovery = false + } +} + +// Acked and lost packets are processed in batches +// resulting from either a received ACK frame or +// the loss detection timer expiring. +// +// A batch consists of zero or more calls to packetAcked and packetLost, +// followed by a single call to packetBatchEnd. +// +// Acks may be reported in any order, but lost packets must +// be reported in strictly increasing order. + +// packetAcked indicates that a packet has been newly acknowledged. +func (c *ccReno) packetAcked(now time.Time, sent *sentPacket) { + if !sent.inFlight { + return + } + c.bytesInFlight -= sent.size + + if c.underutilized { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.8 + return + } + if sent.time.Before(c.recoveryStartTime) { + // In recovery, and this packet was sent before we entered recovery. + // (If this packet was sent after we entered recovery, receiving an ack + // for it moves us out of recovery into congestion avoidance.) + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.2 + return + } + c.congestionPendingAcks += sent.size +} + +// packetLost indicates that a packet has been newly marked as lost. +// Lost packets must be reported in increasing order. +func (c *ccReno) packetLost(now time.Time, space numberSpace, sent *sentPacket, rtt *rttState) { + // Record state to check for persistent congestion. + // https://www.rfc-editor.org/rfc/rfc9002#section-7.6 + // + // Note that this relies on always receiving loss events in increasing order: + // All packets prior to the one we're examining now have either been + // acknowledged or declared lost. + isValidPersistentCongestionSample := (sent.ackEliciting && + !rtt.firstSampleTime.IsZero() && + !sent.time.Before(rtt.firstSampleTime)) + if isValidPersistentCongestionSample { + // This packet either extends an existing range of lost packets, + // or starts a new one. + if sent.num != c.persistentCongestion[space].next { + c.persistentCongestion[space].start = sent.time + } + c.persistentCongestion[space].end = sent.time + c.persistentCongestion[space].next = sent.num + 1 + } else { + // This packet cannot establish persistent congestion on its own. + // However, if we have an existing range of lost packets, + // this does not break it. + if sent.num == c.persistentCongestion[space].next { + c.persistentCongestion[space].next = sent.num + 1 + } + } + + if !sent.inFlight { + return + } + c.bytesInFlight -= sent.size + if sent.time.After(c.ackLastLoss) { + c.ackLastLoss = sent.time + } +} + +// packetBatchEnd is called at the end of processing a batch of acked or lost packets. +func (c *ccReno) packetBatchEnd(now time.Time, log *slog.Logger, space numberSpace, rtt *rttState, maxAckDelay time.Duration) { + if logEnabled(log, QLogLevelPacket) { + oldState := c.state() + defer func() { logCongestionStateUpdated(log, oldState, c.state()) }() + } + if !c.ackLastLoss.IsZero() && !c.ackLastLoss.Before(c.recoveryStartTime) { + // Enter the recovery state. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.2 + c.recoveryStartTime = now + c.slowStartThreshold = c.congestionWindow / 2 + c.congestionWindow = max(c.slowStartThreshold, c.minimumCongestionWindow()) + c.sendOnePacketInRecovery = true + // Clear congestionPendingAcks to avoid increasing the congestion + // window based on acks in a frame that sends us into recovery. + c.congestionPendingAcks = 0 + c.inRecovery = true + } else if c.congestionPendingAcks > 0 { + // We are in slow start or congestion avoidance. + c.inRecovery = false + if c.congestionWindow < c.slowStartThreshold { + // When the congestion window is less than the slow start threshold, + // we are in slow start and increase the window by the number of + // bytes acknowledged. + d := min(c.slowStartThreshold-c.congestionWindow, c.congestionPendingAcks) + c.congestionWindow += d + c.congestionPendingAcks -= d + } + // When the congestion window is at or above the slow start threshold, + // we are in congestion avoidance. + // + // RFC 9002 does not specify an algorithm here. The following is + // the recommended algorithm from RFC 5681, in which we increment + // the window by the maximum datagram size every time the number + // of bytes acknowledged reaches cwnd. + for c.congestionPendingAcks > c.congestionWindow { + c.congestionPendingAcks -= c.congestionWindow + c.congestionWindow += c.maxDatagramSize + } + } + if !c.ackLastLoss.IsZero() { + // Check for persistent congestion. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.6 + // + // "A sender [...] MAY use state for just the packet number space that + // was acknowledged." + // https://www.rfc-editor.org/rfc/rfc9002#section-7.6.2-5 + // + // For simplicity, we consider each number space independently. + const persistentCongestionThreshold = 3 + d := (rtt.smoothedRTT + max(4*rtt.rttvar, timerGranularity) + maxAckDelay) * + persistentCongestionThreshold + start := c.persistentCongestion[space].start + end := c.persistentCongestion[space].end + if end.Sub(start) >= d { + c.congestionWindow = c.minimumCongestionWindow() + c.recoveryStartTime = time.Time{} + rtt.establishPersistentCongestion() + } + } + c.ackLastLoss = time.Time{} +} + +// packetDiscarded indicates that the keys for a packet's space have been discarded. +func (c *ccReno) packetDiscarded(sent *sentPacket) { + // https://www.rfc-editor.org/rfc/rfc9002#section-6.2.2-3 + if sent.inFlight { + c.bytesInFlight -= sent.size + } +} + +func (c *ccReno) minimumCongestionWindow() int { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.2-4 + return 2 * c.maxDatagramSize +} + +func logCongestionStateUpdated(log *slog.Logger, oldState, newState congestionState) { + if oldState == newState { + return + } + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:congestion_state_updated", + slog.String("old", oldState.String()), + slog.String("new", newState.String()), + ) +} + +type congestionState string + +func (s congestionState) String() string { return string(s) } + +const ( + congestionSlowStart = congestionState("slow_start") + congestionCongestionAvoidance = congestionState("congestion_avoidance") + congestionApplicationLimited = congestionState("application_limited") + congestionRecovery = congestionState("recovery") +) + +func (c *ccReno) state() congestionState { + switch { + case c.inRecovery: + return congestionRecovery + case c.underutilized: + return congestionApplicationLimited + case c.congestionWindow < c.slowStartThreshold: + return congestionSlowStart + default: + return congestionCongestionAvoidance + } +} diff --git a/src/vendor/golang.org/x/net/quic/conn.go b/src/vendor/golang.org/x/net/quic/conn.go new file mode 100644 index 0000000000..fd812b8a28 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn.go @@ -0,0 +1,434 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + cryptorand "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "log/slog" + "math/rand/v2" + "net/netip" + "time" +) + +// A Conn is a QUIC connection. +// +// Multiple goroutines may invoke methods on a Conn simultaneously. +type Conn struct { + side connSide + endpoint *Endpoint + config *Config + testHooks connTestHooks + peerAddr netip.AddrPort + localAddr netip.AddrPort + prng *rand.Rand + + msgc chan any + donec chan struct{} // closed when conn loop exits + + w packetWriter + acks [numberSpaceCount]ackState // indexed by number space + lifetime lifetimeState + idle idleState + connIDState connIDState + loss lossState + streams streamsState + path pathState + skip skipState + + // Packet protection keys, CRYPTO streams, and TLS state. + keysInitial fixedKeyPair + keysHandshake fixedKeyPair + keysAppData updatingKeyPair + crypto [numberSpaceCount]cryptoStream + tls *tls.QUICConn + + // retryToken is the token provided by the peer in a Retry packet. + retryToken []byte + + // handshakeConfirmed is set when the handshake is confirmed. + // For server connections, it tracks sending HANDSHAKE_DONE. + handshakeConfirmed sentVal + + peerAckDelayExponent int8 // -1 when unknown + + // Tests only: Send a PING in a specific number space. + testSendPingSpace numberSpace + testSendPing sentVal + + log *slog.Logger +} + +// connTestHooks override conn behavior in tests. +type connTestHooks interface { + // init is called after a conn is created. + init(first bool) + + // handleTLSEvent is called with each TLS event. + handleTLSEvent(tls.QUICEvent) + + // newConnID is called to generate a new connection ID. + // Permits tests to generate consistent connection IDs rather than random ones. + newConnID(seq int64) ([]byte, error) +} + +// newServerConnIDs is connection IDs associated with a new server connection. +type newServerConnIDs struct { + srcConnID []byte // source from client's current Initial + dstConnID []byte // destination from client's current Initial + originalDstConnID []byte // destination from client's first Initial + retrySrcConnID []byte // source from server's Retry +} + +func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) { + c := &Conn{ + side: side, + endpoint: e, + config: config, + peerAddr: unmapAddrPort(peerAddr), + donec: make(chan struct{}), + peerAckDelayExponent: -1, + } + defer func() { + // If we hit an error in newConn, close donec so tests don't get stuck waiting for it. + // This is only relevant if we've got a bug, but it makes tracking that bug down + // much easier. + if conn == nil { + close(c.donec) + } + }() + + // A one-element buffer allows us to wake a Conn's event loop as a + // non-blocking operation. + c.msgc = make(chan any, 1) + + if e.testHooks != nil { + e.testHooks.newConn(c) + } + + // initialConnID is the connection ID used to generate Initial packet protection keys. + var initialConnID []byte + if c.side == clientSide { + if err := c.connIDState.initClient(c); err != nil { + return nil, err + } + initialConnID, _ = c.connIDState.dstConnID() + } else { + initialConnID = cids.originalDstConnID + if cids.retrySrcConnID != nil { + initialConnID = cids.retrySrcConnID + } + if err := c.connIDState.initServer(c, cids); err != nil { + return nil, err + } + } + + // A per-conn ChaCha8 PRNG is probably more than we need, + // but at least it's fairly small. + var seed [32]byte + if _, err := cryptorand.Read(seed[:]); err != nil { + panic(err) + } + c.prng = rand.New(rand.NewChaCha8(seed)) + + // TODO: PMTU discovery. + c.logConnectionStarted(cids.originalDstConnID, peerAddr) + c.keysAppData.init() + c.loss.init(c.side, smallestMaxDatagramSize, now) + c.streamsInit() + c.lifetimeInit() + c.restartIdleTimer(now) + c.skip.init(c) + + if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{ + initialSrcConnID: c.connIDState.srcConnID(), + originalDstConnID: cids.originalDstConnID, + retrySrcConnID: cids.retrySrcConnID, + ackDelayExponent: ackDelayExponent, + maxUDPPayloadSize: maxUDPPayloadSize, + maxAckDelay: maxAckDelay, + disableActiveMigration: true, + initialMaxData: config.maxConnReadBufferSize(), + initialMaxStreamDataBidiLocal: config.maxStreamReadBufferSize(), + initialMaxStreamDataBidiRemote: config.maxStreamReadBufferSize(), + initialMaxStreamDataUni: config.maxStreamReadBufferSize(), + initialMaxStreamsBidi: c.streams.remoteLimit[bidiStream].max, + initialMaxStreamsUni: c.streams.remoteLimit[uniStream].max, + activeConnIDLimit: activeConnIDLimit, + }); err != nil { + return nil, err + } + + if c.testHooks != nil { + c.testHooks.init(true) + } + go c.loop(now) + return c, nil +} + +func (c *Conn) String() string { + return fmt.Sprintf("quic.Conn(%v,->%v)", c.side, c.peerAddr) +} + +// LocalAddr returns the local network address, if known. +func (c *Conn) LocalAddr() netip.AddrPort { + return c.localAddr +} + +// RemoteAddr returns the remote network address, if known. +func (c *Conn) RemoteAddr() netip.AddrPort { + return c.peerAddr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() tls.ConnectionState { + return c.tls.ConnectionState() +} + +// confirmHandshake is called when the handshake is confirmed. +// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2 +func (c *Conn) confirmHandshake(now time.Time) { + // If handshakeConfirmed is unset, the handshake is not confirmed. + // If it is unsent, the handshake is confirmed and we need to send a HANDSHAKE_DONE. + // If it is sent, we have sent a HANDSHAKE_DONE. + // If it is received, the handshake is confirmed and we do not need to send anything. + if c.handshakeConfirmed.isSet() { + return // already confirmed + } + if c.side == serverSide { + // When the server confirms the handshake, it sends a HANDSHAKE_DONE. + c.handshakeConfirmed.setUnsent() + c.endpoint.serverConnEstablished(c) + } else { + // The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed + // to the received state, indicating that the handshake is confirmed and we + // don't need to send anything. + c.handshakeConfirmed.setReceived() + } + c.restartIdleTimer(now) + c.loss.confirmHandshake() + // "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed" + // https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1 + c.discardKeys(now, handshakeSpace) +} + +// discardKeys discards unused packet protection keys. +// https://www.rfc-editor.org/rfc/rfc9001#section-4.9 +func (c *Conn) discardKeys(now time.Time, space numberSpace) { + if err := c.crypto[space].discardKeys(); err != nil { + c.abort(now, err) + } + switch space { + case initialSpace: + c.keysInitial.discard() + case handshakeSpace: + c.keysHandshake.discard() + } + c.loss.discardKeys(now, c.log, space) +} + +// receiveTransportParameters applies transport parameters sent by the peer. +func (c *Conn) receiveTransportParameters(p transportParameters) error { + isRetry := c.retryToken != nil + if err := c.connIDState.validateTransportParameters(c, isRetry, p); err != nil { + return err + } + c.streams.outflow.setMaxData(p.initialMaxData) + c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi) + c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni) + c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal + c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote + c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni + c.receivePeerMaxIdleTimeout(p.maxIdleTimeout) + c.peerAckDelayExponent = p.ackDelayExponent + c.loss.setMaxAckDelay(p.maxAckDelay) + if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil { + return err + } + if p.preferredAddrConnID != nil { + var ( + seq int64 = 1 // sequence number of this conn id is 1 + retirePriorTo int64 = 0 // retire nothing + resetToken [16]byte + ) + copy(resetToken[:], p.preferredAddrResetToken) + if err := c.connIDState.handleNewConnID(c, seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil { + return err + } + } + // TODO: stateless_reset_token + // TODO: max_udp_payload_size + // TODO: disable_active_migration + // TODO: preferred_address + return nil +} + +type ( + timerEvent struct{} + wakeEvent struct{} +) + +var errIdleTimeout = errors.New("idle timeout") + +// loop is the connection main loop. +// +// Except where otherwise noted, all connection state is owned by the loop goroutine. +// +// The loop processes messages from c.msgc and timer events. +// Other goroutines may examine or modify conn state by sending the loop funcs to execute. +func (c *Conn) loop(now time.Time) { + defer c.cleanup() + + // The connection timer sends a message to the connection loop on expiry. + // We need to give it an expiry when creating it, so set the initial timeout to + // an arbitrary large value. The timer will be reset before this expires (and it + // isn't a problem if it does anyway). + var lastTimeout time.Time + timer := time.AfterFunc(1*time.Hour, func() { + c.sendMsg(timerEvent{}) + }) + defer timer.Stop() + + for c.lifetime.state != connStateDone { + sendTimeout := c.maybeSend(now) // try sending + + // Note that we only need to consider the ack timer for the App Data space, + // since the Initial and Handshake spaces always ack immediately. + nextTimeout := sendTimeout + nextTimeout = firstTime(nextTimeout, c.idle.nextTimeout) + if c.isAlive() { + nextTimeout = firstTime(nextTimeout, c.loss.timer) + nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck) + } else { + nextTimeout = firstTime(nextTimeout, c.lifetime.drainEndTime) + } + + var m any + if !nextTimeout.IsZero() && nextTimeout.Before(now) { + // A connection timer has expired. + now = time.Now() + m = timerEvent{} + } else { + // Reschedule the connection timer if necessary + // and wait for the next event. + if !nextTimeout.Equal(lastTimeout) && !nextTimeout.IsZero() { + // Resetting a timer created with time.AfterFunc guarantees + // that the timer will run again. We might generate a spurious + // timer event under some circumstances, but that's okay. + timer.Reset(nextTimeout.Sub(now)) + lastTimeout = nextTimeout + } + m = <-c.msgc + now = time.Now() + } + switch m := m.(type) { + case *datagram: + if !c.handleDatagram(now, m) { + if c.logEnabled(QLogLevelPacket) { + c.logPacketDropped(m) + } + } + m.recycle() + case timerEvent: + // A connection timer has expired. + if c.idleAdvance(now) { + // The connection idle timer has expired. + c.abortImmediately(now, errIdleTimeout) + return + } + c.loss.advance(now, c.handleAckOrLoss) + if c.lifetimeAdvance(now) { + // The connection has completed the draining period, + // and may be shut down. + return + } + case wakeEvent: + // We're being woken up to try sending some frames. + case func(time.Time, *Conn): + // Send a func to msgc to run it on the main Conn goroutine + m(now, c) + case func(now, next time.Time, _ *Conn): + // Send a func to msgc to run it on the main Conn goroutine + m(now, nextTimeout, c) + default: + panic(fmt.Sprintf("quic: unrecognized conn message %T", m)) + } + } +} + +func (c *Conn) cleanup() { + c.logConnectionClosed() + c.endpoint.connDrained(c) + c.tls.Close() + close(c.donec) +} + +// sendMsg sends a message to the conn's loop. +// It does not wait for the message to be processed. +// The conn may close before processing the message, in which case it is lost. +func (c *Conn) sendMsg(m any) { + select { + case c.msgc <- m: + case <-c.donec: + } +} + +// wake wakes up the conn's loop. +func (c *Conn) wake() { + select { + case c.msgc <- wakeEvent{}: + default: + } +} + +// runOnLoop executes a function within the conn's loop goroutine. +func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) error { + donec := make(chan struct{}) + msg := func(now time.Time, c *Conn) { + defer close(donec) + f(now, c) + } + c.sendMsg(msg) + select { + case <-donec: + case <-c.donec: + return errors.New("quic: connection closed") + } + return nil +} + +func (c *Conn) waitOnDone(ctx context.Context, ch <-chan struct{}) error { + // Check the channel before the context. + // We always prefer to return results when available, + // even when provided with an already-canceled context. + select { + case <-ch: + return nil + default: + } + select { + case <-ch: + case <-ctx.Done(): + return ctx.Err() + } + return nil +} + +// firstTime returns the earliest non-zero time, or zero if both times are zero. +func firstTime(a, b time.Time) time.Time { + switch { + case a.IsZero(): + return b + case b.IsZero(): + return a + case a.Before(b): + return a + default: + return b + } +} diff --git a/src/vendor/golang.org/x/net/quic/conn_close.go b/src/vendor/golang.org/x/net/quic/conn_close.go new file mode 100644 index 0000000000..d22f3df5c8 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_close.go @@ -0,0 +1,340 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "errors" + "time" +) + +// connState is the state of a connection. +type connState int + +const ( + // A connection is alive when it is first created. + connStateAlive = connState(iota) + + // The connection has received a CONNECTION_CLOSE frame from the peer, + // and has not yet sent a CONNECTION_CLOSE in response. + // + // We will send a CONNECTION_CLOSE, and then enter the draining state. + connStatePeerClosed + + // The connection is in the closing state. + // + // We will send CONNECTION_CLOSE frames to the peer + // (once upon entering the closing state, and possibly again in response to peer packets). + // + // If we receive a CONNECTION_CLOSE from the peer, we will enter the draining state. + // Otherwise, we will eventually time out and move to the done state. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.1 + connStateClosing + + // The connection is in the draining state. + // + // We will neither send packets nor process received packets. + // When the drain timer expires, we move to the done state. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-10.2.2 + connStateDraining + + // The connection is done, and the conn loop will exit. + connStateDone +) + +// lifetimeState tracks the state of a connection. +// +// This is fairly coupled to the rest of a Conn, but putting it in a struct of its own helps +// reason about operations that cause state transitions. +type lifetimeState struct { + state connState + + readyc chan struct{} // closed when TLS handshake completes + donec chan struct{} // closed when finalErr is set + + localErr error // error sent to the peer + finalErr error // error sent by the peer, or transport error; set before closing donec + + connCloseSentTime time.Time // send time of last CONNECTION_CLOSE frame + connCloseDelay time.Duration // delay until next CONNECTION_CLOSE frame sent + drainEndTime time.Time // time the connection exits the draining state +} + +func (c *Conn) lifetimeInit() { + c.lifetime.readyc = make(chan struct{}) + c.lifetime.donec = make(chan struct{}) +} + +var ( + errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE") + errConnClosed = errors.New("connection closed") +) + +// advance is called when time passes. +func (c *Conn) lifetimeAdvance(now time.Time) (done bool) { + if c.lifetime.drainEndTime.IsZero() || c.lifetime.drainEndTime.After(now) { + return false + } + // The connection drain period has ended, and we can shut down. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-7 + c.lifetime.drainEndTime = time.Time{} + if c.lifetime.state != connStateDraining { + // We were in the closing state, waiting for a CONNECTION_CLOSE from the peer. + c.setFinalError(errNoPeerResponse) + } + c.setState(now, connStateDone) + return true +} + +// setState sets the conn state. +func (c *Conn) setState(now time.Time, state connState) { + if c.lifetime.state == state { + return + } + c.lifetime.state = state + switch state { + case connStateClosing, connStateDraining: + if c.lifetime.drainEndTime.IsZero() { + c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) + } + case connStateDone: + c.setFinalError(nil) + } + if state != connStateAlive { + c.streamsCleanup() + } +} + +// handshakeDone is called when the TLS handshake completes. +func (c *Conn) handshakeDone() { + close(c.lifetime.readyc) +} + +// isDraining reports whether the conn is in the draining state. +// +// The draining state is entered once an endpoint receives a CONNECTION_CLOSE frame. +// The endpoint will no longer send any packets, but we retain knowledge of the connection +// until the end of the drain period to ensure we discard packets for the connection +// rather than treating them as starting a new connection. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 +func (c *Conn) isDraining() bool { + switch c.lifetime.state { + case connStateDraining, connStateDone: + return true + } + return false +} + +// isAlive reports whether the conn is handling packets. +func (c *Conn) isAlive() bool { + return c.lifetime.state == connStateAlive +} + +// sendOK reports whether the conn can send frames at this time. +func (c *Conn) sendOK(now time.Time) bool { + switch c.lifetime.state { + case connStateAlive: + return true + case connStatePeerClosed: + if c.lifetime.localErr == nil { + // We're waiting for the user to close the connection, providing us with + // a final status to send to the peer. + return false + } + // We should send a CONNECTION_CLOSE. + return true + case connStateClosing: + if c.lifetime.connCloseSentTime.IsZero() { + return true + } + maxRecvTime := c.acks[initialSpace].maxRecvTime + if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) { + // After sending CONNECTION_CLOSE, ignore packets from the peer for + // a delay. On the next packet received after the delay, send another + // CONNECTION_CLOSE. + return false + } + return true + case connStateDraining: + // We are in the draining state, and will send no more packets. + return false + case connStateDone: + return false + default: + panic("BUG: unhandled connection state") + } +} + +// sentConnectionClose reports that the conn has sent a CONNECTION_CLOSE to the peer. +func (c *Conn) sentConnectionClose(now time.Time) { + switch c.lifetime.state { + case connStatePeerClosed: + c.enterDraining(now) + } + if c.lifetime.connCloseSentTime.IsZero() { + // Set the initial delay before we will send another CONNECTION_CLOSE. + // + // RFC 9000 states that we should rate limit CONNECTION_CLOSE frames, + // but leaves the implementation of the limit up to us. Here, we start + // with the same delay as the PTO timer (RFC 9002, Section 6.2.1), + // not including max_ack_delay, and double it on every CONNECTION_CLOSE sent. + c.lifetime.connCloseDelay = c.loss.rtt.smoothedRTT + max(4*c.loss.rtt.rttvar, timerGranularity) + } else if !c.lifetime.connCloseSentTime.Equal(now) { + // If connCloseSentTime == now, we're sending two CONNECTION_CLOSE frames + // coalesced into the same datagram. We only want to increase the delay once. + c.lifetime.connCloseDelay *= 2 + } + c.lifetime.connCloseSentTime = now +} + +// handlePeerConnectionClose handles a CONNECTION_CLOSE from the peer. +func (c *Conn) handlePeerConnectionClose(now time.Time, err error) { + c.setFinalError(err) + switch c.lifetime.state { + case connStateAlive: + c.setState(now, connStatePeerClosed) + case connStatePeerClosed: + // Duplicate CONNECTION_CLOSE, ignore. + case connStateClosing: + if c.lifetime.connCloseSentTime.IsZero() { + c.setState(now, connStatePeerClosed) + } else { + c.setState(now, connStateDraining) + } + case connStateDraining: + case connStateDone: + } +} + +// setFinalError records the final connection status we report to the user. +func (c *Conn) setFinalError(err error) { + select { + case <-c.lifetime.donec: + return // already set + default: + } + c.lifetime.finalErr = err + close(c.lifetime.donec) +} + +// finalError returns the final connection status reported to the user, +// or nil if a final status has not yet been set. +func (c *Conn) finalError() error { + select { + case <-c.lifetime.donec: + return c.lifetime.finalErr + default: + } + return nil +} + +func (c *Conn) waitReady(ctx context.Context) error { + select { + case <-c.lifetime.readyc: + return nil + case <-c.lifetime.donec: + return c.lifetime.finalErr + default: + } + select { + case <-c.lifetime.readyc: + return nil + case <-c.lifetime.donec: + return c.lifetime.finalErr + case <-ctx.Done(): + return ctx.Err() + } +} + +// Close closes the connection. +// +// Close is equivalent to: +// +// conn.Abort(nil) +// err := conn.Wait(context.Background()) +func (c *Conn) Close() error { + c.Abort(nil) + <-c.lifetime.donec + return c.lifetime.finalErr +} + +// Wait waits for the peer to close the connection. +// +// If the connection is closed locally and the peer does not close its end of the connection, +// Wait will return with a non-nil error after the drain period expires. +// +// If the peer closes the connection with a NO_ERROR transport error, Wait returns nil. +// If the peer closes the connection with an application error, Wait returns an ApplicationError +// containing the peer's error code and reason. +// If the peer closes the connection with any other status, Wait returns a non-nil error. +func (c *Conn) Wait(ctx context.Context) error { + if err := c.waitOnDone(ctx, c.lifetime.donec); err != nil { + return err + } + return c.lifetime.finalErr +} + +// Abort closes the connection and returns immediately. +// +// If err is nil, Abort sends a transport error of NO_ERROR to the peer. +// If err is an ApplicationError, Abort sends its error code and text. +// Otherwise, Abort sends a transport error of APPLICATION_ERROR with the error's text. +func (c *Conn) Abort(err error) { + if err == nil { + err = localTransportError{code: errNo} + } + c.sendMsg(func(now time.Time, c *Conn) { + c.enterClosing(now, err) + }) +} + +// abort terminates a connection with an error. +func (c *Conn) abort(now time.Time, err error) { + c.setFinalError(err) // this error takes precedence over the peer's CONNECTION_CLOSE + c.enterClosing(now, err) +} + +// abortImmediately terminates a connection. +// The connection does not send a CONNECTION_CLOSE, and skips the draining period. +func (c *Conn) abortImmediately(now time.Time, err error) { + c.setFinalError(err) + c.setState(now, connStateDone) +} + +// enterClosing starts an immediate close. +// We will send a CONNECTION_CLOSE to the peer and wait for their response. +func (c *Conn) enterClosing(now time.Time, err error) { + switch c.lifetime.state { + case connStateAlive: + c.lifetime.localErr = err + c.setState(now, connStateClosing) + case connStatePeerClosed: + c.lifetime.localErr = err + } +} + +// enterDraining moves directly to the draining state, without sending a CONNECTION_CLOSE. +func (c *Conn) enterDraining(now time.Time) { + switch c.lifetime.state { + case connStateAlive, connStatePeerClosed, connStateClosing: + c.setState(now, connStateDraining) + } +} + +// exit fully terminates a connection immediately. +func (c *Conn) exit() { + c.sendMsg(func(now time.Time, c *Conn) { + c.abortImmediately(now, errors.New("connection closed")) + }) +} diff --git a/src/vendor/golang.org/x/net/quic/conn_flow.go b/src/vendor/golang.org/x/net/quic/conn_flow.go new file mode 100644 index 0000000000..1d04f45545 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_flow.go @@ -0,0 +1,142 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "sync/atomic" + "time" +) + +// connInflow tracks connection-level flow control for data sent by the peer to us. +// +// There are four byte offsets of significance in the stream of data received from the peer, +// each >= to the previous: +// +// - bytes read by the user +// - bytes received from the peer +// - limit sent to the peer in a MAX_DATA frame +// - potential new limit to sent to the peer +// +// We maintain a flow control window, so as bytes are read by the user +// the potential limit is extended correspondingly. +// +// We keep an atomic counter of bytes read by the user and not yet applied to the +// potential limit (credit). When this count grows large enough, we update the +// new limit to send and mark that we need to send a new MAX_DATA frame. +type connInflow struct { + sent sentVal // set when we need to send a MAX_DATA update to the peer + usedLimit int64 // total bytes sent by the peer, must be less than sentLimit + sentLimit int64 // last MAX_DATA sent to the peer + newLimit int64 // new MAX_DATA to send + + credit atomic.Int64 // bytes read but not yet applied to extending the flow-control window +} + +func (c *Conn) inflowInit() { + // The initial MAX_DATA limit is sent as a transport parameter. + c.streams.inflow.sentLimit = c.config.maxConnReadBufferSize() + c.streams.inflow.newLimit = c.streams.inflow.sentLimit +} + +// handleStreamBytesReadOffLoop records that the user has consumed bytes from a stream. +// We may extend the peer's flow control window. +// +// This is called indirectly by the user, via Read or CloseRead. +func (c *Conn) handleStreamBytesReadOffLoop(n int64) { + if n == 0 { + return + } + if c.shouldUpdateFlowControl(c.streams.inflow.credit.Add(n)) { + // We should send a MAX_DATA update to the peer. + // Record this on the Conn's main loop. + c.sendMsg(func(now time.Time, c *Conn) { + // A MAX_DATA update may have already happened, so check again. + if c.shouldUpdateFlowControl(c.streams.inflow.credit.Load()) { + c.sendMaxDataUpdate() + } + }) + } +} + +// handleStreamBytesReadOnLoop extends the peer's flow control window after +// data has been discarded due to a RESET_STREAM frame. +// +// This is called on the conn's loop. +func (c *Conn) handleStreamBytesReadOnLoop(n int64) { + if c.shouldUpdateFlowControl(c.streams.inflow.credit.Add(n)) { + c.sendMaxDataUpdate() + } +} + +func (c *Conn) sendMaxDataUpdate() { + c.streams.inflow.sent.setUnsent() + // Apply current credit to the limit. + // We don't strictly need to do this here + // since appendMaxDataFrame will do so as well, + // but this avoids redundant trips down this path + // if the MAX_DATA frame doesn't go out right away. + c.streams.inflow.newLimit += c.streams.inflow.credit.Swap(0) +} + +func (c *Conn) shouldUpdateFlowControl(credit int64) bool { + return shouldUpdateFlowControl(c.config.maxConnReadBufferSize(), credit) +} + +// handleStreamBytesReceived records that the peer has sent us stream data. +func (c *Conn) handleStreamBytesReceived(n int64) error { + c.streams.inflow.usedLimit += n + if c.streams.inflow.usedLimit > c.streams.inflow.sentLimit { + return localTransportError{ + code: errFlowControl, + reason: "stream exceeded flow control limit", + } + } + return nil +} + +// appendMaxDataFrame appends a MAX_DATA frame to the current packet. +// +// It returns true if no more frames need appending, +// false if it could not fit a frame in the current packet. +func (c *Conn) appendMaxDataFrame(w *packetWriter, pnum packetNumber, pto bool) bool { + if c.streams.inflow.sent.shouldSendPTO(pto) { + // Add any unapplied credit to the new limit now. + c.streams.inflow.newLimit += c.streams.inflow.credit.Swap(0) + if !w.appendMaxDataFrame(c.streams.inflow.newLimit) { + return false + } + c.streams.inflow.sentLimit += c.streams.inflow.newLimit + c.streams.inflow.sent.setSent(pnum) + } + return true +} + +// ackOrLossMaxData records the fate of a MAX_DATA frame. +func (c *Conn) ackOrLossMaxData(pnum packetNumber, fate packetFate) { + c.streams.inflow.sent.ackLatestOrLoss(pnum, fate) +} + +// connOutflow tracks connection-level flow control for data sent by us to the peer. +type connOutflow struct { + max int64 // largest MAX_DATA received from peer + used int64 // total bytes of STREAM data sent to peer +} + +// setMaxData updates the connection-level flow control limit +// with the initial limit conveyed in transport parameters +// or an update from a MAX_DATA frame. +func (f *connOutflow) setMaxData(maxData int64) { + f.max = max(f.max, maxData) +} + +// avail returns the number of connection-level flow control bytes available. +func (f *connOutflow) avail() int64 { + return f.max - f.used +} + +// consume records consumption of n bytes of flow. +func (f *connOutflow) consume(n int64) { + f.used += n +} diff --git a/src/vendor/golang.org/x/net/quic/conn_id.go b/src/vendor/golang.org/x/net/quic/conn_id.go new file mode 100644 index 0000000000..8749e52b79 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_id.go @@ -0,0 +1,502 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "crypto/rand" + "slices" +) + +// connIDState is a conn's connection IDs. +type connIDState struct { + // The destination connection IDs of packets we receive are local. + // The destination connection IDs of packets we send are remote. + // + // Local IDs are usually issued by us, and remote IDs by the peer. + // The exception is the transient destination connection ID sent in + // a client's Initial packets, which is chosen by the client. + // + // These are []connID rather than []*connID to minimize allocations. + local []connID + remote []remoteConnID + + nextLocalSeq int64 + peerActiveConnIDLimit int64 // peer's active_connection_id_limit + + // Handling of retirement of remote connection IDs. + // The rangesets track ID sequence numbers. + // IDs in need of retirement are added to remoteRetiring, + // moved to remoteRetiringSent once we send a RETIRE_CONECTION_ID frame, + // and removed from the set once retirement completes. + retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer + remoteRetiring rangeset[int64] // remote IDs in need of retirement + remoteRetiringSent rangeset[int64] // remote IDs waiting for ack of retirement + + originalDstConnID []byte // expected original_destination_connection_id param + retrySrcConnID []byte // expected retry_source_connection_id param + + needSend bool +} + +// A connID is a connection ID and associated metadata. +type connID struct { + // cid is the connection ID itself. + cid []byte + + // seq is the connection ID's sequence number: + // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-1 + // + // For the transient destination ID in a client's Initial packet, this is -1. + seq int64 + + // send is set when the connection ID's state needs to be sent to the peer. + // + // For local IDs, this indicates a new ID that should be sent + // in a NEW_CONNECTION_ID frame. + // + // For remote IDs, this indicates a retired ID that should be sent + // in a RETIRE_CONNECTION_ID frame. + send sentVal +} + +// A remoteConnID is a connection ID and stateless reset token. +type remoteConnID struct { + connID + resetToken statelessResetToken +} + +func (s *connIDState) initClient(c *Conn) error { + // Client chooses its initial connection ID, and sends it + // in the Source Connection ID field of the first Initial packet. + locid, err := c.newConnID(0) + if err != nil { + return err + } + s.local = append(s.local, connID{ + seq: 0, + cid: locid, + }) + s.nextLocalSeq = 1 + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.addConnID(c, locid) + }) + + // Client chooses an initial, transient connection ID for the server, + // and sends it in the Destination Connection ID field of the first Initial packet. + remid, err := c.newConnID(-1) + if err != nil { + return err + } + s.remote = append(s.remote, remoteConnID{ + connID: connID{ + seq: -1, + cid: remid, + }, + }) + s.originalDstConnID = remid + return nil +} + +func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error { + dstConnID := cloneBytes(cids.dstConnID) + // Client-chosen, transient connection ID received in the first Initial packet. + // The server will not use this as the Source Connection ID of packets it sends, + // but remembers it because it may receive packets sent to this destination. + s.local = append(s.local, connID{ + seq: -1, + cid: dstConnID, + }) + + // Server chooses a connection ID, and sends it in the Source Connection ID of + // the response to the clent. + locid, err := c.newConnID(0) + if err != nil { + return err + } + s.local = append(s.local, connID{ + seq: 0, + cid: locid, + }) + s.nextLocalSeq = 1 + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.addConnID(c, dstConnID) + conns.addConnID(c, locid) + }) + + // Client chose its own connection ID. + s.remote = append(s.remote, remoteConnID{ + connID: connID{ + seq: 0, + cid: cloneBytes(cids.srcConnID), + }, + }) + return nil +} + +// srcConnID is the Source Connection ID to use in a sent packet. +func (s *connIDState) srcConnID() []byte { + if s.local[0].seq == -1 && len(s.local) > 1 { + // Don't use the transient connection ID if another is available. + return s.local[1].cid + } + return s.local[0].cid +} + +// dstConnID is the Destination Connection ID to use in a sent packet. +func (s *connIDState) dstConnID() (cid []byte, ok bool) { + for i := range s.remote { + return s.remote[i].cid, true + } + return nil, false +} + +// isValidStatelessResetToken reports whether the given reset token is +// associated with a non-retired connection ID which we have used. +func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool { + if len(s.remote) == 0 { + return false + } + // We currently only use the first available remote connection ID, + // so any other reset token is not valid. + return s.remote[0].resetToken == resetToken +} + +// setPeerActiveConnIDLimit sets the active_connection_id_limit +// transport parameter received from the peer. +func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error { + s.peerActiveConnIDLimit = lim + return s.issueLocalIDs(c) +} + +func (s *connIDState) issueLocalIDs(c *Conn) error { + toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit) + for i := range s.local { + if s.local[i].seq != -1 { + toIssue-- + } + } + var newIDs [][]byte + for toIssue > 0 { + cid, err := c.newConnID(s.nextLocalSeq) + if err != nil { + return err + } + newIDs = append(newIDs, cid) + s.local = append(s.local, connID{ + seq: s.nextLocalSeq, + cid: cid, + }) + s.local[len(s.local)-1].send.setUnsent() + s.nextLocalSeq++ + s.needSend = true + toIssue-- + } + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + for _, cid := range newIDs { + conns.addConnID(c, cid) + } + }) + return nil +} + +// validateTransportParameters verifies the original_destination_connection_id and +// initial_source_connection_id transport parameters match the expected values. +func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p transportParameters) error { + // TODO: Consider returning more detailed errors, for debugging. + // Verify original_destination_connection_id matches + // the transient remote connection ID we chose (client) + // or is empty (server). + if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) { + return localTransportError{ + code: errTransportParameter, + reason: "original_destination_connection_id mismatch", + } + } + s.originalDstConnID = nil // we have no further need for this + // Verify retry_source_connection_id matches the value from + // the server's Retry packet (when one was sent), or is empty. + if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) { + return localTransportError{ + code: errTransportParameter, + reason: "retry_source_connection_id mismatch", + } + } + s.retrySrcConnID = nil // we have no further need for this + // Verify initial_source_connection_id matches the first remote connection ID. + if len(s.remote) == 0 || s.remote[0].seq != 0 { + return localTransportError{ + code: errInternal, + reason: "remote connection id missing", + } + } + if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) { + return localTransportError{ + code: errTransportParameter, + reason: "initial_source_connection_id mismatch", + } + } + if len(p.statelessResetToken) > 0 { + if c.side == serverSide { + return localTransportError{ + code: errTransportParameter, + reason: "client sent stateless_reset_token", + } + } + token := statelessResetToken(p.statelessResetToken) + s.remote[0].resetToken = token + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.addResetToken(c, token) + }) + } + return nil +} + +// handlePacket updates the connection ID state during the handshake +// (Initial and Handshake packets). +func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) { + switch { + case ptype == packetTypeInitial && c.side == clientSide: + if len(s.remote) == 1 && s.remote[0].seq == -1 { + // We're a client connection processing the first Initial packet + // from the server. Replace the transient remote connection ID + // with the Source Connection ID from the packet. + s.remote[0] = remoteConnID{ + connID: connID{ + seq: 0, + cid: cloneBytes(srcConnID), + }, + } + } + case ptype == packetTypeHandshake && c.side == serverSide: + if len(s.local) > 0 && s.local[0].seq == -1 { + // We're a server connection processing the first Handshake packet from + // the client. Discard the transient, client-chosen connection ID used + // for Initial packets; the client will never send it again. + cid := s.local[0].cid + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.retireConnID(c, cid) + }) + s.local = append(s.local[:0], s.local[1:]...) + } + } +} + +func (s *connIDState) handleRetryPacket(srcConnID []byte) { + if len(s.remote) != 1 || s.remote[0].seq != -1 { + panic("BUG: handling retry with non-transient remote conn id") + } + s.retrySrcConnID = cloneBytes(srcConnID) + s.remote[0].cid = s.retrySrcConnID +} + +func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, resetToken statelessResetToken) error { + if len(s.remote[0].cid) == 0 { + // "An endpoint that is sending packets with a zero-length + // Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID + // frame as a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6 + return localTransportError{ + code: errProtocolViolation, + reason: "NEW_CONNECTION_ID from peer with zero-length DCID", + } + } + + if seq < s.retireRemotePriorTo { + // This ID was already retired by a previous NEW_CONNECTION_ID frame. + // Nothing to do. + return nil + } + + if retire > s.retireRemotePriorTo { + // Add newly-retired connection IDs to the set we need to send + // RETIRE_CONNECTION_ID frames for, and remove them from s.remote. + // + // (This might cause us to send a RETIRE_CONNECTION_ID for an ID we've + // never seen. That's fine.) + s.remoteRetiring.add(s.retireRemotePriorTo, retire) + s.retireRemotePriorTo = retire + s.needSend = true + s.remote = slices.DeleteFunc(s.remote, func(rcid remoteConnID) bool { + return rcid.seq < s.retireRemotePriorTo + }) + } + + have := false // do we already have this connection ID? + for i := range s.remote { + rcid := &s.remote[i] + if rcid.seq == seq { + if !bytes.Equal(rcid.cid, cid) { + return localTransportError{ + code: errProtocolViolation, + reason: "NEW_CONNECTION_ID does not match prior id", + } + } + have = true // yes, we've seen this sequence number + break + } + } + + if !have { + // This is a new connection ID that we have not seen before. + // + // We could take steps to keep the list of remote connection IDs + // sorted by sequence number, but there's no particular need + // so we don't bother. + s.remote = append(s.remote, remoteConnID{ + connID: connID{ + seq: seq, + cid: cloneBytes(cid), + }, + resetToken: resetToken, + }) + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.addResetToken(c, resetToken) + }) + } + + if len(s.remote) > activeConnIDLimit { + // Retired connection IDs (including newly-retired ones) do not count + // against the limit. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5 + return localTransportError{ + code: errConnectionIDLimit, + reason: "active_connection_id_limit exceeded", + } + } + + // "An endpoint SHOULD limit the number of connection IDs it has retired locally + // for which RETIRE_CONNECTION_ID frames have not yet been acknowledged." + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6 + // + // Set a limit of three times the active_connection_id_limit for + // the total number of remote connection IDs we keep retirement state for. + if s.remoteRetiring.size()+s.remoteRetiringSent.size() > 3*activeConnIDLimit { + return localTransportError{ + code: errConnectionIDLimit, + reason: "too many unacknowledged retired connection ids", + } + } + + return nil +} + +func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error { + if seq >= s.nextLocalSeq { + return localTransportError{ + code: errProtocolViolation, + reason: "RETIRE_CONNECTION_ID for unissued sequence number", + } + } + for i := range s.local { + if s.local[i].seq == seq { + cid := s.local[i].cid + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.retireConnID(c, cid) + }) + s.local = append(s.local[:i], s.local[i+1:]...) + break + } + } + s.issueLocalIDs(c) + return nil +} + +func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) { + for i := range s.local { + if s.local[i].seq != seq { + continue + } + s.local[i].send.ackOrLoss(pnum, fate) + if fate != packetAcked { + s.needSend = true + } + return + } +} + +func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) { + s.remoteRetiringSent.sub(seq, seq+1) + if fate == packetLost { + // RETIRE_CONNECTION_ID frame was lost, mark for retransmission. + s.remoteRetiring.add(seq, seq+1) + s.needSend = true + } +} + +// appendFrames appends NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames +// to the current packet. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool { + if !s.needSend && !pto { + // Fast path: We don't need to send anything. + return true + } + retireBefore := int64(0) + if s.local[0].seq != -1 { + retireBefore = s.local[0].seq + } + for i := range s.local { + if !s.local[i].send.shouldSendPTO(pto) { + continue + } + if !c.w.appendNewConnectionIDFrame( + s.local[i].seq, + retireBefore, + s.local[i].cid, + c.endpoint.resetGen.tokenForConnID(s.local[i].cid), + ) { + return false + } + s.local[i].send.setSent(pnum) + } + if pto { + for _, r := range s.remoteRetiringSent { + for cid := r.start; cid < r.end; cid++ { + if !c.w.appendRetireConnectionIDFrame(cid) { + return false + } + } + } + } + for s.remoteRetiring.numRanges() > 0 { + cid := s.remoteRetiring.min() + if !c.w.appendRetireConnectionIDFrame(cid) { + return false + } + s.remoteRetiring.sub(cid, cid+1) + s.remoteRetiringSent.add(cid, cid+1) + } + s.needSend = false + return true +} + +func cloneBytes(b []byte) []byte { + n := make([]byte, len(b)) + copy(n, b) + return n +} + +func (c *Conn) newConnID(seq int64) ([]byte, error) { + if c.testHooks != nil { + return c.testHooks.newConnID(seq) + } + return newRandomConnID(seq) +} + +func newRandomConnID(_ int64) ([]byte, error) { + // It is not necessary for connection IDs to be cryptographically secure, + // but it doesn't hurt. + id := make([]byte, connIDLen) + if _, err := rand.Read(id); err != nil { + // TODO: Surface this error as a metric or log event or something. + // rand.Read really shouldn't ever fail, but if it does, we should + // have a way to inform the user. + return nil, err + } + return id, nil +} diff --git a/src/vendor/golang.org/x/net/quic/conn_loss.go b/src/vendor/golang.org/x/net/quic/conn_loss.go new file mode 100644 index 0000000000..bc6d106601 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_loss.go @@ -0,0 +1,85 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "fmt" + +// handleAckOrLoss deals with the final fate of a packet we sent: +// Either the peer acknowledges it, or we declare it lost. +// +// In order to handle packet loss, we must retain any information sent to the peer +// until the peer has acknowledged it. +// +// When information is acknowledged, we can discard it. +// +// When information is lost, we mark it for retransmission. +// See RFC 9000, Section 13.3 for a complete list of information which is retransmitted on loss. +// https://www.rfc-editor.org/rfc/rfc9000#section-13.3 +func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) { + if fate == packetLost && c.logEnabled(QLogLevelPacket) { + c.logPacketLost(space, sent) + } + + // The list of frames in a sent packet is marshaled into a buffer in the sentPacket + // by the packetWriter. Unmarshal that buffer here. This code must be kept in sync with + // packetWriter.append*. + // + // A sent packet meets its fate (acked or lost) only once, so it's okay to consume + // the sentPacket's buffer here. + for !sent.done() { + switch f := sent.next(); f { + default: + panic(fmt.Sprintf("BUG: unhandled acked/lost frame type %x", f)) + case frameTypeAck, frameTypeAckECN: + // Unlike most information, loss of an ACK frame does not trigger + // retransmission. ACKs are sent in response to ack-eliciting packets, + // and always contain the latest information available. + // + // Acknowledgement of an ACK frame may allow us to discard information + // about older packets. + largest := packetNumber(sent.nextInt()) + if fate == packetAcked { + c.acks[space].handleAck(largest) + } + case frameTypeCrypto: + start, end := sent.nextRange() + c.crypto[space].ackOrLoss(start, end, fate) + case frameTypeMaxData: + c.ackOrLossMaxData(sent.num, fate) + case frameTypeResetStream, + frameTypeStopSending, + frameTypeMaxStreamData, + frameTypeStreamDataBlocked: + id := streamID(sent.nextInt()) + s := c.streamForID(id) + if s == nil { + continue + } + s.ackOrLoss(sent.num, f, fate) + case frameTypeStreamBase, + frameTypeStreamBase | streamFinBit: + id := streamID(sent.nextInt()) + start, end := sent.nextRange() + s := c.streamForID(id) + if s == nil { + continue + } + fin := f&streamFinBit != 0 + s.ackOrLossData(sent.num, start, end, fin, fate) + case frameTypeMaxStreamsBidi: + c.streams.remoteLimit[bidiStream].sendMax.ackLatestOrLoss(sent.num, fate) + case frameTypeMaxStreamsUni: + c.streams.remoteLimit[uniStream].sendMax.ackLatestOrLoss(sent.num, fate) + case frameTypeNewConnectionID: + seq := int64(sent.nextInt()) + c.connIDState.ackOrLossNewConnectionID(sent.num, seq, fate) + case frameTypeRetireConnectionID: + seq := int64(sent.nextInt()) + c.connIDState.ackOrLossRetireConnectionID(sent.num, seq, fate) + case frameTypeHandshakeDone: + c.handshakeConfirmed.ackOrLoss(sent.num, fate) + } + } +} diff --git a/src/vendor/golang.org/x/net/quic/conn_recv.go b/src/vendor/golang.org/x/net/quic/conn_recv.go new file mode 100644 index 0000000000..2bf127a479 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_recv.go @@ -0,0 +1,631 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "encoding/binary" + "errors" + "time" +) + +func (c *Conn) handleDatagram(now time.Time, dgram *datagram) (handled bool) { + if !c.localAddr.IsValid() { + // We don't have any way to tell in the general case what address we're + // sending packets from. Set our address from the destination address of + // the first packet received from the peer. + c.localAddr = dgram.localAddr + } + if dgram.peerAddr.IsValid() && dgram.peerAddr != c.peerAddr { + if c.side == clientSide { + // "If a client receives packets from an unknown server address, + // the client MUST discard these packets." + // https://www.rfc-editor.org/rfc/rfc9000#section-9-6 + return false + } + // We currently don't support connection migration, + // so for now the server also drops packets from an unknown address. + return false + } + buf := dgram.b + c.loss.datagramReceived(now, len(buf)) + if c.isDraining() { + return false + } + for len(buf) > 0 { + var n int + ptype := getPacketType(buf) + switch ptype { + case packetTypeInitial: + if c.side == serverSide && len(dgram.b) < paddedInitialDatagramSize { + // Discard client-sent Initial packets in too-short datagrams. + // https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4 + return false + } + n = c.handleLongHeader(now, dgram, ptype, initialSpace, c.keysInitial.r, buf) + case packetTypeHandshake: + n = c.handleLongHeader(now, dgram, ptype, handshakeSpace, c.keysHandshake.r, buf) + case packetType1RTT: + n = c.handle1RTT(now, dgram, buf) + case packetTypeRetry: + c.handleRetry(now, buf) + return true + case packetTypeVersionNegotiation: + c.handleVersionNegotiation(now, buf) + return true + default: + n = -1 + } + if n <= 0 { + // We don't expect to get a stateless reset with a valid + // destination connection ID, since the sender of a stateless + // reset doesn't know what the connection ID is. + // + // We're required to perform this check anyway. + // + // "[...] the comparison MUST be performed when the first packet + // in an incoming datagram [...] cannot be decrypted." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-2 + if len(buf) == len(dgram.b) && len(buf) > statelessResetTokenLen { + var token statelessResetToken + copy(token[:], buf[len(buf)-len(token):]) + if c.handleStatelessReset(now, token) { + return true + } + } + // Invalid data at the end of a datagram is ignored. + return false + } + c.idleHandlePacketReceived(now) + buf = buf[n:] + } + return true +} + +func (c *Conn) handleLongHeader(now time.Time, dgram *datagram, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int { + if !k.isSet() { + return skipLongHeaderPacket(buf) + } + + pnumMax := c.acks[space].largestSeen() + p, n := parseLongHeaderPacket(buf, k, pnumMax) + if n < 0 { + return -1 + } + if buf[0]&reservedLongBits != 0 { + // Reserved header bits must be 0. + // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "reserved header bits are not zero", + }) + return -1 + } + if p.version != quicVersion1 { + // The peer has changed versions on us mid-handshake? + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "protocol version changed during handshake", + }) + return -1 + } + + if !c.acks[space].shouldProcess(p.num) { + return n + } + + if logPackets { + logInboundLongPacket(c, p) + } + if c.logEnabled(QLogLevelPacket) { + c.logLongPacketReceived(p, buf[:n]) + } + c.connIDState.handlePacket(c, p.ptype, p.srcConnID) + ackEliciting := c.handleFrames(now, dgram, ptype, space, p.payload) + c.acks[space].receive(now, space, p.num, ackEliciting, dgram.ecn) + if p.ptype == packetTypeHandshake && c.side == serverSide { + c.loss.validateClientAddress() + + // "[...] a server MUST discard Initial keys when it first successfully + // processes a Handshake packet [...]" + // https://www.rfc-editor.org/rfc/rfc9001#section-4.9.1-2 + c.discardKeys(now, initialSpace) + } + return n +} + +func (c *Conn) handle1RTT(now time.Time, dgram *datagram, buf []byte) int { + if !c.keysAppData.canRead() { + // 1-RTT packets extend to the end of the datagram, + // so skip the remainder of the datagram if we can't parse this. + return len(buf) + } + + pnumMax := c.acks[appDataSpace].largestSeen() + p, err := parse1RTTPacket(buf, &c.keysAppData, connIDLen, pnumMax) + if err != nil { + // A localTransportError terminates the connection. + // Other errors indicate an unparsable packet, but otherwise may be ignored. + if _, ok := err.(localTransportError); ok { + c.abort(now, err) + } + return -1 + } + if buf[0]&reserved1RTTBits != 0 { + // Reserved header bits must be 0. + // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "reserved header bits are not zero", + }) + return -1 + } + + if !c.acks[appDataSpace].shouldProcess(p.num) { + return len(buf) + } + + if logPackets { + logInboundShortPacket(c, p) + } + if c.logEnabled(QLogLevelPacket) { + c.log1RTTPacketReceived(p, buf) + } + ackEliciting := c.handleFrames(now, dgram, packetType1RTT, appDataSpace, p.payload) + c.acks[appDataSpace].receive(now, appDataSpace, p.num, ackEliciting, dgram.ecn) + return len(buf) +} + +func (c *Conn) handleRetry(now time.Time, pkt []byte) { + if c.side != clientSide { + return // clients don't send Retry packets + } + // "After the client has received and processed an Initial or Retry packet + // from the server, it MUST discard any subsequent Retry packets that it receives." + // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-1 + if !c.keysInitial.canRead() { + return // discarded Initial keys, connection is already established + } + if c.acks[initialSpace].seen.numRanges() != 0 { + return // processed at least one packet + } + if c.retryToken != nil { + return // received a Retry already + } + // "Clients MUST discard Retry packets that have a Retry Integrity Tag + // that cannot be validated." + // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-2 + p, ok := parseRetryPacket(pkt, c.connIDState.originalDstConnID) + if !ok { + return + } + // "A client MUST discard a Retry packet with a zero-length Retry Token field." + // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-2 + if len(p.token) == 0 { + return + } + c.retryToken = cloneBytes(p.token) + c.connIDState.handleRetryPacket(p.srcConnID) + c.keysInitial = initialKeys(p.srcConnID, c.side) + // We need to resend any data we've already sent in Initial packets. + // We must not reuse already sent packet numbers. + c.loss.discardPackets(initialSpace, c.log, c.handleAckOrLoss) + // TODO: Discard 0-RTT packets as well, once we support 0-RTT. + if c.testHooks != nil { + c.testHooks.init(false) + } +} + +var errVersionNegotiation = errors.New("server does not support QUIC version 1") + +func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) { + if c.side != clientSide { + return // servers don't handle Version Negotiation packets + } + // "A client MUST discard any Version Negotiation packet if it has + // received and successfully processed any other packet [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + if !c.keysInitial.canRead() { + return // discarded Initial keys, connection is already established + } + if c.acks[initialSpace].seen.numRanges() != 0 { + return // processed at least one packet + } + _, srcConnID, versions := parseVersionNegotiation(pkt) + if len(c.connIDState.remote) < 1 || !bytes.Equal(c.connIDState.remote[0].cid, srcConnID) { + return // Source Connection ID doesn't match what we sent + } + for len(versions) >= 4 { + ver := binary.BigEndian.Uint32(versions) + if ver == 1 { + // "A client MUST discard a Version Negotiation packet that lists + // the QUIC version selected by the client." + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + return + } + versions = versions[4:] + } + // "A client that supports only this version of QUIC MUST + // abandon the current connection attempt if it receives + // a Version Negotiation packet, [with the two exceptions handled above]." + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + c.abortImmediately(now, errVersionNegotiation) +} + +func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) { + if len(payload) == 0 { + // "An endpoint MUST treat receipt of a packet containing no frames + // as a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "packet contains no frames", + }) + return false + } + // frameOK verifies that ptype is one of the packets in mask. + frameOK := func(c *Conn, ptype, mask packetType) (ok bool) { + if ptype&mask == 0 { + // "An endpoint MUST treat receipt of a frame in a packet type + // that is not permitted as a connection error of type + // PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "frame not allowed in packet", + }) + return false + } + return true + } + // Packet masks from RFC 9000 Table 3. + // https://www.rfc-editor.org/rfc/rfc9000#table-3 + const ( + IH_1 = packetTypeInitial | packetTypeHandshake | packetType1RTT + __01 = packetType0RTT | packetType1RTT + ___1 = packetType1RTT + ) + hasCrypto := false + for len(payload) > 0 { + switch payload[0] { + case frameTypePadding, frameTypeAck, frameTypeAckECN, + frameTypeConnectionCloseTransport, frameTypeConnectionCloseApplication: + default: + ackEliciting = true + } + n := -1 + switch payload[0] { + case frameTypePadding: + // PADDING is OK in all spaces. + n = 1 + case frameTypePing: + // PING is OK in all spaces. + // + // A PING frame causes us to respond with an ACK by virtue of being + // an ack-eliciting frame, but requires no other action. + n = 1 + case frameTypeAck, frameTypeAckECN: + if !frameOK(c, ptype, IH_1) { + return + } + n = c.handleAckFrame(now, space, payload) + case frameTypeResetStream: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleResetStreamFrame(now, space, payload) + case frameTypeStopSending: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleStopSendingFrame(now, space, payload) + case frameTypeCrypto: + if !frameOK(c, ptype, IH_1) { + return + } + hasCrypto = true + n = c.handleCryptoFrame(now, space, payload) + case frameTypeNewToken: + if !frameOK(c, ptype, ___1) { + return + } + _, n = consumeNewTokenFrame(payload) + case 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f: // STREAM + if !frameOK(c, ptype, __01) { + return + } + n = c.handleStreamFrame(now, space, payload) + case frameTypeMaxData: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleMaxDataFrame(now, payload) + case frameTypeMaxStreamData: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleMaxStreamDataFrame(now, payload) + case frameTypeMaxStreamsBidi, frameTypeMaxStreamsUni: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleMaxStreamsFrame(now, payload) + case frameTypeDataBlocked: + if !frameOK(c, ptype, __01) { + return + } + _, n = consumeDataBlockedFrame(payload) + case frameTypeStreamsBlockedBidi, frameTypeStreamsBlockedUni: + if !frameOK(c, ptype, __01) { + return + } + _, _, n = consumeStreamsBlockedFrame(payload) + case frameTypeStreamDataBlocked: + if !frameOK(c, ptype, __01) { + return + } + _, _, n = consumeStreamDataBlockedFrame(payload) + case frameTypeNewConnectionID: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleNewConnectionIDFrame(now, space, payload) + case frameTypeRetireConnectionID: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleRetireConnectionIDFrame(now, space, payload) + case frameTypePathChallenge: + if !frameOK(c, ptype, __01) { + return + } + n = c.handlePathChallengeFrame(now, dgram, space, payload) + case frameTypePathResponse: + if !frameOK(c, ptype, ___1) { + return + } + n = c.handlePathResponseFrame(now, space, payload) + case frameTypeConnectionCloseTransport: + // Transport CONNECTION_CLOSE is OK in all spaces. + n = c.handleConnectionCloseTransportFrame(now, payload) + case frameTypeConnectionCloseApplication: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleConnectionCloseApplicationFrame(now, payload) + case frameTypeHandshakeDone: + if !frameOK(c, ptype, ___1) { + return + } + n = c.handleHandshakeDoneFrame(now, space, payload) + } + if n < 0 { + c.abort(now, localTransportError{ + code: errFrameEncoding, + reason: "frame encoding error", + }) + return false + } + payload = payload[n:] + } + if hasCrypto { + // Process TLS events after handling all frames in a packet. + // TLS events can cause us to drop state for a number space, + // so do that last, to avoid handling frames differently + // depending on whether they come before or after a CRYPTO frame. + if err := c.handleTLSEvents(now); err != nil { + c.abort(now, err) + } + } + return ackEliciting +} + +func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) int { + c.loss.receiveAckStart() + largest, ackDelay, ecn, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) { + if err := c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss); err != nil { + c.abort(now, err) + return + } + }) + // TODO: Make use of ECN feedback. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.3.2 + _ = ecn + // Prior to receiving the peer's transport parameters, we cannot + // interpret the ACK Delay field because we don't know the ack_delay_exponent + // to apply. + // + // For servers, we should always know the ack_delay_exponent because the + // client's transport parameters are carried in its Initial packets and we + // won't send an ack-eliciting Initial packet until after receiving the last + // client Initial packet. + // + // For clients, we won't receive the server's transport parameters until handling + // its Handshake flight, which will probably happen after reading its ACK for our + // Initial packet(s). However, the peer's acknowledgement delay cannot reduce our + // adjusted RTT sample below min_rtt, and min_rtt is generally going to be set + // by the packet containing the ACK for our Initial flight. Therefore, the + // ACK Delay for an ACK in the Initial space is likely to be ignored anyway. + // + // Long story short, setting the delay to 0 prior to reading transport parameters + // is usually going to have no effect, will have only a minor effect in the rare + // cases when it happens, and there aren't any good alternatives anyway since we + // can't interpret the ACK Delay field without knowing the exponent. + var delay time.Duration + if c.peerAckDelayExponent >= 0 { + delay = ackDelay.Duration(uint8(c.peerAckDelayExponent)) + } + c.loss.receiveAckEnd(now, c.log, space, delay, c.handleAckOrLoss) + if space == appDataSpace { + c.keysAppData.handleAckFor(largest) + } + return n +} + +func (c *Conn) handleMaxDataFrame(now time.Time, payload []byte) int { + maxData, n := consumeMaxDataFrame(payload) + if n < 0 { + return -1 + } + c.streams.outflow.setMaxData(maxData) + return n +} + +func (c *Conn) handleMaxStreamDataFrame(now time.Time, payload []byte) int { + id, maxStreamData, n := consumeMaxStreamDataFrame(payload) + if n < 0 { + return -1 + } + if s := c.streamForFrame(now, id, sendStream); s != nil { + if err := s.handleMaxStreamData(maxStreamData); err != nil { + c.abort(now, err) + return -1 + } + } + return n +} + +func (c *Conn) handleMaxStreamsFrame(now time.Time, payload []byte) int { + styp, max, n := consumeMaxStreamsFrame(payload) + if n < 0 { + return -1 + } + c.streams.localLimit[styp].setMax(max) + return n +} + +func (c *Conn) handleResetStreamFrame(now time.Time, space numberSpace, payload []byte) int { + id, code, finalSize, n := consumeResetStreamFrame(payload) + if n < 0 { + return -1 + } + if s := c.streamForFrame(now, id, recvStream); s != nil { + if err := s.handleReset(code, finalSize); err != nil { + c.abort(now, err) + } + } + return n +} + +func (c *Conn) handleStopSendingFrame(now time.Time, space numberSpace, payload []byte) int { + id, code, n := consumeStopSendingFrame(payload) + if n < 0 { + return -1 + } + if s := c.streamForFrame(now, id, sendStream); s != nil { + if err := s.handleStopSending(code); err != nil { + c.abort(now, err) + } + } + return n +} + +func (c *Conn) handleCryptoFrame(now time.Time, space numberSpace, payload []byte) int { + off, data, n := consumeCryptoFrame(payload) + err := c.handleCrypto(now, space, off, data) + if err != nil { + c.abort(now, err) + return -1 + } + return n +} + +func (c *Conn) handleStreamFrame(now time.Time, space numberSpace, payload []byte) int { + id, off, fin, b, n := consumeStreamFrame(payload) + if n < 0 { + return -1 + } + if s := c.streamForFrame(now, id, recvStream); s != nil { + if err := s.handleData(off, b, fin); err != nil { + c.abort(now, err) + } + } + return n +} + +func (c *Conn) handleNewConnectionIDFrame(now time.Time, space numberSpace, payload []byte) int { + seq, retire, connID, resetToken, n := consumeNewConnectionIDFrame(payload) + if n < 0 { + return -1 + } + if err := c.connIDState.handleNewConnID(c, seq, retire, connID, resetToken); err != nil { + c.abort(now, err) + } + return n +} + +func (c *Conn) handleRetireConnectionIDFrame(now time.Time, space numberSpace, payload []byte) int { + seq, n := consumeRetireConnectionIDFrame(payload) + if n < 0 { + return -1 + } + if err := c.connIDState.handleRetireConnID(c, seq); err != nil { + c.abort(now, err) + } + return n +} + +func (c *Conn) handlePathChallengeFrame(now time.Time, dgram *datagram, space numberSpace, payload []byte) int { + data, n := consumePathChallengeFrame(payload) + if n < 0 { + return -1 + } + c.handlePathChallenge(now, dgram, data) + return n +} + +func (c *Conn) handlePathResponseFrame(now time.Time, space numberSpace, payload []byte) int { + data, n := consumePathResponseFrame(payload) + if n < 0 { + return -1 + } + c.handlePathResponse(now, data) + return n +} + +func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte) int { + code, _, reason, n := consumeConnectionCloseTransportFrame(payload) + if n < 0 { + return -1 + } + c.handlePeerConnectionClose(now, peerTransportError{code: code, reason: reason}) + return n +} + +func (c *Conn) handleConnectionCloseApplicationFrame(now time.Time, payload []byte) int { + code, reason, n := consumeConnectionCloseApplicationFrame(payload) + if n < 0 { + return -1 + } + c.handlePeerConnectionClose(now, &ApplicationError{Code: code, Reason: reason}) + return n +} + +func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payload []byte) int { + if c.side == serverSide { + // Clients should never send HANDSHAKE_DONE. + // https://www.rfc-editor.org/rfc/rfc9000#section-19.20-4 + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "client sent HANDSHAKE_DONE", + }) + return -1 + } + if c.isAlive() { + c.confirmHandshake(now) + } + return 1 +} + +var errStatelessReset = errors.New("received stateless reset") + +func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) (valid bool) { + if !c.connIDState.isValidStatelessResetToken(resetToken) { + return false + } + c.setFinalError(errStatelessReset) + c.enterDraining(now) + return true +} diff --git a/src/vendor/golang.org/x/net/quic/conn_send.go b/src/vendor/golang.org/x/net/quic/conn_send.go new file mode 100644 index 0000000000..3e8cf526b5 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_send.go @@ -0,0 +1,407 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "crypto/tls" + "errors" + "time" +) + +// maybeSend sends datagrams, if possible. +// +// If sending is blocked by pacing, it returns the next time +// a datagram may be sent. +// +// If sending is blocked indefinitely, it returns the zero Time. +func (c *Conn) maybeSend(now time.Time) (next time.Time) { + // Assumption: The congestion window is not underutilized. + // If congestion control, pacing, and anti-amplification all permit sending, + // but we have no packet to send, then we will declare the window underutilized. + underutilized := false + defer func() { + c.loss.cc.setUnderutilized(c.log, underutilized) + }() + + // Send one datagram on each iteration of this loop, + // until we hit a limit or run out of data to send. + // + // For each number space where we have write keys, + // attempt to construct a packet in that space. + // If the packet contains no frames (we have no data in need of sending), + // abandon the packet. + // + // Speculatively constructing packets means we don't need + // separate code paths for "do we have data to send?" and + // "send the data" that need to be kept in sync. + for { + limit, next := c.loss.sendLimit(now) + if limit == ccBlocked { + // If anti-amplification blocks sending, then no packet can be sent. + return next + } + if !c.sendOK(now) { + return time.Time{} + } + // We may still send ACKs, even if congestion control or pacing limit sending. + + // Prepare to write a datagram of at most maxSendSize bytes. + c.w.reset(c.loss.maxSendSize()) + + dstConnID, ok := c.connIDState.dstConnID() + if !ok { + // It is currently not possible for us to end up without a connection ID, + // but handle the case anyway. + return time.Time{} + } + + // Initial packet. + pad := false + var sentInitial *sentPacket + if c.keysInitial.canWrite() { + pnumMaxAcked := c.loss.spaces[initialSpace].maxAcked + pnum := c.loss.nextNumber(initialSpace) + p := longPacket{ + ptype: packetTypeInitial, + version: quicVersion1, + num: pnum, + dstConnID: dstConnID, + srcConnID: c.connIDState.srcConnID(), + extra: c.retryToken, + } + c.w.startProtectedLongHeaderPacket(pnumMaxAcked, p) + c.appendFrames(now, initialSpace, pnum, limit) + if logPackets { + logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload()) + } + if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { + c.logPacketSent(packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload()) + } + sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p) + if sentInitial != nil { + // Client initial packets and ack-eliciting server initial packaets + // need to be sent in a datagram padded to at least 1200 bytes. + // We can't add the padding yet, however, since we may want to + // coalesce additional packets with this one. + if c.side == clientSide || sentInitial.ackEliciting { + pad = true + } + } + } + + // Handshake packet. + if c.keysHandshake.canWrite() { + pnumMaxAcked := c.loss.spaces[handshakeSpace].maxAcked + pnum := c.loss.nextNumber(handshakeSpace) + p := longPacket{ + ptype: packetTypeHandshake, + version: quicVersion1, + num: pnum, + dstConnID: dstConnID, + srcConnID: c.connIDState.srcConnID(), + } + c.w.startProtectedLongHeaderPacket(pnumMaxAcked, p) + c.appendFrames(now, handshakeSpace, pnum, limit) + if logPackets { + logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload()) + } + if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { + c.logPacketSent(packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload()) + } + if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil { + c.packetSent(now, handshakeSpace, sent) + if c.side == clientSide { + // "[...] a client MUST discard Initial keys when it first + // sends a Handshake packet [...]" + // https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9.1-2 + c.discardKeys(now, initialSpace) + } + } + } + + // 1-RTT packet. + if c.keysAppData.canWrite() { + pnumMaxAcked := c.loss.spaces[appDataSpace].maxAcked + pnum := c.loss.nextNumber(appDataSpace) + c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID) + c.appendFrames(now, appDataSpace, pnum, limit) + if pad && len(c.w.payload()) > 0 { + // 1-RTT packets have no length field and extend to the end + // of the datagram, so if we're sending a datagram that needs + // padding we need to add it inside the 1-RTT packet. + c.w.appendPaddingTo(paddedInitialDatagramSize) + pad = false + } + if logPackets { + logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload()) + } + if c.logEnabled(QLogLevelPacket) && len(c.w.payload()) > 0 { + c.logPacketSent(packetType1RTT, pnum, nil, dstConnID, c.w.packetLen(), c.w.payload()) + } + if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil { + c.packetSent(now, appDataSpace, sent) + if c.skip.shouldSkip(pnum + 1) { + c.loss.skipNumber(now, appDataSpace) + c.skip.updateNumberSkip(c) + } + } + } + + buf := c.w.datagram() + if len(buf) == 0 { + if limit == ccOK { + // We have nothing to send, and congestion control does not + // block sending. The congestion window is underutilized. + underutilized = true + } + return next + } + + if sentInitial != nil { + if pad { + // Pad out the datagram with zeros, coalescing the Initial + // packet with invalid packets that will be ignored by the peer. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-14.1-1 + for len(buf) < paddedInitialDatagramSize { + buf = append(buf, 0) + // Technically this padding isn't in any packet, but + // account it to the Initial packet in this datagram + // for purposes of flow control and loss recovery. + sentInitial.size++ + sentInitial.inFlight = true + } + } + // If we're a client and this Initial packet is coalesced + // with a Handshake packet, then we've discarded Initial keys + // since constructing the packet and shouldn't record it as in-flight. + if c.keysInitial.canWrite() { + c.packetSent(now, initialSpace, sentInitial) + } + } + + c.endpoint.sendDatagram(datagram{ + b: buf, + peerAddr: c.peerAddr, + }) + } +} + +func (c *Conn) packetSent(now time.Time, space numberSpace, sent *sentPacket) { + c.idleHandlePacketSent(now, sent) + c.loss.packetSent(now, c.log, space, sent) +} + +func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, limit ccLimit) { + if c.lifetime.localErr != nil { + c.appendConnectionCloseFrame(now, space, c.lifetime.localErr) + return + } + + shouldSendAck := c.acks[space].shouldSendAck(now) + if limit != ccOK { + // ACKs are not limited by congestion control. + if shouldSendAck && c.appendAckFrame(now, space) { + c.acks[space].sentAck() + } + return + } + // We want to send an ACK frame if the ack controller wants to send a frame now, + // OR if we are sending a packet anyway and have ack-eliciting packets which we + // have not yet acked. + // + // We speculatively add ACK frames here, to put them at the front of the packet + // to avoid truncation. + // + // After adding all frames, if we don't need to send an ACK frame and have not + // added any other frames, we abandon the packet. + if c.appendAckFrame(now, space) { + defer func() { + // All frames other than ACK and PADDING are ack-eliciting, + // so if the packet is ack-eliciting we've added additional + // frames to it. + if !shouldSendAck && !c.w.sent.ackEliciting { + // There's nothing in this packet but ACK frames, and + // we don't want to send an ACK-only packet at this time. + // Abandoning the packet means we wrote an ACK frame for + // nothing, but constructing the frame is cheap. + c.w.abandonPacket() + return + } + // Either we are willing to send an ACK-only packet, + // or we've added additional frames. + c.acks[space].sentAck() + if !c.w.sent.ackEliciting && c.shouldMakePacketAckEliciting() { + c.w.appendPingFrame() + } + }() + } + if limit != ccOK { + return + } + pto := c.loss.ptoExpired + + // TODO: Add all the other frames we can send. + + // CRYPTO + c.crypto[space].dataToSend(pto, func(off, size int64) int64 { + b, _ := c.w.appendCryptoFrame(off, int(size)) + c.crypto[space].sendData(off, b) + return int64(len(b)) + }) + + // Test-only PING frames. + if space == c.testSendPingSpace && c.testSendPing.shouldSendPTO(pto) { + if !c.w.appendPingFrame() { + return + } + c.testSendPing.setSent(pnum) + } + + if space == appDataSpace { + // HANDSHAKE_DONE + if c.handshakeConfirmed.shouldSendPTO(pto) { + if !c.w.appendHandshakeDoneFrame() { + return + } + c.handshakeConfirmed.setSent(pnum) + } + + // NEW_CONNECTION_ID, RETIRE_CONNECTION_ID + if !c.connIDState.appendFrames(c, pnum, pto) { + return + } + + // PATH_RESPONSE + if pad, ok := c.appendPathFrames(); !ok { + return + } else if pad { + defer c.w.appendPaddingTo(smallestMaxDatagramSize) + } + + // All stream-related frames. This should come last in the packet, + // so large amounts of STREAM data don't crowd out other frames + // we may need to send. + if !c.appendStreamFrames(&c.w, pnum, pto) { + return + } + + if !c.appendKeepAlive(now) { + return + } + } + + // If this is a PTO probe and we haven't added an ack-eliciting frame yet, + // add a PING to make this an ack-eliciting probe. + // + // Technically, there are separate PTO timers for each number space. + // When a PTO timer expires, we MUST send an ack-eliciting packet in the + // timer's space. We SHOULD send ack-eliciting packets in every other space + // with in-flight data. (RFC 9002, section 6.2.4) + // + // What we actually do is send a single datagram containing an ack-eliciting packet + // for every space for which we have keys. + // + // We fill the PTO probe packets with new or unacknowledged data. For example, + // a PTO probe sent for the Initial space will generally retransmit previously + // sent but unacknowledged CRYPTO data. + // + // When sending a PTO probe datagram containing multiple packets, it is + // possible that an earlier packet will fill up the datagram, leaving no + // space for the remaining probe packet(s). This is not a problem in practice. + // + // A client discards Initial keys when it first sends a Handshake packet + // (RFC 9001 Section 4.9.1). Handshake keys are discarded when the handshake + // is confirmed (RFC 9001 Section 4.9.2). The PTO timer is not set for the + // Application Data packet number space until the handshake is confirmed + // (RFC 9002 Section 6.2.1). Therefore, the only times a PTO probe can fire + // while data for multiple spaces is in flight are: + // + // - a server's Initial or Handshake timers can fire while Initial and Handshake + // data is in flight; and + // + // - a client's Handshake timer can fire while Handshake and Application Data + // data is in flight. + // + // It is theoretically possible for a server's Initial CRYPTO data to overflow + // the maximum datagram size, but unlikely in practice; this space contains + // only the ServerHello TLS message, which is small. It's also unlikely that + // the Handshake PTO probe will fire while Initial data is in flight (this + // requires not just that the Initial CRYPTO data completely fill a datagram, + // but a quite specific arrangement of lost and retransmitted packets.) + // We don't bother worrying about this case here, since the worst case is + // that we send a PTO probe for the in-flight Initial data and drop the + // Handshake probe. + // + // If a client's Handshake PTO timer fires while Application Data data is in + // flight, it is possible that the resent Handshake CRYPTO data will crowd + // out the probe for the Application Data space. However, since this probe is + // optional (recall that the Application Data PTO timer is never set until + // after Handshake keys have been discarded), dropping it is acceptable. + if pto && !c.w.sent.ackEliciting { + c.w.appendPingFrame() + } +} + +// shouldMakePacketAckEliciting is called when sending a packet containing nothing but an ACK frame. +// It reports whether we should add a PING frame to the packet to make it ack-eliciting. +func (c *Conn) shouldMakePacketAckEliciting() bool { + if c.keysAppData.needAckEliciting() { + // The peer has initiated a key update. + // We haven't sent them any packets yet in the new phase. + // Make this an ack-eliciting packet. + // Their ack of this packet will complete the key update. + return true + } + if c.loss.consecutiveNonAckElicitingPackets >= 19 { + // We've sent a run of non-ack-eliciting packets. + // Add in an ack-eliciting one every once in a while so the peer + // lets us know which ones have arrived. + // + // Google QUICHE injects a PING after sending 19 packets. We do the same. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-13.2.4-2 + return true + } + // TODO: Consider making every packet sent when in PTO ack-eliciting to speed up recovery. + return false +} + +func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool { + seen, delay := c.acks[space].acksToSend(now) + if len(seen) == 0 { + return false + } + d := unscaledAckDelayFromDuration(delay, ackDelayExponent) + return c.w.appendAckFrame(seen, d, c.acks[space].ecn) +} + +func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err error) { + c.sentConnectionClose(now) + switch e := err.(type) { + case localTransportError: + c.w.appendConnectionCloseTransportFrame(e.code, 0, e.reason) + case *ApplicationError: + if space != appDataSpace { + // "CONNECTION_CLOSE frames signaling application errors (type 0x1d) + // MUST only appear in the application data packet number space." + // https://www.rfc-editor.org/rfc/rfc9000#section-12.5-2.2 + c.w.appendConnectionCloseTransportFrame(errApplicationError, 0, "") + } else { + c.w.appendConnectionCloseApplicationFrame(e.Code, e.Reason) + } + default: + // TLS alerts are sent using error codes [0x0100,0x01ff). + // https://www.rfc-editor.org/rfc/rfc9000#section-20.1-2.36.1 + var alert tls.AlertError + switch { + case errors.As(err, &alert): + // tls.AlertError is a uint8, so this can't exceed 0x01ff. + code := errTLSBase + transportError(alert) + c.w.appendConnectionCloseTransportFrame(code, 0, "") + default: + c.w.appendConnectionCloseTransportFrame(errInternal, 0, "") + } + } +} diff --git a/src/vendor/golang.org/x/net/quic/conn_streams.go b/src/vendor/golang.org/x/net/quic/conn_streams.go new file mode 100644 index 0000000000..0e4bf50094 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/conn_streams.go @@ -0,0 +1,483 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +type streamsState struct { + queue queue[*Stream] // new, peer-created streams + + // All peer-created streams. + // + // Implicitly created streams are included as an empty entry in the map. + // (For example, if we receive a frame for stream 4, we implicitly create stream 0 and + // insert an empty entry for it to the map.) + // + // The map value is maybeStream rather than *Stream as a reminder that values can be nil. + streams map[streamID]maybeStream + + // Limits on the number of streams, indexed by streamType. + localLimit [streamTypeCount]localStreamLimits + remoteLimit [streamTypeCount]remoteStreamLimits + + // Peer configuration provided in transport parameters. + peerInitialMaxStreamDataRemote [streamTypeCount]int64 // streams opened by us + peerInitialMaxStreamDataBidiLocal int64 // streams opened by them + + // Connection-level flow control. + inflow connInflow + outflow connOutflow + + // Streams with frames to send are stored in one of two circular linked lists, + // depending on whether they require connection-level flow control. + needSend atomic.Bool + sendMu sync.Mutex + queueMeta streamRing // streams with any non-flow-controlled frames + queueData streamRing // streams with only flow-controlled frames +} + +// maybeStream is a possibly nil *Stream. See streamsState.streams. +type maybeStream struct { + s *Stream +} + +func (c *Conn) streamsInit() { + c.streams.streams = make(map[streamID]maybeStream) + c.streams.queue = newQueue[*Stream]() + c.streams.localLimit[bidiStream].init() + c.streams.localLimit[uniStream].init() + c.streams.remoteLimit[bidiStream].init(c.config.maxBidiRemoteStreams()) + c.streams.remoteLimit[uniStream].init(c.config.maxUniRemoteStreams()) + c.inflowInit() +} + +func (c *Conn) streamsCleanup() { + c.streams.queue.close(errConnClosed) + c.streams.localLimit[bidiStream].connHasClosed() + c.streams.localLimit[uniStream].connHasClosed() + for _, s := range c.streams.streams { + if s.s != nil { + s.s.connHasClosed() + } + } +} + +// AcceptStream waits for and returns the next stream created by the peer. +func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) { + return c.streams.queue.get(ctx) +} + +// NewStream creates a stream. +// +// If the peer's maximum stream limit for the connection has been reached, +// NewStream blocks until the limit is increased or the context expires. +func (c *Conn) NewStream(ctx context.Context) (*Stream, error) { + return c.newLocalStream(ctx, bidiStream) +} + +// NewSendOnlyStream creates a unidirectional, send-only stream. +// +// If the peer's maximum stream limit for the connection has been reached, +// NewSendOnlyStream blocks until the limit is increased or the context expires. +func (c *Conn) NewSendOnlyStream(ctx context.Context) (*Stream, error) { + return c.newLocalStream(ctx, uniStream) +} + +func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, error) { + num, err := c.streams.localLimit[styp].open(ctx, c) + if err != nil { + return nil, err + } + + s := newStream(c, newStreamID(c.side, styp, num)) + s.outmaxbuf = c.config.maxStreamWriteBufferSize() + s.outwin = c.streams.peerInitialMaxStreamDataRemote[styp] + if styp == bidiStream { + s.inmaxbuf = c.config.maxStreamReadBufferSize() + s.inwin = c.config.maxStreamReadBufferSize() + } + s.inUnlock() + s.outUnlock() + + // Modify c.streams on the conn's loop. + if err := c.runOnLoop(ctx, func(now time.Time, c *Conn) { + c.streams.streams[s.id] = maybeStream{s} + }); err != nil { + return nil, err + } + return s, nil +} + +// streamFrameType identifies which direction of a stream, +// from the local perspective, a frame is associated with. +// +// For example, STREAM is a recvStream frame, +// because it carries data from the peer to us. +type streamFrameType uint8 + +const ( + sendStream = streamFrameType(iota) // for example, MAX_DATA + recvStream // for example, STREAM_DATA_BLOCKED +) + +// streamForID returns the stream with the given id. +// If the stream does not exist, it returns nil. +func (c *Conn) streamForID(id streamID) *Stream { + return c.streams.streams[id].s +} + +// streamForFrame returns the stream with the given id. +// If the stream does not exist, it may be created. +// +// streamForFrame aborts the connection if the stream id, state, and frame type don't align. +// For example, it aborts the connection with a STREAM_STATE error if a MAX_DATA frame +// is received for a receive-only stream, or if the peer attempts to create a stream that +// should be originated locally. +// +// streamForFrame returns nil if the stream no longer exists or if an error occurred. +func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) *Stream { + if id.streamType() == uniStream { + if (id.initiator() == c.side) != (ftype == sendStream) { + // Received an invalid frame for unidirectional stream. + // For example, a RESET_STREAM frame for a send-only stream. + c.abort(now, localTransportError{ + code: errStreamState, + reason: "invalid frame for unidirectional stream", + }) + return nil + } + } + + ms, isOpen := c.streams.streams[id] + if ms.s != nil { + return ms.s + } + + num := id.num() + styp := id.streamType() + if id.initiator() == c.side { + if num < c.streams.localLimit[styp].opened { + // This stream was created by us, and has been closed. + return nil + } + // Received a frame for a stream that should be originated by us, + // but which we never created. + c.abort(now, localTransportError{ + code: errStreamState, + reason: "received frame for unknown stream", + }) + return nil + } else { + // if isOpen, this is a stream that was implicitly opened by a + // previous frame for a larger-numbered stream, but we haven't + // actually created it yet. + if !isOpen && num < c.streams.remoteLimit[styp].opened { + // This stream was created by the peer, and has been closed. + return nil + } + } + + prevOpened := c.streams.remoteLimit[styp].opened + if err := c.streams.remoteLimit[styp].open(id); err != nil { + c.abort(now, err) + return nil + } + + // Receiving a frame for a stream implicitly creates all streams + // with the same initiator and type and a lower number. + // Add a nil entry to the streams map for each implicitly created stream. + for n := newStreamID(id.initiator(), id.streamType(), prevOpened); n < id; n += 4 { + c.streams.streams[n] = maybeStream{} + } + + s := newStream(c, id) + s.inmaxbuf = c.config.maxStreamReadBufferSize() + s.inwin = c.config.maxStreamReadBufferSize() + if id.streamType() == bidiStream { + s.outmaxbuf = c.config.maxStreamWriteBufferSize() + s.outwin = c.streams.peerInitialMaxStreamDataBidiLocal + } + s.inUnlock() + s.outUnlock() + + c.streams.streams[id] = maybeStream{s} + c.streams.queue.put(s) + return s +} + +// maybeQueueStreamForSend marks a stream as containing frames that need sending. +func (c *Conn) maybeQueueStreamForSend(s *Stream, state streamState) { + if state.wantQueue() == state.inQueue() { + return // already on the right queue + } + c.streams.sendMu.Lock() + defer c.streams.sendMu.Unlock() + state = s.state.load() // may have changed while waiting + c.queueStreamForSendLocked(s, state) + + c.streams.needSend.Store(true) + c.wake() +} + +// queueStreamForSendLocked moves a stream to the correct send queue, +// or removes it from all queues. +// +// state is the last known stream state. +func (c *Conn) queueStreamForSendLocked(s *Stream, state streamState) { + for { + wantQueue := state.wantQueue() + inQueue := state.inQueue() + if inQueue == wantQueue { + return // already on the right queue + } + + switch inQueue { + case metaQueue: + c.streams.queueMeta.remove(s) + case dataQueue: + c.streams.queueData.remove(s) + } + + switch wantQueue { + case metaQueue: + c.streams.queueMeta.append(s) + state = s.state.set(streamQueueMeta, streamQueueMeta|streamQueueData) + case dataQueue: + c.streams.queueData.append(s) + state = s.state.set(streamQueueData, streamQueueMeta|streamQueueData) + case noQueue: + state = s.state.set(0, streamQueueMeta|streamQueueData) + } + + // If the stream state changed while we were moving the stream, + // we might now be on the wrong queue. + // + // For example: + // - stream has data to send: streamOutSendData|streamQueueData + // - appendStreamFrames sends all the data: streamQueueData + // - concurrently, more data is written: streamOutSendData|streamQueueData + // - appendStreamFrames calls us with the last state it observed + // (streamQueueData). + // - We remove the stream from the queue and observe the updated state: + // streamOutSendData + // - We realize that the stream needs to go back on the data queue. + // + // Go back around the loop to confirm we're on the correct queue. + } +} + +// appendStreamFrames writes stream-related frames to the current packet. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) bool { + // MAX_DATA + if !c.appendMaxDataFrame(w, pnum, pto) { + return false + } + + if pto { + return c.appendStreamFramesPTO(w, pnum) + } + if !c.streams.needSend.Load() { + // If queueMeta includes newly-finished streams, we may extend the peer's + // stream limits. When there are no streams to process, add MAX_STREAMS + // frames here. Otherwise, wait until after we've processed queueMeta. + return c.appendMaxStreams(w, pnum, pto) + } + c.streams.sendMu.Lock() + defer c.streams.sendMu.Unlock() + // queueMeta contains streams with non-flow-controlled frames to send. + for c.streams.queueMeta.head != nil { + s := c.streams.queueMeta.head + state := s.state.load() + if state&(streamQueueMeta|streamConnRemoved) != streamQueueMeta { + panic("BUG: queueMeta stream is not streamQueueMeta") + } + if state&streamInSendMeta != 0 { + s.ingate.lock() + ok := s.appendInFramesLocked(w, pnum, pto) + state = s.inUnlockNoQueue() + if !ok { + return false + } + if state&streamInSendMeta != 0 { + panic("BUG: streamInSendMeta set after successfully appending frames") + } + } + if state&streamOutSendMeta != 0 { + s.outgate.lock() + // This might also append flow-controlled frames if we have any + // and available conn-level quota. That's fine. + ok := s.appendOutFramesLocked(w, pnum, pto) + state = s.outUnlockNoQueue() + // We're checking both ok and state, because appendOutFramesLocked + // might have filled up the packet with flow-controlled data. + // If so, we want to move the stream to queueData for any remaining frames. + if !ok && state&streamOutSendMeta != 0 { + return false + } + if state&streamOutSendMeta != 0 { + panic("BUG: streamOutSendMeta set after successfully appending frames") + } + } + // We've sent all frames for this stream, so remove it from the send queue. + c.streams.queueMeta.remove(s) + if state&(streamInDone|streamOutDone) == streamInDone|streamOutDone { + // Stream is finished, remove it from the conn. + state = s.state.set(streamConnRemoved, streamQueueMeta|streamConnRemoved) + delete(c.streams.streams, s.id) + + // Record finalization of remote streams, to know when + // to extend the peer's stream limit. + if s.id.initiator() != c.side { + c.streams.remoteLimit[s.id.streamType()].close() + } + } else { + state = s.state.set(0, streamQueueMeta|streamConnRemoved) + } + // The stream may have flow-controlled data to send, + // or something might have added non-flow-controlled frames after we + // unlocked the stream. + // If so, put the stream back on a queue. + c.queueStreamForSendLocked(s, state) + } + + // MAX_STREAMS (possibly triggered by finalization of remote streams above). + if !c.appendMaxStreams(w, pnum, pto) { + return false + } + + // queueData contains streams with flow-controlled frames. + for c.streams.queueData.head != nil { + avail := c.streams.outflow.avail() + if avail == 0 { + break // no flow control quota available + } + s := c.streams.queueData.head + s.outgate.lock() + ok := s.appendOutFramesLocked(w, pnum, pto) + state := s.outUnlockNoQueue() + if !ok { + // We've sent some data for this stream, but it still has more to send. + // If the stream got a reasonable chance to put data in a packet, + // advance sendHead to the next stream in line, to avoid starvation. + // We'll come back to this stream after going through the others. + // + // If the packet was already mostly out of space, leave sendHead alone + // and come back to this stream again on the next packet. + if avail > 512 { + c.streams.queueData.head = s.next + } + return false + } + if state&streamQueueData == 0 { + panic("BUG: queueData stream is not streamQueueData") + } + if state&streamOutSendData != 0 { + // We must have run out of connection-level flow control: + // appendOutFramesLocked says it wrote all it can, but there's + // still data to send. + // + // Advance sendHead to the next stream in line to avoid starvation. + if c.streams.outflow.avail() != 0 { + panic("BUG: streamOutSendData set and flow control available after send") + } + c.streams.queueData.head = s.next + return true + } + c.streams.queueData.remove(s) + state = s.state.set(0, streamQueueData) + c.queueStreamForSendLocked(s, state) + } + if c.streams.queueMeta.head == nil && c.streams.queueData.head == nil { + c.streams.needSend.Store(false) + } + return true +} + +// appendStreamFramesPTO writes stream-related frames to the current packet +// for a PTO probe. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { + const pto = true + if !c.appendMaxStreams(w, pnum, pto) { + return false + } + c.streams.sendMu.Lock() + defer c.streams.sendMu.Unlock() + for _, ms := range c.streams.streams { + s := ms.s + if s == nil { + continue + } + const pto = true + s.ingate.lock() + inOK := s.appendInFramesLocked(w, pnum, pto) + s.inUnlockNoQueue() + if !inOK { + return false + } + + s.outgate.lock() + outOK := s.appendOutFramesLocked(w, pnum, pto) + s.outUnlockNoQueue() + if !outOK { + return false + } + } + return true +} + +func (c *Conn) appendMaxStreams(w *packetWriter, pnum packetNumber, pto bool) bool { + if !c.streams.remoteLimit[uniStream].appendFrame(w, uniStream, pnum, pto) { + return false + } + if !c.streams.remoteLimit[bidiStream].appendFrame(w, bidiStream, pnum, pto) { + return false + } + return true +} + +// A streamRing is a circular linked list of streams. +type streamRing struct { + head *Stream +} + +// remove removes s from the ring. +// s must be on the ring. +func (r *streamRing) remove(s *Stream) { + if s.next == s { + r.head = nil // s was the last stream in the ring + } else { + s.prev.next = s.next + s.next.prev = s.prev + if r.head == s { + r.head = s.next + } + } +} + +// append places s at the last position in the ring. +// s must not be attached to any ring. +func (r *streamRing) append(s *Stream) { + if r.head == nil { + r.head = s + s.next = s + s.prev = s + } else { + s.prev = r.head.prev + s.next = r.head + s.prev.next = s + s.next.prev = s + } +} diff --git a/src/vendor/golang.org/x/net/quic/crypto_stream.go b/src/vendor/golang.org/x/net/quic/crypto_stream.go new file mode 100644 index 0000000000..a5b9818296 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/crypto_stream.go @@ -0,0 +1,157 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// "Implementations MUST support buffering at least 4096 bytes of data +// received in out-of-order CRYPTO frames." +// https://www.rfc-editor.org/rfc/rfc9000.html#section-7.5-2 +// +// 4096 is too small for real-world cases, however, so we allow more. +const cryptoBufferSize = 1 << 20 + +// A cryptoStream is the stream of data passed in CRYPTO frames. +// There is one cryptoStream per packet number space. +type cryptoStream struct { + // CRYPTO data received from the peer. + in pipe + inset rangeset[int64] // bytes received + + // CRYPTO data queued for transmission to the peer. + out pipe + outunsent rangeset[int64] // bytes in need of sending + outacked rangeset[int64] // bytes acked by peer +} + +// handleCrypto processes data received in a CRYPTO frame. +func (s *cryptoStream) handleCrypto(off int64, b []byte, f func([]byte) error) error { + end := off + int64(len(b)) + if end-s.inset.min() > cryptoBufferSize { + return localTransportError{ + code: errCryptoBufferExceeded, + reason: "crypto buffer exceeded", + } + } + s.inset.add(off, end) + if off == s.in.start { + // Fast path: This is the next chunk of data in the stream, + // so just handle it immediately. + if err := f(b); err != nil { + return err + } + s.in.discardBefore(end) + } else { + // This is either data we've already processed, + // data we can't process yet, or a mix of both. + s.in.writeAt(b, off) + } + // s.in.start is the next byte in sequence. + // If it's in s.inset, we have bytes to provide. + // If it isn't, we don't--we're either out of data, + // or only have data that comes after the next byte. + if !s.inset.contains(s.in.start) { + return nil + } + // size is the size of the first contiguous chunk of bytes + // that have not been processed yet. + size := int(s.inset[0].end - s.in.start) + if size <= 0 { + return nil + } + err := s.in.read(s.in.start, size, f) + s.in.discardBefore(s.inset[0].end) + return err +} + +// write queues data for sending to the peer. +// It does not block or limit the amount of buffered data. +// QUIC connections don't communicate the amount of CRYPTO data they are willing to buffer, +// so we send what we have and the peer can close the connection if it is too much. +func (s *cryptoStream) write(b []byte) { + start := s.out.end + s.out.writeAt(b, start) + s.outunsent.add(start, s.out.end) +} + +// ackOrLoss reports that an CRYPTO frame sent by us has been acknowledged by the peer, or lost. +func (s *cryptoStream) ackOrLoss(start, end int64, fate packetFate) { + switch fate { + case packetAcked: + s.outacked.add(start, end) + s.outunsent.sub(start, end) + // If this ack is for data at the start of the send buffer, we can now discard it. + if s.outacked.contains(s.out.start) { + s.out.discardBefore(s.outacked[0].end) + } + case packetLost: + // Mark everything lost, but not previously acked, as needing retransmission. + // We do this by adding all the lost bytes to outunsent, and then + // removing everything already acked. + s.outunsent.add(start, end) + for _, a := range s.outacked { + s.outunsent.sub(a.start, a.end) + } + } +} + +// dataToSend reports what data should be sent in CRYPTO frames to the peer. +// It calls f with each range of data to send. +// f uses sendData to get the bytes to send, and returns the number of bytes sent. +// dataToSend calls f until no data is left, or f returns 0. +// +// This function is unusually indirect (why not just return a []byte, +// or implement io.Reader?). +// +// Returning a []byte to the caller either requires that we store the +// data to send contiguously (which we don't), allocate a temporary buffer +// and copy into it (inefficient), or return less data than we have available +// (requires complexity to avoid unnecessarily breaking data across frames). +// +// Accepting a []byte from the caller (io.Reader) makes packet construction +// difficult. Since CRYPTO data is encoded with a varint length prefix, the +// location of the data depends on the length of the data. (We could hardcode +// a 2-byte length, of course.) +// +// Instead, we tell the caller how much data is, the caller figures out where +// to put it (and possibly decides that it doesn't have space for this data +// in the packet after all), and the caller then makes a separate call to +// copy the data it wants into position. +func (s *cryptoStream) dataToSend(pto bool, f func(off, size int64) (sent int64)) { + for { + off, size := dataToSend(s.out.start, s.out.end, s.outunsent, s.outacked, pto) + if size == 0 { + return + } + n := f(off, size) + if n == 0 || pto { + return + } + } +} + +// sendData fills b with data to send to the peer, starting at off, +// and marks the data as sent. The caller must have already ascertained +// that there is data to send in this region using dataToSend. +func (s *cryptoStream) sendData(off int64, b []byte) { + s.out.copy(off, b) + s.outunsent.sub(off, off+int64(len(b))) +} + +// discardKeys is called when the packet protection keys for the stream are dropped. +func (s *cryptoStream) discardKeys() error { + if s.in.end-s.in.start != 0 { + // The peer sent some unprocessed CRYPTO data that we're about to discard. + // Close the connection with a TLS unexpected_message alert. + // https://www.rfc-editor.org/rfc/rfc5246#section-7.2.2 + const unexpectedMessage = 10 + return localTransportError{ + code: errTLSBase + unexpectedMessage, + reason: "excess crypto data", + } + } + // Discard any unacked (but presumably received) data in our output buffer. + s.out.discardBefore(s.out.end) + *s = cryptoStream{} + return nil +} diff --git a/src/vendor/golang.org/x/net/quic/dgram.go b/src/vendor/golang.org/x/net/quic/dgram.go new file mode 100644 index 0000000000..cea03694ee --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/dgram.go @@ -0,0 +1,53 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "net/netip" + "sync" +) + +type datagram struct { + b []byte + localAddr netip.AddrPort + peerAddr netip.AddrPort + ecn ecnBits +} + +// Explicit Congestion Notification bits. +// +// https://www.rfc-editor.org/rfc/rfc3168.html#section-5 +type ecnBits byte + +const ( + ecnMask = 0b000000_11 + ecnNotECT = 0b000000_00 + ecnECT1 = 0b000000_01 + ecnECT0 = 0b000000_10 + ecnCE = 0b000000_11 +) + +var datagramPool = sync.Pool{ + New: func() any { + return &datagram{ + b: make([]byte, maxUDPPayloadSize), + } + }, +} + +func newDatagram() *datagram { + m := datagramPool.Get().(*datagram) + *m = datagram{ + b: m.b[:cap(m.b)], + } + return m +} + +func (m *datagram) recycle() { + if cap(m.b) != maxUDPPayloadSize { + return + } + datagramPool.Put(m) +} diff --git a/src/vendor/golang.org/x/net/quic/doc.go b/src/vendor/golang.org/x/net/quic/doc.go new file mode 100644 index 0000000000..8d5a78f8a8 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/doc.go @@ -0,0 +1,50 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package quic implements the QUIC protocol. +// +// This package is a work in progress. +// It is not ready for production usage. +// Its API is subject to change without notice. +// +// This package is low-level. +// Most users will use it indirectly through an HTTP/3 implementation. +// +// # Usage +// +// An [Endpoint] sends and receives traffic on a network address. +// Create an Endpoint to either accept inbound QUIC connections +// or create outbound ones. +// +// A [Conn] is a QUIC connection. +// +// A [Stream] is a QUIC stream, an ordered, reliable byte stream. +// +// # Cancellation +// +// All blocking operations may be canceled using a context.Context. +// When performing an operation with a canceled context, the operation +// will succeed if doing so does not require blocking. For example, +// reading from a stream will return data when buffered data is available, +// even if the stream context is canceled. +// +// # Limitations +// +// This package is a work in progress. +// Known limitations include: +// +// - Performance is untuned. +// - 0-RTT is not supported. +// - Address migration is not supported. +// - Server preferred addresses are not supported. +// - The latency spin bit is not supported. +// - Stream send/receive windows are configurable, +// but are fixed and do not adapt to available throughput. +// - Path MTU discovery is not implemented. +// +// # Security Policy +// +// This package is a work in progress, +// and not yet covered by the Go security policy. +package quic diff --git a/src/vendor/golang.org/x/net/quic/endpoint.go b/src/vendor/golang.org/x/net/quic/endpoint.go new file mode 100644 index 0000000000..3d68073cd6 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/endpoint.go @@ -0,0 +1,472 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "crypto/rand" + "errors" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" +) + +// An Endpoint handles QUIC traffic on a network address. +// It can accept inbound connections or create outbound ones. +// +// Multiple goroutines may invoke methods on an Endpoint simultaneously. +type Endpoint struct { + listenConfig *Config + packetConn packetConn + testHooks endpointTestHooks + resetGen statelessResetTokenGenerator + retry retryState + + acceptQueue queue[*Conn] // new inbound connections + connsMap connsMap // only accessed by the listen loop + + connsMu sync.Mutex + conns map[*Conn]struct{} + closing bool // set when Close is called + closec chan struct{} // closed when the listen loop exits +} + +type endpointTestHooks interface { + newConn(c *Conn) +} + +// A packetConn is the interface to sending and receiving UDP packets. +type packetConn interface { + Close() error + LocalAddr() netip.AddrPort + Read(f func(*datagram)) + Write(datagram) error +} + +// Listen listens on a local network address. +// +// The config is used to for connections accepted by the endpoint. +// If the config is nil, the endpoint will not accept connections. +func Listen(network, address string, listenConfig *Config) (*Endpoint, error) { + if listenConfig != nil && listenConfig.TLSConfig == nil { + return nil, errors.New("TLSConfig is not set") + } + a, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + udpConn, err := net.ListenUDP(network, a) + if err != nil { + return nil, err + } + pc, err := newNetUDPConn(udpConn) + if err != nil { + return nil, err + } + return newEndpoint(pc, listenConfig, nil) +} + +// NewEndpoint creates an endpoint using a net.PacketConn as the underlying transport. +// +// If the PacketConn is not a *net.UDPConn, the endpoint may be slower and lack +// access to some features of the network. +func NewEndpoint(conn net.PacketConn, config *Config) (*Endpoint, error) { + var pc packetConn + var err error + switch conn := conn.(type) { + case *net.UDPConn: + pc, err = newNetUDPConn(conn) + default: + pc, err = newNetPacketConn(conn) + } + if err != nil { + return nil, err + } + return newEndpoint(pc, config, nil) +} + +func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { + e := &Endpoint{ + listenConfig: config, + packetConn: pc, + testHooks: hooks, + conns: make(map[*Conn]struct{}), + acceptQueue: newQueue[*Conn](), + closec: make(chan struct{}), + } + var statelessResetKey [32]byte + if config != nil { + statelessResetKey = config.StatelessResetKey + } + e.resetGen.init(statelessResetKey) + e.connsMap.init() + if config != nil && config.RequireAddressValidation { + if err := e.retry.init(); err != nil { + return nil, err + } + } + go e.listen() + return e, nil +} + +// LocalAddr returns the local network address. +func (e *Endpoint) LocalAddr() netip.AddrPort { + return e.packetConn.LocalAddr() +} + +// Close closes the Endpoint. +// Any blocked operations on the Endpoint or associated Conns and Stream will be unblocked +// and return errors. +// +// Close aborts every open connection. +// Data in stream read and write buffers is discarded. +// It waits for the peers of any open connection to acknowledge the connection has been closed. +func (e *Endpoint) Close(ctx context.Context) error { + e.acceptQueue.close(errors.New("endpoint closed")) + + // It isn't safe to call Conn.Abort or conn.exit with connsMu held, + // so copy the list of conns. + var conns []*Conn + e.connsMu.Lock() + if !e.closing { + e.closing = true // setting e.closing prevents new conns from being created + for c := range e.conns { + conns = append(conns, c) + } + if len(e.conns) == 0 { + e.packetConn.Close() + } + } + e.connsMu.Unlock() + + for _, c := range conns { + c.Abort(localTransportError{code: errNo}) + } + select { + case <-e.closec: + case <-ctx.Done(): + for _, c := range conns { + c.exit() + } + return ctx.Err() + } + return nil +} + +// Accept waits for and returns the next connection. +func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) { + return e.acceptQueue.get(ctx) +} + +// Dial creates and returns a connection to a network address. +// The config cannot be nil. +func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) { + u, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + addr := u.AddrPort() + addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) + c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr) + if err != nil { + return nil, err + } + if err := c.waitReady(ctx); err != nil { + c.Abort(nil) + return nil, err + } + return c, nil +} + +func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) { + e.connsMu.Lock() + defer e.connsMu.Unlock() + if e.closing { + return nil, errors.New("endpoint closed") + } + c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e) + if err != nil { + return nil, err + } + e.conns[c] = struct{}{} + return c, nil +} + +// serverConnEstablished is called by a conn when the handshake completes +// for an inbound (serverSide) connection. +func (e *Endpoint) serverConnEstablished(c *Conn) { + e.acceptQueue.put(c) +} + +// connDrained is called by a conn when it leaves the draining state, +// either when the peer acknowledges connection closure or the drain timeout expires. +func (e *Endpoint) connDrained(c *Conn) { + var cids [][]byte + for i := range c.connIDState.local { + cids = append(cids, c.connIDState.local[i].cid) + } + var tokens []statelessResetToken + for i := range c.connIDState.remote { + tokens = append(tokens, c.connIDState.remote[i].resetToken) + } + e.connsMap.updateConnIDs(func(conns *connsMap) { + for _, cid := range cids { + conns.retireConnID(c, cid) + } + for _, token := range tokens { + conns.retireResetToken(c, token) + } + }) + e.connsMu.Lock() + defer e.connsMu.Unlock() + delete(e.conns, c) + if e.closing && len(e.conns) == 0 { + e.packetConn.Close() + } +} + +func (e *Endpoint) listen() { + defer close(e.closec) + e.packetConn.Read(func(m *datagram) { + if e.connsMap.updateNeeded.Load() { + e.connsMap.applyUpdates() + } + e.handleDatagram(m) + }) +} + +func (e *Endpoint) handleDatagram(m *datagram) { + dstConnID, ok := dstConnIDForDatagram(m.b) + if !ok { + m.recycle() + return + } + c := e.connsMap.byConnID[string(dstConnID)] + if c == nil { + // TODO: Move this branch into a separate goroutine to avoid blocking + // the endpoint while processing packets. + e.handleUnknownDestinationDatagram(m) + return + } + + // TODO: This can block the endpoint while waiting for the conn to accept the dgram. + // Think about buffering between the receive loop and the conn. + c.sendMsg(m) +} + +func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { + defer func() { + if m != nil { + m.recycle() + } + }() + const minimumValidPacketSize = 21 + if len(m.b) < minimumValidPacketSize { + return + } + now := time.Now() + // Check to see if this is a stateless reset. + var token statelessResetToken + copy(token[:], m.b[len(m.b)-len(token):]) + if c := e.connsMap.byResetToken[token]; c != nil { + c.sendMsg(func(now time.Time, c *Conn) { + c.handleStatelessReset(now, token) + }) + return + } + // If this is a 1-RTT packet, there's nothing productive we can do with it. + // Send a stateless reset if possible. + if !isLongHeader(m.b[0]) { + e.maybeSendStatelessReset(m.b, m.peerAddr) + return + } + p, ok := parseGenericLongHeaderPacket(m.b) + if !ok || len(m.b) < paddedInitialDatagramSize { + return + } + switch p.version { + case quicVersion1: + case 0: + // Version Negotiation for an unknown connection. + return + default: + // Unknown version. + e.sendVersionNegotiation(p, m.peerAddr) + return + } + if getPacketType(m.b) != packetTypeInitial { + // This packet isn't trying to create a new connection. + // It might be associated with some connection we've lost state for. + // We are technically permitted to send a stateless reset for + // a long-header packet, but this isn't generally useful. See: + // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16 + return + } + if e.listenConfig == nil { + // We are not configured to accept connections. + return + } + cids := newServerConnIDs{ + srcConnID: p.srcConnID, + dstConnID: p.dstConnID, + } + if e.listenConfig.RequireAddressValidation { + var ok bool + cids.retrySrcConnID = p.dstConnID + cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr) + if !ok { + return + } + } else { + cids.originalDstConnID = p.dstConnID + } + var err error + c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr) + if err != nil { + // The accept queue is probably full. + // We could send a CONNECTION_CLOSE to the peer to reject the connection. + // Currently, we just drop the datagram. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5 + return + } + c.sendMsg(m) + m = nil // don't recycle, sendMsg takes ownership +} + +func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) { + if !e.resetGen.canReset { + // Config.StatelessResetKey isn't set, so we don't send stateless resets. + return + } + // The smallest possible valid packet a peer can send us is: + // 1 byte of header + // connIDLen bytes of destination connection ID + // 1 byte of packet number + // 1 byte of payload + // 16 bytes AEAD expansion + if len(b) < 1+connIDLen+1+1+16 { + return + } + // TODO: Rate limit stateless resets. + cid := b[1:][:connIDLen] + token := e.resetGen.tokenForConnID(cid) + // We want to generate a stateless reset that is as short as possible, + // but long enough to be difficult to distinguish from a 1-RTT packet. + // + // The minimal 1-RTT packet is: + // 1 byte of header + // 0-20 bytes of destination connection ID + // 1-4 bytes of packet number + // 1 byte of payload + // 16 bytes AEAD expansion + // + // Assuming the maximum possible connection ID and packet number size, + // this gives 1 + 20 + 4 + 1 + 16 = 42 bytes. + // + // We also must generate a stateless reset that is shorter than the datagram + // we are responding to, in order to ensure that reset loops terminate. + // + // See: https://www.rfc-editor.org/rfc/rfc9000#section-10.3 + size := min(len(b)-1, 42) + // Reuse the input buffer for generating the stateless reset. + b = b[:size] + rand.Read(b[:len(b)-statelessResetTokenLen]) + b[0] &^= headerFormLong // clear long header bit + b[0] |= fixedBit // set fixed bit + copy(b[len(b)-statelessResetTokenLen:], token[:]) + e.sendDatagram(datagram{ + b: b, + peerAddr: peerAddr, + }) +} + +func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) { + m := newDatagram() + m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1) + m.peerAddr = peerAddr + e.sendDatagram(*m) + m.recycle() +} + +func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) { + keys := initialKeys(in.dstConnID, serverSide) + var w packetWriter + p := longPacket{ + ptype: packetTypeInitial, + version: quicVersion1, + num: 0, + dstConnID: in.srcConnID, + srcConnID: in.dstConnID, + } + const pnumMaxAcked = 0 + w.reset(paddedInitialDatagramSize) + w.startProtectedLongHeaderPacket(pnumMaxAcked, p) + w.appendConnectionCloseTransportFrame(code, 0, "") + w.finishProtectedLongHeaderPacket(pnumMaxAcked, keys.w, p) + buf := w.datagram() + if len(buf) == 0 { + return + } + e.sendDatagram(datagram{ + b: buf, + peerAddr: peerAddr, + }) +} + +func (e *Endpoint) sendDatagram(dgram datagram) error { + return e.packetConn.Write(dgram) +} + +// A connsMap is an endpoint's mapping of conn ids and reset tokens to conns. +type connsMap struct { + byConnID map[string]*Conn + byResetToken map[statelessResetToken]*Conn + + updateMu sync.Mutex + updateNeeded atomic.Bool + updates []func(*connsMap) +} + +func (m *connsMap) init() { + m.byConnID = map[string]*Conn{} + m.byResetToken = map[statelessResetToken]*Conn{} +} + +func (m *connsMap) addConnID(c *Conn, cid []byte) { + m.byConnID[string(cid)] = c +} + +func (m *connsMap) retireConnID(c *Conn, cid []byte) { + delete(m.byConnID, string(cid)) +} + +func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) { + m.byResetToken[token] = c +} + +func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) { + delete(m.byResetToken, token) +} + +func (m *connsMap) updateConnIDs(f func(*connsMap)) { + m.updateMu.Lock() + defer m.updateMu.Unlock() + m.updates = append(m.updates, f) + m.updateNeeded.Store(true) +} + +// applyUpdates is called by the datagram receive loop to update its connection ID map. +func (m *connsMap) applyUpdates() { + m.updateMu.Lock() + defer m.updateMu.Unlock() + for _, f := range m.updates { + f(m) + } + clear(m.updates) + m.updates = m.updates[:0] + m.updateNeeded.Store(false) +} diff --git a/src/vendor/golang.org/x/net/quic/errors.go b/src/vendor/golang.org/x/net/quic/errors.go new file mode 100644 index 0000000000..1226370d26 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/errors.go @@ -0,0 +1,129 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "fmt" +) + +// A transportError is a transport error code from RFC 9000 Section 20.1. +// +// The transportError type doesn't implement the error interface to ensure we always +// distinguish between errors sent to and received from the peer. +// See the localTransportError and peerTransportError types below. +type transportError uint64 + +// https://www.rfc-editor.org/rfc/rfc9000.html#section-20.1 +const ( + errNo = transportError(0x00) + errInternal = transportError(0x01) + errConnectionRefused = transportError(0x02) + errFlowControl = transportError(0x03) + errStreamLimit = transportError(0x04) + errStreamState = transportError(0x05) + errFinalSize = transportError(0x06) + errFrameEncoding = transportError(0x07) + errTransportParameter = transportError(0x08) + errConnectionIDLimit = transportError(0x09) + errProtocolViolation = transportError(0x0a) + errInvalidToken = transportError(0x0b) + errApplicationError = transportError(0x0c) + errCryptoBufferExceeded = transportError(0x0d) + errKeyUpdateError = transportError(0x0e) + errAEADLimitReached = transportError(0x0f) + errNoViablePath = transportError(0x10) + errTLSBase = transportError(0x0100) // 0x0100-0x01ff; base + TLS code +) + +func (e transportError) String() string { + switch e { + case errNo: + return "NO_ERROR" + case errInternal: + return "INTERNAL_ERROR" + case errConnectionRefused: + return "CONNECTION_REFUSED" + case errFlowControl: + return "FLOW_CONTROL_ERROR" + case errStreamLimit: + return "STREAM_LIMIT_ERROR" + case errStreamState: + return "STREAM_STATE_ERROR" + case errFinalSize: + return "FINAL_SIZE_ERROR" + case errFrameEncoding: + return "FRAME_ENCODING_ERROR" + case errTransportParameter: + return "TRANSPORT_PARAMETER_ERROR" + case errConnectionIDLimit: + return "CONNECTION_ID_LIMIT_ERROR" + case errProtocolViolation: + return "PROTOCOL_VIOLATION" + case errInvalidToken: + return "INVALID_TOKEN" + case errApplicationError: + return "APPLICATION_ERROR" + case errCryptoBufferExceeded: + return "CRYPTO_BUFFER_EXCEEDED" + case errKeyUpdateError: + return "KEY_UPDATE_ERROR" + case errAEADLimitReached: + return "AEAD_LIMIT_REACHED" + case errNoViablePath: + return "NO_VIABLE_PATH" + } + if e >= 0x0100 && e <= 0x01ff { + return fmt.Sprintf("CRYPTO_ERROR(%v)", uint64(e)&0xff) + } + return fmt.Sprintf("ERROR %d", uint64(e)) +} + +// A localTransportError is an error sent to the peer. +type localTransportError struct { + code transportError + reason string +} + +func (e localTransportError) Error() string { + if e.reason == "" { + return fmt.Sprintf("closed connection: %v", e.code) + } + return fmt.Sprintf("closed connection: %v: %q", e.code, e.reason) +} + +// A peerTransportError is an error received from the peer. +type peerTransportError struct { + code transportError + reason string +} + +func (e peerTransportError) Error() string { + return fmt.Sprintf("peer closed connection: %v: %q", e.code, e.reason) +} + +// A StreamErrorCode is an application protocol error code (RFC 9000, Section 20.2) +// indicating why a stream is being closed. +type StreamErrorCode uint64 + +func (e StreamErrorCode) Error() string { + return fmt.Sprintf("stream error code %v", uint64(e)) +} + +// An ApplicationError is an application protocol error code (RFC 9000, Section 20.2). +// Application protocol errors may be sent when terminating a stream or connection. +type ApplicationError struct { + Code uint64 + Reason string +} + +func (e *ApplicationError) Error() string { + return fmt.Sprintf("peer closed connection: %v: %q", e.Code, e.Reason) +} + +// Is reports a match if err is an *ApplicationError with a matching Code. +func (e *ApplicationError) Is(err error) bool { + e2, ok := err.(*ApplicationError) + return ok && e2.Code == e.Code +} diff --git a/src/vendor/golang.org/x/net/quic/frame_debug.go b/src/vendor/golang.org/x/net/quic/frame_debug.go new file mode 100644 index 0000000000..8d8fd54517 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/frame_debug.go @@ -0,0 +1,730 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "fmt" + "log/slog" + "strconv" + "time" +) + +// A debugFrame is a representation of the contents of a QUIC frame, +// used for debug logs and testing but not the primary serving path. +type debugFrame interface { + String() string + write(w *packetWriter) bool + LogValue() slog.Value +} + +func parseDebugFrame(b []byte) (f debugFrame, n int) { + if len(b) == 0 { + return nil, -1 + } + switch b[0] { + case frameTypePadding: + f, n = parseDebugFramePadding(b) + case frameTypePing: + f, n = parseDebugFramePing(b) + case frameTypeAck, frameTypeAckECN: + f, n = parseDebugFrameAck(b) + case frameTypeResetStream: + f, n = parseDebugFrameResetStream(b) + case frameTypeStopSending: + f, n = parseDebugFrameStopSending(b) + case frameTypeCrypto: + f, n = parseDebugFrameCrypto(b) + case frameTypeNewToken: + f, n = parseDebugFrameNewToken(b) + case frameTypeStreamBase, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f: + f, n = parseDebugFrameStream(b) + case frameTypeMaxData: + f, n = parseDebugFrameMaxData(b) + case frameTypeMaxStreamData: + f, n = parseDebugFrameMaxStreamData(b) + case frameTypeMaxStreamsBidi, frameTypeMaxStreamsUni: + f, n = parseDebugFrameMaxStreams(b) + case frameTypeDataBlocked: + f, n = parseDebugFrameDataBlocked(b) + case frameTypeStreamDataBlocked: + f, n = parseDebugFrameStreamDataBlocked(b) + case frameTypeStreamsBlockedBidi, frameTypeStreamsBlockedUni: + f, n = parseDebugFrameStreamsBlocked(b) + case frameTypeNewConnectionID: + f, n = parseDebugFrameNewConnectionID(b) + case frameTypeRetireConnectionID: + f, n = parseDebugFrameRetireConnectionID(b) + case frameTypePathChallenge: + f, n = parseDebugFramePathChallenge(b) + case frameTypePathResponse: + f, n = parseDebugFramePathResponse(b) + case frameTypeConnectionCloseTransport: + f, n = parseDebugFrameConnectionCloseTransport(b) + case frameTypeConnectionCloseApplication: + f, n = parseDebugFrameConnectionCloseApplication(b) + case frameTypeHandshakeDone: + f, n = parseDebugFrameHandshakeDone(b) + default: + return nil, -1 + } + return f, n +} + +// debugFramePadding is a sequence of PADDING frames. +type debugFramePadding struct { + size int + to int // alternate for writing packets: pad to +} + +func parseDebugFramePadding(b []byte) (f debugFramePadding, n int) { + for n < len(b) && b[n] == frameTypePadding { + n++ + } + f.size = n + return f, n +} + +func (f debugFramePadding) String() string { + return fmt.Sprintf("PADDING*%v", f.size) +} + +func (f debugFramePadding) write(w *packetWriter) bool { + if w.avail() == 0 { + return false + } + if f.to > 0 { + w.appendPaddingTo(f.to) + return true + } + for i := 0; i < f.size && w.avail() > 0; i++ { + w.b = append(w.b, frameTypePadding) + } + return true +} + +func (f debugFramePadding) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "padding"), + slog.Int("length", f.size), + ) +} + +// debugFramePing is a PING frame. +type debugFramePing struct{} + +func parseDebugFramePing(b []byte) (f debugFramePing, n int) { + return f, 1 +} + +func (f debugFramePing) String() string { + return "PING" +} + +func (f debugFramePing) write(w *packetWriter) bool { + return w.appendPingFrame() +} + +func (f debugFramePing) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "ping"), + ) +} + +// debugFrameAck is an ACK frame. +type debugFrameAck struct { + ackDelay unscaledAckDelay + ranges []i64range[packetNumber] + ecn ecnCounts +} + +func parseDebugFrameAck(b []byte) (f debugFrameAck, n int) { + f.ranges = nil + _, f.ackDelay, f.ecn, n = consumeAckFrame(b, func(_ int, start, end packetNumber) { + f.ranges = append(f.ranges, i64range[packetNumber]{ + start: start, + end: end, + }) + }) + // Ranges are parsed high to low; reverse ranges slice to order them low to high. + for i := 0; i < len(f.ranges)/2; i++ { + j := len(f.ranges) - 1 + f.ranges[i], f.ranges[j] = f.ranges[j], f.ranges[i] + } + return f, n +} + +func (f debugFrameAck) String() string { + s := fmt.Sprintf("ACK Delay=%v", f.ackDelay) + for _, r := range f.ranges { + s += fmt.Sprintf(" [%v,%v)", r.start, r.end) + } + + if (f.ecn != ecnCounts{}) { + s += fmt.Sprintf(" ECN=[%d,%d,%d]", f.ecn.t0, f.ecn.t1, f.ecn.ce) + } + return s +} + +func (f debugFrameAck) write(w *packetWriter) bool { + return w.appendAckFrame(rangeset[packetNumber](f.ranges), f.ackDelay, f.ecn) +} + +func (f debugFrameAck) LogValue() slog.Value { + return slog.StringValue("error: debugFrameAck should not appear as a slog Value") +} + +// debugFrameScaledAck is an ACK frame with scaled ACK Delay. +// +// This type is used in qlog events, which need access to the delay as a duration. +type debugFrameScaledAck struct { + ackDelay time.Duration + ranges []i64range[packetNumber] +} + +func (f debugFrameScaledAck) LogValue() slog.Value { + var ackDelay slog.Attr + if f.ackDelay >= 0 { + ackDelay = slog.Duration("ack_delay", f.ackDelay) + } + return slog.GroupValue( + slog.String("frame_type", "ack"), + // Rather than trying to convert the ack ranges into the slog data model, + // pass a value that can JSON-encode itself. + slog.Any("acked_ranges", debugAckRanges(f.ranges)), + ackDelay, + ) +} + +type debugAckRanges []i64range[packetNumber] + +// AppendJSON appends a JSON encoding of the ack ranges to b, and returns it. +// This is different than the standard json.Marshaler, but more efficient. +// Since we only use this in cooperation with the qlog package, +// encoding/json compatibility is irrelevant. +func (r debugAckRanges) AppendJSON(b []byte) []byte { + b = append(b, '[') + for i, ar := range r { + start, end := ar.start, ar.end-1 // qlog ranges are closed-closed + if i != 0 { + b = append(b, ',') + } + b = append(b, '[') + b = strconv.AppendInt(b, int64(start), 10) + if start != end { + b = append(b, ',') + b = strconv.AppendInt(b, int64(end), 10) + } + b = append(b, ']') + } + b = append(b, ']') + return b +} + +func (r debugAckRanges) String() string { + return string(r.AppendJSON(nil)) +} + +// debugFrameResetStream is a RESET_STREAM frame. +type debugFrameResetStream struct { + id streamID + code uint64 + finalSize int64 +} + +func parseDebugFrameResetStream(b []byte) (f debugFrameResetStream, n int) { + f.id, f.code, f.finalSize, n = consumeResetStreamFrame(b) + return f, n +} + +func (f debugFrameResetStream) String() string { + return fmt.Sprintf("RESET_STREAM ID=%v Code=%v FinalSize=%v", f.id, f.code, f.finalSize) +} + +func (f debugFrameResetStream) write(w *packetWriter) bool { + return w.appendResetStreamFrame(f.id, f.code, f.finalSize) +} + +func (f debugFrameResetStream) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "reset_stream"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Uint64("final_size", uint64(f.finalSize)), + ) +} + +// debugFrameStopSending is a STOP_SENDING frame. +type debugFrameStopSending struct { + id streamID + code uint64 +} + +func parseDebugFrameStopSending(b []byte) (f debugFrameStopSending, n int) { + f.id, f.code, n = consumeStopSendingFrame(b) + return f, n +} + +func (f debugFrameStopSending) String() string { + return fmt.Sprintf("STOP_SENDING ID=%v Code=%v", f.id, f.code) +} + +func (f debugFrameStopSending) write(w *packetWriter) bool { + return w.appendStopSendingFrame(f.id, f.code) +} + +func (f debugFrameStopSending) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "stop_sending"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Uint64("error_code", uint64(f.code)), + ) +} + +// debugFrameCrypto is a CRYPTO frame. +type debugFrameCrypto struct { + off int64 + data []byte +} + +func parseDebugFrameCrypto(b []byte) (f debugFrameCrypto, n int) { + f.off, f.data, n = consumeCryptoFrame(b) + return f, n +} + +func (f debugFrameCrypto) String() string { + return fmt.Sprintf("CRYPTO Offset=%v Length=%v", f.off, len(f.data)) +} + +func (f debugFrameCrypto) write(w *packetWriter) bool { + b, added := w.appendCryptoFrame(f.off, len(f.data)) + copy(b, f.data) + return added +} + +func (f debugFrameCrypto) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "crypto"), + slog.Int64("offset", f.off), + slog.Int("length", len(f.data)), + ) +} + +// debugFrameNewToken is a NEW_TOKEN frame. +type debugFrameNewToken struct { + token []byte +} + +func parseDebugFrameNewToken(b []byte) (f debugFrameNewToken, n int) { + f.token, n = consumeNewTokenFrame(b) + return f, n +} + +func (f debugFrameNewToken) String() string { + return fmt.Sprintf("NEW_TOKEN Token=%x", f.token) +} + +func (f debugFrameNewToken) write(w *packetWriter) bool { + return w.appendNewTokenFrame(f.token) +} + +func (f debugFrameNewToken) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "new_token"), + slogHexstring("token", f.token), + ) +} + +// debugFrameStream is a STREAM frame. +type debugFrameStream struct { + id streamID + fin bool + off int64 + data []byte +} + +func parseDebugFrameStream(b []byte) (f debugFrameStream, n int) { + f.id, f.off, f.fin, f.data, n = consumeStreamFrame(b) + return f, n +} + +func (f debugFrameStream) String() string { + fin := "" + if f.fin { + fin = " FIN" + } + return fmt.Sprintf("STREAM ID=%v%v Offset=%v Length=%v", f.id, fin, f.off, len(f.data)) +} + +func (f debugFrameStream) write(w *packetWriter) bool { + b, added := w.appendStreamFrame(f.id, f.off, len(f.data), f.fin) + copy(b, f.data) + return added +} + +func (f debugFrameStream) LogValue() slog.Value { + var fin slog.Attr + if f.fin { + fin = slog.Bool("fin", true) + } + return slog.GroupValue( + slog.String("frame_type", "stream"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Int64("offset", f.off), + slog.Int("length", len(f.data)), + fin, + ) +} + +// debugFrameMaxData is a MAX_DATA frame. +type debugFrameMaxData struct { + max int64 +} + +func parseDebugFrameMaxData(b []byte) (f debugFrameMaxData, n int) { + f.max, n = consumeMaxDataFrame(b) + return f, n +} + +func (f debugFrameMaxData) String() string { + return fmt.Sprintf("MAX_DATA Max=%v", f.max) +} + +func (f debugFrameMaxData) write(w *packetWriter) bool { + return w.appendMaxDataFrame(f.max) +} + +func (f debugFrameMaxData) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "max_data"), + slog.Int64("maximum", f.max), + ) +} + +// debugFrameMaxStreamData is a MAX_STREAM_DATA frame. +type debugFrameMaxStreamData struct { + id streamID + max int64 +} + +func parseDebugFrameMaxStreamData(b []byte) (f debugFrameMaxStreamData, n int) { + f.id, f.max, n = consumeMaxStreamDataFrame(b) + return f, n +} + +func (f debugFrameMaxStreamData) String() string { + return fmt.Sprintf("MAX_STREAM_DATA ID=%v Max=%v", f.id, f.max) +} + +func (f debugFrameMaxStreamData) write(w *packetWriter) bool { + return w.appendMaxStreamDataFrame(f.id, f.max) +} + +func (f debugFrameMaxStreamData) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "max_stream_data"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Int64("maximum", f.max), + ) +} + +// debugFrameMaxStreams is a MAX_STREAMS frame. +type debugFrameMaxStreams struct { + streamType streamType + max int64 +} + +func parseDebugFrameMaxStreams(b []byte) (f debugFrameMaxStreams, n int) { + f.streamType, f.max, n = consumeMaxStreamsFrame(b) + return f, n +} + +func (f debugFrameMaxStreams) String() string { + return fmt.Sprintf("MAX_STREAMS Type=%v Max=%v", f.streamType, f.max) +} + +func (f debugFrameMaxStreams) write(w *packetWriter) bool { + return w.appendMaxStreamsFrame(f.streamType, f.max) +} + +func (f debugFrameMaxStreams) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "max_streams"), + slog.String("stream_type", f.streamType.qlogString()), + slog.Int64("maximum", f.max), + ) +} + +// debugFrameDataBlocked is a DATA_BLOCKED frame. +type debugFrameDataBlocked struct { + max int64 +} + +func parseDebugFrameDataBlocked(b []byte) (f debugFrameDataBlocked, n int) { + f.max, n = consumeDataBlockedFrame(b) + return f, n +} + +func (f debugFrameDataBlocked) String() string { + return fmt.Sprintf("DATA_BLOCKED Max=%v", f.max) +} + +func (f debugFrameDataBlocked) write(w *packetWriter) bool { + return w.appendDataBlockedFrame(f.max) +} + +func (f debugFrameDataBlocked) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "data_blocked"), + slog.Int64("limit", f.max), + ) +} + +// debugFrameStreamDataBlocked is a STREAM_DATA_BLOCKED frame. +type debugFrameStreamDataBlocked struct { + id streamID + max int64 +} + +func parseDebugFrameStreamDataBlocked(b []byte) (f debugFrameStreamDataBlocked, n int) { + f.id, f.max, n = consumeStreamDataBlockedFrame(b) + return f, n +} + +func (f debugFrameStreamDataBlocked) String() string { + return fmt.Sprintf("STREAM_DATA_BLOCKED ID=%v Max=%v", f.id, f.max) +} + +func (f debugFrameStreamDataBlocked) write(w *packetWriter) bool { + return w.appendStreamDataBlockedFrame(f.id, f.max) +} + +func (f debugFrameStreamDataBlocked) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "stream_data_blocked"), + slog.Uint64("stream_id", uint64(f.id)), + slog.Int64("limit", f.max), + ) +} + +// debugFrameStreamsBlocked is a STREAMS_BLOCKED frame. +type debugFrameStreamsBlocked struct { + streamType streamType + max int64 +} + +func parseDebugFrameStreamsBlocked(b []byte) (f debugFrameStreamsBlocked, n int) { + f.streamType, f.max, n = consumeStreamsBlockedFrame(b) + return f, n +} + +func (f debugFrameStreamsBlocked) String() string { + return fmt.Sprintf("STREAMS_BLOCKED Type=%v Max=%v", f.streamType, f.max) +} + +func (f debugFrameStreamsBlocked) write(w *packetWriter) bool { + return w.appendStreamsBlockedFrame(f.streamType, f.max) +} + +func (f debugFrameStreamsBlocked) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "streams_blocked"), + slog.String("stream_type", f.streamType.qlogString()), + slog.Int64("limit", f.max), + ) +} + +// debugFrameNewConnectionID is a NEW_CONNECTION_ID frame. +type debugFrameNewConnectionID struct { + seq int64 + retirePriorTo int64 + connID []byte + token statelessResetToken +} + +func parseDebugFrameNewConnectionID(b []byte) (f debugFrameNewConnectionID, n int) { + f.seq, f.retirePriorTo, f.connID, f.token, n = consumeNewConnectionIDFrame(b) + return f, n +} + +func (f debugFrameNewConnectionID) String() string { + return fmt.Sprintf("NEW_CONNECTION_ID Seq=%v Retire=%v ID=%x Token=%x", f.seq, f.retirePriorTo, f.connID, f.token[:]) +} + +func (f debugFrameNewConnectionID) write(w *packetWriter) bool { + return w.appendNewConnectionIDFrame(f.seq, f.retirePriorTo, f.connID, f.token) +} + +func (f debugFrameNewConnectionID) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "new_connection_id"), + slog.Int64("sequence_number", f.seq), + slog.Int64("retire_prior_to", f.retirePriorTo), + slogHexstring("connection_id", f.connID), + slogHexstring("stateless_reset_token", f.token[:]), + ) +} + +// debugFrameRetireConnectionID is a NEW_CONNECTION_ID frame. +type debugFrameRetireConnectionID struct { + seq int64 +} + +func parseDebugFrameRetireConnectionID(b []byte) (f debugFrameRetireConnectionID, n int) { + f.seq, n = consumeRetireConnectionIDFrame(b) + return f, n +} + +func (f debugFrameRetireConnectionID) String() string { + return fmt.Sprintf("RETIRE_CONNECTION_ID Seq=%v", f.seq) +} + +func (f debugFrameRetireConnectionID) write(w *packetWriter) bool { + return w.appendRetireConnectionIDFrame(f.seq) +} + +func (f debugFrameRetireConnectionID) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "retire_connection_id"), + slog.Int64("sequence_number", f.seq), + ) +} + +// debugFramePathChallenge is a PATH_CHALLENGE frame. +type debugFramePathChallenge struct { + data pathChallengeData +} + +func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) { + f.data, n = consumePathChallengeFrame(b) + return f, n +} + +func (f debugFramePathChallenge) String() string { + return fmt.Sprintf("PATH_CHALLENGE Data=%x", f.data) +} + +func (f debugFramePathChallenge) write(w *packetWriter) bool { + return w.appendPathChallengeFrame(f.data) +} + +func (f debugFramePathChallenge) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "path_challenge"), + slog.String("data", fmt.Sprintf("%x", f.data)), + ) +} + +// debugFramePathResponse is a PATH_RESPONSE frame. +type debugFramePathResponse struct { + data pathChallengeData +} + +func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) { + f.data, n = consumePathResponseFrame(b) + return f, n +} + +func (f debugFramePathResponse) String() string { + return fmt.Sprintf("PATH_RESPONSE Data=%x", f.data) +} + +func (f debugFramePathResponse) write(w *packetWriter) bool { + return w.appendPathResponseFrame(f.data) +} + +func (f debugFramePathResponse) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "path_response"), + slog.String("data", fmt.Sprintf("%x", f.data)), + ) +} + +// debugFrameConnectionCloseTransport is a CONNECTION_CLOSE frame carrying a transport error. +type debugFrameConnectionCloseTransport struct { + code transportError + frameType uint64 + reason string +} + +func parseDebugFrameConnectionCloseTransport(b []byte) (f debugFrameConnectionCloseTransport, n int) { + f.code, f.frameType, f.reason, n = consumeConnectionCloseTransportFrame(b) + return f, n +} + +func (f debugFrameConnectionCloseTransport) String() string { + s := fmt.Sprintf("CONNECTION_CLOSE Code=%v", f.code) + if f.frameType != 0 { + s += fmt.Sprintf(" FrameType=%v", f.frameType) + } + if f.reason != "" { + s += fmt.Sprintf(" Reason=%q", f.reason) + } + return s +} + +func (f debugFrameConnectionCloseTransport) write(w *packetWriter) bool { + return w.appendConnectionCloseTransportFrame(f.code, f.frameType, f.reason) +} + +func (f debugFrameConnectionCloseTransport) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "connection_close"), + slog.String("error_space", "transport"), + slog.Uint64("error_code_value", uint64(f.code)), + slog.String("reason", f.reason), + ) +} + +// debugFrameConnectionCloseApplication is a CONNECTION_CLOSE frame carrying an application error. +type debugFrameConnectionCloseApplication struct { + code uint64 + reason string +} + +func parseDebugFrameConnectionCloseApplication(b []byte) (f debugFrameConnectionCloseApplication, n int) { + f.code, f.reason, n = consumeConnectionCloseApplicationFrame(b) + return f, n +} + +func (f debugFrameConnectionCloseApplication) String() string { + s := fmt.Sprintf("CONNECTION_CLOSE AppCode=%v", f.code) + if f.reason != "" { + s += fmt.Sprintf(" Reason=%q", f.reason) + } + return s +} + +func (f debugFrameConnectionCloseApplication) write(w *packetWriter) bool { + return w.appendConnectionCloseApplicationFrame(f.code, f.reason) +} + +func (f debugFrameConnectionCloseApplication) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "connection_close"), + slog.String("error_space", "application"), + slog.Uint64("error_code_value", uint64(f.code)), + slog.String("reason", f.reason), + ) +} + +// debugFrameHandshakeDone is a HANDSHAKE_DONE frame. +type debugFrameHandshakeDone struct{} + +func parseDebugFrameHandshakeDone(b []byte) (f debugFrameHandshakeDone, n int) { + return f, 1 +} + +func (f debugFrameHandshakeDone) String() string { + return "HANDSHAKE_DONE" +} + +func (f debugFrameHandshakeDone) write(w *packetWriter) bool { + return w.appendHandshakeDoneFrame() +} + +func (f debugFrameHandshakeDone) LogValue() slog.Value { + return slog.GroupValue( + slog.String("frame_type", "handshake_done"), + ) +} diff --git a/src/vendor/golang.org/x/net/quic/gate.go b/src/vendor/golang.org/x/net/quic/gate.go new file mode 100644 index 0000000000..b8b8605e62 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/gate.go @@ -0,0 +1,86 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "context" + +// An gate is a monitor (mutex + condition variable) with one bit of state. +// +// The condition may be either set or unset. +// Lock operations may be unconditional, or wait for the condition to be set. +// Unlock operations record the new state of the condition. +type gate struct { + // When unlocked, exactly one of set or unset contains a value. + // When locked, neither chan contains a value. + set chan struct{} + unset chan struct{} +} + +// newGate returns a new, unlocked gate with the condition unset. +func newGate() gate { + g := newLockedGate() + g.unlock(false) + return g +} + +// newLockedGate returns a new, locked gate. +func newLockedGate() gate { + return gate{ + set: make(chan struct{}, 1), + unset: make(chan struct{}, 1), + } +} + +// lock acquires the gate unconditionally. +// It reports whether the condition is set. +func (g *gate) lock() (set bool) { + select { + case <-g.set: + return true + case <-g.unset: + return false + } +} + +// waitAndLock waits until the condition is set before acquiring the gate. +// If the context expires, waitAndLock returns an error and does not acquire the gate. +func (g *gate) waitAndLock(ctx context.Context) error { + select { + case <-g.set: + return nil + default: + } + select { + case <-g.set: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// lockIfSet acquires the gate if and only if the condition is set. +func (g *gate) lockIfSet() (acquired bool) { + select { + case <-g.set: + return true + default: + return false + } +} + +// unlock sets the condition and releases the gate. +func (g *gate) unlock(set bool) { + if set { + g.set <- struct{}{} + } else { + g.unset <- struct{}{} + } +} + +// unlockFunc sets the condition to the result of f and releases the gate. +// Useful in defers. +func (g *gate) unlockFunc(f func() bool) { + g.unlock(f()) +} diff --git a/src/vendor/golang.org/x/net/quic/idle.go b/src/vendor/golang.org/x/net/quic/idle.go new file mode 100644 index 0000000000..6b1dfd1d25 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/idle.go @@ -0,0 +1,168 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// idleState tracks connection idle events. +// +// Before the handshake is confirmed, the idle timeout is Config.HandshakeTimeout. +// +// After the handshake is confirmed, the idle timeout is +// the minimum of Config.MaxIdleTimeout and the peer's max_idle_timeout transport parameter. +// +// If KeepAlivePeriod is set, keep-alive pings are sent. +// Keep-alives are only sent after the handshake is confirmed. +// +// https://www.rfc-editor.org/rfc/rfc9000#section-10.1 +type idleState struct { + // idleDuration is the negotiated idle timeout for the connection. + idleDuration time.Duration + + // idleTimeout is the time at which the connection will be closed due to inactivity. + idleTimeout time.Time + + // nextTimeout is the time of the next idle event. + // If nextTimeout == idleTimeout, this is the idle timeout. + // Otherwise, this is the keep-alive timeout. + nextTimeout time.Time + + // sentSinceLastReceive is set if we have sent an ack-eliciting packet + // since the last time we received and processed a packet from the peer. + sentSinceLastReceive bool +} + +// receivePeerMaxIdleTimeout handles the peer's max_idle_timeout transport parameter. +func (c *Conn) receivePeerMaxIdleTimeout(peerMaxIdleTimeout time.Duration) { + localMaxIdleTimeout := c.config.maxIdleTimeout() + switch { + case localMaxIdleTimeout == 0: + c.idle.idleDuration = peerMaxIdleTimeout + case peerMaxIdleTimeout == 0: + c.idle.idleDuration = localMaxIdleTimeout + default: + c.idle.idleDuration = min(localMaxIdleTimeout, peerMaxIdleTimeout) + } +} + +func (c *Conn) idleHandlePacketReceived(now time.Time) { + if !c.handshakeConfirmed.isSet() { + return + } + // "An endpoint restarts its idle timer when a packet from its peer is + // received and processed successfully." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3 + c.idle.sentSinceLastReceive = false + c.restartIdleTimer(now) +} + +func (c *Conn) idleHandlePacketSent(now time.Time, sent *sentPacket) { + // "An endpoint also restarts its idle timer when sending an ack-eliciting packet + // if no other ack-eliciting packets have been sent since + // last receiving and processing a packet." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-3 + if c.idle.sentSinceLastReceive || !sent.ackEliciting || !c.handshakeConfirmed.isSet() { + return + } + c.idle.sentSinceLastReceive = true + c.restartIdleTimer(now) +} + +func (c *Conn) restartIdleTimer(now time.Time) { + if !c.isAlive() { + // Connection is closing, disable timeouts. + c.idle.idleTimeout = time.Time{} + c.idle.nextTimeout = time.Time{} + return + } + var idleDuration time.Duration + if c.handshakeConfirmed.isSet() { + idleDuration = c.idle.idleDuration + } else { + idleDuration = c.config.handshakeTimeout() + } + if idleDuration == 0 { + c.idle.idleTimeout = time.Time{} + } else { + // "[...] endpoints MUST increase the idle timeout period to be + // at least three times the current Probe Timeout (PTO)." + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-4 + idleDuration = max(idleDuration, 3*c.loss.ptoPeriod()) + c.idle.idleTimeout = now.Add(idleDuration) + } + // Set the time of our next event: + // The idle timer if no keep-alive is set, or the keep-alive timer if one is. + c.idle.nextTimeout = c.idle.idleTimeout + keepAlive := c.config.keepAlivePeriod() + switch { + case !c.handshakeConfirmed.isSet(): + // We do not send keep-alives before the handshake is complete. + case keepAlive <= 0: + // Keep-alives are not enabled. + case c.idle.sentSinceLastReceive: + // We have sent an ack-eliciting packet to the peer. + // If they don't acknowledge it, loss detection will follow up with PTO probes, + // which will function as keep-alives. + // We don't need to send further pings. + case idleDuration == 0: + // The connection does not have a negotiated idle timeout. + // Send keep-alives anyway, since they may be required to keep middleboxes + // from losing state. + c.idle.nextTimeout = now.Add(keepAlive) + default: + // Schedule our next keep-alive. + // If our configured keep-alive period is greater than half the negotiated + // connection idle timeout, we reduce the keep-alive period to half + // the idle timeout to ensure we have time for the ping to arrive. + c.idle.nextTimeout = now.Add(min(keepAlive, idleDuration/2)) + } +} + +func (c *Conn) appendKeepAlive(now time.Time) bool { + if c.idle.nextTimeout.IsZero() || c.idle.nextTimeout.After(now) { + return true // timer has not expired + } + if c.idle.nextTimeout.Equal(c.idle.idleTimeout) { + return true // no keepalive timer set, only idle + } + if c.idle.sentSinceLastReceive { + return true // already sent an ack-eliciting packet + } + if c.w.sent.ackEliciting { + return true // this packet is already ack-eliciting + } + // Send an ack-eliciting PING frame to the peer to keep the connection alive. + return c.w.appendPingFrame() +} + +var errHandshakeTimeout error = localTransportError{ + code: errConnectionRefused, + reason: "handshake timeout", +} + +func (c *Conn) idleAdvance(now time.Time) (shouldExit bool) { + if c.idle.idleTimeout.IsZero() || now.Before(c.idle.idleTimeout) { + return false + } + c.idle.idleTimeout = time.Time{} + c.idle.nextTimeout = time.Time{} + if !c.handshakeConfirmed.isSet() { + // Handshake timeout has expired. + // If we're a server, we're refusing the too-slow client. + // If we're a client, we're giving up. + // In either case, we're going to send a CONNECTION_CLOSE frame and + // enter the closing state rather than unceremoniously dropping the connection, + // since the peer might still be trying to complete the handshake. + c.abort(now, errHandshakeTimeout) + return false + } + // Idle timeout has expired. + // + // "[...] the connection is silently closed and its state is discarded [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-10.1-1 + return true +} diff --git a/src/vendor/golang.org/x/net/quic/log.go b/src/vendor/golang.org/x/net/quic/log.go new file mode 100644 index 0000000000..eee2b5fd61 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/log.go @@ -0,0 +1,67 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "fmt" + "os" + "strings" +) + +var logPackets bool + +// Parse GODEBUG settings. +// +// GODEBUG=quiclogpackets=1 -- log every packet sent and received. +func init() { + s := os.Getenv("GODEBUG") + for len(s) > 0 { + var opt string + opt, s, _ = strings.Cut(s, ",") + switch opt { + case "quiclogpackets=1": + logPackets = true + } + } +} + +func logInboundLongPacket(c *Conn, p longPacket) { + if !logPackets { + return + } + prefix := c.String() + fmt.Printf("%v recv %v %v\n", prefix, p.ptype, p.num) + logFrames(prefix+" <- ", p.payload) +} + +func logInboundShortPacket(c *Conn, p shortPacket) { + if !logPackets { + return + } + prefix := c.String() + fmt.Printf("%v recv 1-RTT %v\n", prefix, p.num) + logFrames(prefix+" <- ", p.payload) +} + +func logSentPacket(c *Conn, ptype packetType, pnum packetNumber, src, dst, payload []byte) { + if !logPackets || len(payload) == 0 { + return + } + prefix := c.String() + fmt.Printf("%v send %v %v\n", prefix, ptype, pnum) + logFrames(prefix+" -> ", payload) +} + +func logFrames(prefix string, payload []byte) { + for len(payload) > 0 { + f, n := parseDebugFrame(payload) + if n < 0 { + fmt.Printf("%vBAD DATA\n", prefix) + break + } + payload = payload[n:] + fmt.Printf("%v%v\n", prefix, f) + } +} diff --git a/src/vendor/golang.org/x/net/quic/loss.go b/src/vendor/golang.org/x/net/quic/loss.go new file mode 100644 index 0000000000..95feaba2d4 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/loss.go @@ -0,0 +1,521 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "log/slog" + "math" + "time" +) + +type lossState struct { + side connSide + + // True when the handshake is confirmed. + // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2 + handshakeConfirmed bool + + // Peer's max_ack_delay transport parameter. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.28.1 + maxAckDelay time.Duration + + // Time of the next event: PTO expiration (if ptoTimerArmed is true), + // or loss detection. + // The connection must call lossState.advance when the timer expires. + timer time.Time + + // True when the PTO timer is set. + ptoTimerArmed bool + + // True when the PTO timer has expired and a probe packet has not yet been sent. + ptoExpired bool + + // Count of PTO expirations since the lack received acknowledgement. + // https://www.rfc-editor.org/rfc/rfc9002#section-6.2.1-9 + ptoBackoffCount int + + // Anti-amplification limit: Three times the amount of data received from + // the peer, less the amount of data sent. + // + // Set to antiAmplificationUnlimited (MaxInt) to disable the limit. + // The limit is always disabled for clients, and for servers after the + // peer's address is validated. + // + // Anti-amplification is per-address; this will need to change if/when we + // support address migration. + // + // https://www.rfc-editor.org/rfc/rfc9000#section-8-2 + antiAmplificationLimit int + + // Count of non-ack-eliciting packets (ACKs) sent since the last ack-eliciting one. + consecutiveNonAckElicitingPackets int + + rtt rttState + pacer pacerState + cc *ccReno + + // Per-space loss detection state. + spaces [numberSpaceCount]struct { + sentPacketList + maxAcked packetNumber + lastAckEliciting packetNumber + } + + // Temporary state used when processing an ACK frame. + ackFrameRTT time.Duration // RTT from latest packet in frame + ackFrameContainsAckEliciting bool // newly acks an ack-eliciting packet? +} + +const antiAmplificationUnlimited = math.MaxInt + +func (c *lossState) init(side connSide, maxDatagramSize int, now time.Time) { + c.side = side + if side == clientSide { + // Clients don't have an anti-amplification limit. + c.antiAmplificationLimit = antiAmplificationUnlimited + } + c.rtt.init() + c.cc = newReno(maxDatagramSize) + c.pacer.init(now, c.cc.congestionWindow, timerGranularity) + + // Peer's assumed max_ack_delay, prior to receiving transport parameters. + // https://www.rfc-editor.org/rfc/rfc9000#section-18.2 + c.maxAckDelay = 25 * time.Millisecond + + for space := range c.spaces { + c.spaces[space].maxAcked = -1 + c.spaces[space].lastAckEliciting = -1 + } +} + +// setMaxAckDelay sets the max_ack_delay transport parameter received from the peer. +func (c *lossState) setMaxAckDelay(d time.Duration) { + if d >= (1<<14)*time.Millisecond { + // Values of 2^14 or greater are invalid. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.28.1 + return + } + c.maxAckDelay = d +} + +// confirmHandshake indicates the handshake has been confirmed. +func (c *lossState) confirmHandshake() { + c.handshakeConfirmed = true +} + +// validateClientAddress disables the anti-amplification limit after +// a server validates a client's address. +func (c *lossState) validateClientAddress() { + c.antiAmplificationLimit = antiAmplificationUnlimited +} + +// minDatagramSize is the minimum datagram size permitted by +// anti-amplification protection. +// +// Defining a minimum size avoids the case where, say, anti-amplification +// technically allows us to send a 1-byte datagram, but no such datagram +// can be constructed. +const minPacketSize = 128 + +type ccLimit int + +const ( + ccOK = ccLimit(iota) // OK to send + ccBlocked // sending blocked by anti-amplification + ccLimited // sending blocked by congestion control + ccPaced // sending allowed by congestion, but delayed by pacer +) + +// sendLimit reports whether sending is possible at this time. +// When sending is pacing limited, it returns the next time a packet may be sent. +func (c *lossState) sendLimit(now time.Time) (limit ccLimit, next time.Time) { + if c.antiAmplificationLimit < minPacketSize { + // When at the anti-amplification limit, we may not send anything. + return ccBlocked, time.Time{} + } + if c.ptoExpired { + // On PTO expiry, send a probe. + return ccOK, time.Time{} + } + if !c.cc.canSend() { + // Congestion control blocks sending. + return ccLimited, time.Time{} + } + if c.cc.bytesInFlight == 0 { + // If no bytes are in flight, send packet unpaced. + return ccOK, time.Time{} + } + canSend, next := c.pacer.canSend(now) + if !canSend { + // Pacer blocks sending. + return ccPaced, next + } + return ccOK, time.Time{} +} + +// maxSendSize reports the maximum datagram size that may be sent. +func (c *lossState) maxSendSize() int { + return min(c.antiAmplificationLimit, c.cc.maxDatagramSize) +} + +// advance is called when time passes. +// The lossf function is called for each packet newly detected as lost. +func (c *lossState) advance(now time.Time, lossf func(numberSpace, *sentPacket, packetFate)) { + c.pacer.advance(now, c.cc.congestionWindow, c.rtt.smoothedRTT) + if c.ptoTimerArmed && !c.timer.IsZero() && !c.timer.After(now) { + c.ptoExpired = true + c.timer = time.Time{} + c.ptoBackoffCount++ + } + c.detectLoss(now, lossf) +} + +// nextNumber returns the next packet number to use in a space. +func (c *lossState) nextNumber(space numberSpace) packetNumber { + return c.spaces[space].nextNum +} + +// skipNumber skips a packet number as a defense against optimistic ACK attacks. +func (c *lossState) skipNumber(now time.Time, space numberSpace) { + sent := newSentPacket() + sent.num = c.spaces[space].nextNum + sent.time = now + sent.state = sentPacketUnsent + c.spaces[space].add(sent) +} + +// packetSent records a sent packet. +func (c *lossState) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) { + sent.time = now + c.spaces[space].add(sent) + size := sent.size + if c.antiAmplificationLimit != antiAmplificationUnlimited { + c.antiAmplificationLimit = max(0, c.antiAmplificationLimit-size) + } + if sent.inFlight { + c.cc.packetSent(now, log, space, sent) + c.pacer.packetSent(now, size, c.cc.congestionWindow, c.rtt.smoothedRTT) + if sent.ackEliciting { + c.spaces[space].lastAckEliciting = sent.num + c.ptoExpired = false // reset expired PTO timer after sending probe + } + c.scheduleTimer(now) + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } + } + if sent.ackEliciting { + c.consecutiveNonAckElicitingPackets = 0 + } else { + c.consecutiveNonAckElicitingPackets++ + } +} + +// datagramReceived records a datagram (not packet!) received from the peer. +func (c *lossState) datagramReceived(now time.Time, size int) { + if c.antiAmplificationLimit != antiAmplificationUnlimited { + c.antiAmplificationLimit += 3 * size + // Reset the PTO timer, possibly to a point in the past, in which + // case the caller should execute it immediately. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-2 + c.scheduleTimer(now) + if c.ptoTimerArmed && !c.timer.IsZero() && !c.timer.After(now) { + c.ptoExpired = true + c.timer = time.Time{} + } + } +} + +// receiveAckStart starts processing an ACK frame. +// Call receiveAckRange for each range in the frame. +// Call receiveAckFrameEnd after all ranges are processed. +func (c *lossState) receiveAckStart() { + c.ackFrameContainsAckEliciting = false + c.ackFrameRTT = -1 +} + +// receiveAckRange processes a range within an ACK frame. +// The ackf function is called for each newly-acknowledged packet. +func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex int, start, end packetNumber, ackf func(numberSpace, *sentPacket, packetFate)) error { + // Limit our range to the intersection of the ACK range and + // the in-flight packets we have state for. + if s := c.spaces[space].start(); start < s { + start = s + } + if e := c.spaces[space].end(); end > e { + return localTransportError{ + code: errProtocolViolation, + reason: "acknowledgement for unsent packet", + } + } + if start >= end { + return nil + } + if rangeIndex == 0 { + // If the latest packet in the ACK frame is newly-acked, + // record the RTT in c.ackFrameRTT. + sent := c.spaces[space].num(end - 1) + if sent.state == sentPacketSent { + c.ackFrameRTT = max(0, now.Sub(sent.time)) + } + } + for pnum := start; pnum < end; pnum++ { + sent := c.spaces[space].num(pnum) + if sent.state == sentPacketUnsent { + return localTransportError{ + code: errProtocolViolation, + reason: "acknowledgement for unsent packet", + } + } + if sent.state != sentPacketSent { + continue + } + // This is a newly-acknowledged packet. + if pnum > c.spaces[space].maxAcked { + c.spaces[space].maxAcked = pnum + } + sent.state = sentPacketAcked + c.cc.packetAcked(now, sent) + ackf(space, sent, packetAcked) + if sent.ackEliciting { + c.ackFrameContainsAckEliciting = true + } + } + return nil +} + +// receiveAckEnd finishes processing an ack frame. +// The lossf function is called for each packet newly detected as lost. +func (c *lossState) receiveAckEnd(now time.Time, log *slog.Logger, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) { + c.spaces[space].sentPacketList.clean() + // Update the RTT sample when the largest acknowledged packet in the ACK frame + // is newly acknowledged, and at least one newly acknowledged packet is ack-eliciting. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.1-2.2 + if c.ackFrameRTT >= 0 && c.ackFrameContainsAckEliciting { + c.rtt.updateSample(now, c.handshakeConfirmed, space, c.ackFrameRTT, ackDelay, c.maxAckDelay) + } + // Reset the PTO backoff. + // Exception: A client does not reset the backoff on acks for Initial packets. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1-9 + if !(c.side == clientSide && space == initialSpace) { + c.ptoBackoffCount = 0 + } + // If the client has set a PTO timer with no packets in flight + // we want to restart that timer now. Clearing c.timer does this. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-3 + c.timer = time.Time{} + c.detectLoss(now, lossf) + c.cc.packetBatchEnd(now, log, space, &c.rtt, c.maxAckDelay) + + if logEnabled(log, QLogLevelPacket) { + var ssthresh slog.Attr + if c.cc.slowStartThreshold != math.MaxInt { + ssthresh = slog.Int("ssthresh", c.cc.slowStartThreshold) + } + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:metrics_updated", + slog.Duration("min_rtt", c.rtt.minRTT), + slog.Duration("smoothed_rtt", c.rtt.smoothedRTT), + slog.Duration("latest_rtt", c.rtt.latestRTT), + slog.Duration("rtt_variance", c.rtt.rttvar), + slog.Int("congestion_window", c.cc.congestionWindow), + slog.Int("bytes_in_flight", c.cc.bytesInFlight), + ssthresh, + ) + } +} + +// discardPackets declares that packets within a number space will not be delivered +// and that data contained in them should be resent. +// For example, after receiving a Retry packet we discard already-sent Initial packets. +func (c *lossState) discardPackets(space numberSpace, log *slog.Logger, lossf func(numberSpace, *sentPacket, packetFate)) { + for i := 0; i < c.spaces[space].size; i++ { + sent := c.spaces[space].nth(i) + if sent.state != sentPacketSent { + // This should not be possible, since we only discard packets + // in spaces which have never received an ack, but check anyway. + continue + } + sent.state = sentPacketLost + c.cc.packetDiscarded(sent) + lossf(numberSpace(space), sent, packetLost) + } + c.spaces[space].clean() + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } +} + +// discardKeys is called when dropping packet protection keys for a number space. +func (c *lossState) discardKeys(now time.Time, log *slog.Logger, space numberSpace) { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.4 + for i := 0; i < c.spaces[space].size; i++ { + sent := c.spaces[space].nth(i) + if sent.state != sentPacketSent { + continue + } + c.cc.packetDiscarded(sent) + } + c.spaces[space].discard() + c.spaces[space].maxAcked = -1 + c.spaces[space].lastAckEliciting = -1 + c.scheduleTimer(now) + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } +} + +func (c *lossState) lossDuration() time.Duration { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2 + return max((9*max(c.rtt.smoothedRTT, c.rtt.latestRTT))/8, timerGranularity) +} + +func (c *lossState) detectLoss(now time.Time, lossf func(numberSpace, *sentPacket, packetFate)) { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.1-1 + const lossThreshold = 3 + + lossTime := now.Add(-c.lossDuration()) + for space := numberSpace(0); space < numberSpaceCount; space++ { + for i := 0; i < c.spaces[space].size; i++ { + sent := c.spaces[space].nth(i) + if sent.state != sentPacketSent { + continue + } + // RFC 9002 Section 6.1 states that a packet is only declared lost if it + // is "in flight", which excludes packets that contain only ACK frames. + // However, we need some way to determine when to drop state for ACK-only + // packets, and the loss algorithm in Appendix A handles loss detection of + // not-in-flight packets identically to all others, so we do the same here. + switch { + case c.spaces[space].maxAcked-sent.num >= lossThreshold: + // Packet threshold + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.1 + fallthrough + case sent.num <= c.spaces[space].maxAcked && !sent.time.After(lossTime): + // Time threshold + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2 + sent.state = sentPacketLost + lossf(space, sent, packetLost) + if sent.inFlight { + c.cc.packetLost(now, space, sent, &c.rtt) + } + } + if sent.state != sentPacketLost { + break + } + } + c.spaces[space].clean() + } + c.scheduleTimer(now) +} + +// scheduleTimer sets the loss or PTO timer. +// +// The connection is responsible for arranging for advance to be called after +// the timer expires. +// +// The timer may be set to a point in the past, in which advance should be called +// immediately. We don't do this here, because executing the timer can cause +// packet loss events, and it's simpler for the connection if loss events only +// occur when advancing time. +func (c *lossState) scheduleTimer(now time.Time) { + c.ptoTimerArmed = false + + // Loss timer for sent packets. + // The loss timer is only started once a later packet has been acknowledged, + // and takes precedence over the PTO timer. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2 + var oldestPotentiallyLost time.Time + for space := numberSpace(0); space < numberSpaceCount; space++ { + if c.spaces[space].size > 0 && c.spaces[space].start() <= c.spaces[space].maxAcked { + firstTime := c.spaces[space].nth(0).time + if oldestPotentiallyLost.IsZero() || firstTime.Before(oldestPotentiallyLost) { + oldestPotentiallyLost = firstTime + } + } + } + if !oldestPotentiallyLost.IsZero() { + c.timer = oldestPotentiallyLost.Add(c.lossDuration()) + return + } + + // PTO timer. + if c.ptoExpired { + // PTO timer has expired, don't restart it until we send a probe. + c.timer = time.Time{} + return + } + if c.antiAmplificationLimit >= 0 && c.antiAmplificationLimit < minPacketSize { + // Server is at its anti-amplification limit and can't send any more data. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-1 + c.timer = time.Time{} + return + } + // Timer starts at the most recently sent ack-eliciting packet. + // Prior to confirming the handshake, we consider the Initial and Handshake + // number spaces; after, we consider only Application Data. + var last time.Time + if !c.handshakeConfirmed { + for space := initialSpace; space <= handshakeSpace; space++ { + sent := c.spaces[space].num(c.spaces[space].lastAckEliciting) + if sent == nil { + continue + } + if last.IsZero() || last.After(sent.time) { + last = sent.time + } + } + } else { + sent := c.spaces[appDataSpace].num(c.spaces[appDataSpace].lastAckEliciting) + if sent != nil { + last = sent.time + } + } + if last.IsZero() && + c.side == clientSide && + c.spaces[handshakeSpace].maxAcked < 0 && + !c.handshakeConfirmed { + // The client must always set a PTO timer prior to receiving an ack for a + // handshake packet or the handshake being confirmed. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1 + if !c.timer.IsZero() { + // If c.timer is non-zero here, we've already set the PTO timer and + // should leave it as-is rather than moving it forward. + c.ptoTimerArmed = true + return + } + last = now + } else if last.IsZero() { + c.timer = time.Time{} + return + } + c.timer = last.Add(c.ptoPeriod()) + c.ptoTimerArmed = true +} + +func (c *lossState) ptoPeriod() time.Duration { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 + return c.ptoBasePeriod() << c.ptoBackoffCount +} + +func (c *lossState) ptoBasePeriod() time.Duration { + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1 + pto := c.rtt.smoothedRTT + max(4*c.rtt.rttvar, timerGranularity) + if c.handshakeConfirmed { + // The max_ack_delay is the maximum amount of time the peer might delay sending + // an ack to us. We only take it into account for the Application Data space. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.1-4 + pto += c.maxAckDelay + } + return pto +} + +func logBytesInFlight(log *slog.Logger, bytesInFlight int) { + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:metrics_updated", + slog.Int("bytes_in_flight", bytesInFlight), + ) +} diff --git a/src/vendor/golang.org/x/net/quic/math.go b/src/vendor/golang.org/x/net/quic/math.go new file mode 100644 index 0000000000..d1e8a80025 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/math.go @@ -0,0 +1,12 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +func abs[T ~int | ~int64](a T) T { + if a < 0 { + return -a + } + return a +} diff --git a/src/vendor/golang.org/x/net/quic/pacer.go b/src/vendor/golang.org/x/net/quic/pacer.go new file mode 100644 index 0000000000..5891f42597 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/pacer.go @@ -0,0 +1,129 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// A pacerState controls the rate at which packets are sent using a leaky-bucket rate limiter. +// +// The pacer limits the maximum size of a burst of packets. +// When a burst exceeds this limit, it spreads subsequent packets +// over time. +// +// The bucket is initialized to the maximum burst size (ten packets by default), +// and fills at the rate: +// +// 1.25 * congestion_window / smoothed_rtt +// +// A sender can send one congestion window of packets per RTT, +// since the congestion window consumed by each packet is returned +// one round-trip later by the responding ack. +// The pacer permits sending at slightly faster than this rate to +// avoid underutilizing the congestion window. +// +// The pacer permits the bucket to become negative, and permits +// sending when non-negative. This biases slightly in favor of +// sending packets over limiting them, and permits bursts one +// packet greater than the configured maximum, but permits the pacer +// to be ignorant of the maximum packet size. +// +// https://www.rfc-editor.org/rfc/rfc9002.html#section-7.7 +type pacerState struct { + bucket int // measured in bytes + maxBucket int + timerGranularity time.Duration + lastUpdate time.Time + nextSend time.Time +} + +func (p *pacerState) init(now time.Time, maxBurst int, timerGranularity time.Duration) { + // Bucket is limited to maximum burst size, which is the initial congestion window. + // https://www.rfc-editor.org/rfc/rfc9002#section-7.7-2 + p.maxBucket = maxBurst + p.bucket = p.maxBucket + p.timerGranularity = timerGranularity + p.lastUpdate = now + p.nextSend = now +} + +// pacerBytesForInterval returns the number of bytes permitted over an interval. +// +// rate = 1.25 * congestion_window / smoothed_rtt +// bytes = interval * rate +// +// https://www.rfc-editor.org/rfc/rfc9002#section-7.7-6 +func pacerBytesForInterval(interval time.Duration, congestionWindow int, rtt time.Duration) int { + bytes := (int64(interval) * int64(congestionWindow)) / int64(rtt) + bytes = (bytes * 5) / 4 // bytes *= 1.25 + return int(bytes) +} + +// pacerIntervalForBytes returns the amount of time required for a number of bytes. +// +// time_per_byte = (smoothed_rtt / congestion_window) / 1.25 +// interval = time_per_byte * bytes +// +// https://www.rfc-editor.org/rfc/rfc9002#section-7.7-8 +func pacerIntervalForBytes(bytes int, congestionWindow int, rtt time.Duration) time.Duration { + interval := (int64(rtt) * int64(bytes)) / int64(congestionWindow) + interval = (interval * 4) / 5 // interval /= 1.25 + return time.Duration(interval) +} + +// advance is called when time passes. +func (p *pacerState) advance(now time.Time, congestionWindow int, rtt time.Duration) { + elapsed := now.Sub(p.lastUpdate) + if elapsed < 0 { + // Time has gone backward? + elapsed = 0 + p.nextSend = now // allow a packet through to get back on track + if p.bucket < 0 { + p.bucket = 0 + } + } + p.lastUpdate = now + if rtt == 0 { + // Avoid divide by zero in the implausible case that we measure no RTT. + p.bucket = p.maxBucket + return + } + // Refill the bucket. + delta := pacerBytesForInterval(elapsed, congestionWindow, rtt) + p.bucket = min(p.bucket+delta, p.maxBucket) +} + +// packetSent is called to record transmission of a packet. +func (p *pacerState) packetSent(now time.Time, size, congestionWindow int, rtt time.Duration) { + p.bucket -= size + if p.bucket < -congestionWindow { + // Never allow the bucket to fall more than one congestion window in arrears. + // We can only fall this far behind if the sender is sending unpaced packets, + // the congestion window has been exceeded, or the RTT is less than the + // timer granularity. + // + // Limiting the minimum bucket size limits the maximum pacer delay + // to RTT/1.25. + p.bucket = -congestionWindow + } + if p.bucket >= 0 { + p.nextSend = now + return + } + // Next send occurs when the bucket has refilled to 0. + delay := pacerIntervalForBytes(-p.bucket, congestionWindow, rtt) + p.nextSend = now.Add(delay) +} + +// canSend reports whether a packet can be sent now. +// If it returns false, next is the time when the next packet can be sent. +func (p *pacerState) canSend(now time.Time) (canSend bool, next time.Time) { + // If the next send time is within the timer granularity, send immediately. + if p.nextSend.After(now.Add(p.timerGranularity)) { + return false, p.nextSend + } + return true, time.Time{} +} diff --git a/src/vendor/golang.org/x/net/quic/packet.go b/src/vendor/golang.org/x/net/quic/packet.go new file mode 100644 index 0000000000..b9fa333c54 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet.go @@ -0,0 +1,267 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "encoding/binary" + "fmt" + + "golang.org/x/net/internal/quic/quicwire" +) + +// packetType is a QUIC packet type. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17 +type packetType byte + +const ( + packetTypeInvalid = packetType(iota) + packetTypeInitial + packetType0RTT + packetTypeHandshake + packetTypeRetry + packetType1RTT + packetTypeVersionNegotiation +) + +func (p packetType) String() string { + switch p { + case packetTypeInitial: + return "Initial" + case packetType0RTT: + return "0-RTT" + case packetTypeHandshake: + return "Handshake" + case packetTypeRetry: + return "Retry" + case packetType1RTT: + return "1-RTT" + } + return fmt.Sprintf("unknown packet type %v", byte(p)) +} + +func (p packetType) qlogString() string { + switch p { + case packetTypeInitial: + return "initial" + case packetType0RTT: + return "0RTT" + case packetTypeHandshake: + return "handshake" + case packetTypeRetry: + return "retry" + case packetType1RTT: + return "1RTT" + } + return "unknown" +} + +// Bits set in the first byte of a packet. +const ( + headerFormLong = 0x80 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.2.1 + headerFormShort = 0x00 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.3.1-4.2.1 + fixedBit = 0x40 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.4.1 + reservedLongBits = 0x0c // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1 + reserved1RTTBits = 0x18 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1 + keyPhaseBit = 0x04 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.10.1 +) + +// Long Packet Type bits. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.6.1 +const ( + longPacketTypeInitial = 0 << 4 + longPacketType0RTT = 1 << 4 + longPacketTypeHandshake = 2 << 4 + longPacketTypeRetry = 3 << 4 +) + +// Frame types. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-19 +const ( + frameTypePadding = 0x00 + frameTypePing = 0x01 + frameTypeAck = 0x02 + frameTypeAckECN = 0x03 + frameTypeResetStream = 0x04 + frameTypeStopSending = 0x05 + frameTypeCrypto = 0x06 + frameTypeNewToken = 0x07 + frameTypeStreamBase = 0x08 // low three bits carry stream flags + frameTypeMaxData = 0x10 + frameTypeMaxStreamData = 0x11 + frameTypeMaxStreamsBidi = 0x12 + frameTypeMaxStreamsUni = 0x13 + frameTypeDataBlocked = 0x14 + frameTypeStreamDataBlocked = 0x15 + frameTypeStreamsBlockedBidi = 0x16 + frameTypeStreamsBlockedUni = 0x17 + frameTypeNewConnectionID = 0x18 + frameTypeRetireConnectionID = 0x19 + frameTypePathChallenge = 0x1a + frameTypePathResponse = 0x1b + frameTypeConnectionCloseTransport = 0x1c + frameTypeConnectionCloseApplication = 0x1d + frameTypeHandshakeDone = 0x1e +) + +// The low three bits of STREAM frames. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.8 +const ( + streamOffBit = 0x04 + streamLenBit = 0x02 + streamFinBit = 0x01 +) + +// Maximum length of a connection ID. +const maxConnIDLen = 20 + +// isLongHeader returns true if b is the first byte of a long header. +func isLongHeader(b byte) bool { + return b&headerFormLong == headerFormLong +} + +// getPacketType returns the type of a packet. +func getPacketType(b []byte) packetType { + if len(b) == 0 { + return packetTypeInvalid + } + if !isLongHeader(b[0]) { + if b[0]&fixedBit != fixedBit { + return packetTypeInvalid + } + return packetType1RTT + } + if len(b) < 5 { + return packetTypeInvalid + } + if b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 { + // Version Negotiation packets don't necessarily set the fixed bit. + return packetTypeVersionNegotiation + } + if b[0]&fixedBit != fixedBit { + return packetTypeInvalid + } + switch b[0] & 0x30 { + case longPacketTypeInitial: + return packetTypeInitial + case longPacketType0RTT: + return packetType0RTT + case longPacketTypeHandshake: + return packetTypeHandshake + case longPacketTypeRetry: + return packetTypeRetry + } + return packetTypeInvalid +} + +// dstConnIDForDatagram returns the destination connection ID field of the +// first QUIC packet in a datagram. +func dstConnIDForDatagram(pkt []byte) (id []byte, ok bool) { + if len(pkt) < 1 { + return nil, false + } + var n int + var b []byte + if isLongHeader(pkt[0]) { + if len(pkt) < 6 { + return nil, false + } + n = int(pkt[5]) + b = pkt[6:] + } else { + n = connIDLen + b = pkt[1:] + } + if len(b) < n { + return nil, false + } + return b[:n], true +} + +// parseVersionNegotiation parses a Version Negotiation packet. +// The returned versions is a slice of big-endian uint32s. +// It returns (nil, nil, nil) for an invalid packet. +func parseVersionNegotiation(pkt []byte) (dstConnID, srcConnID, versions []byte) { + p, ok := parseGenericLongHeaderPacket(pkt) + if !ok { + return nil, nil, nil + } + if len(p.data)%4 != 0 { + return nil, nil, nil + } + return p.dstConnID, p.srcConnID, p.data +} + +// appendVersionNegotiation appends a Version Negotiation packet to pkt, +// returning the result. +func appendVersionNegotiation(pkt, dstConnID, srcConnID []byte, versions ...uint32) []byte { + pkt = append(pkt, headerFormLong|fixedBit) // header byte + pkt = append(pkt, 0, 0, 0, 0) // Version (0 for Version Negotiation) + pkt = quicwire.AppendUint8Bytes(pkt, dstConnID) // Destination Connection ID + pkt = quicwire.AppendUint8Bytes(pkt, srcConnID) // Source Connection ID + for _, v := range versions { + pkt = binary.BigEndian.AppendUint32(pkt, v) // Supported Version + } + return pkt +} + +// A longPacket is a long header packet. +type longPacket struct { + ptype packetType + version uint32 + num packetNumber + dstConnID []byte + srcConnID []byte + payload []byte + + // The extra data depends on the packet type: + // Initial: Token. + // Retry: Retry token and integrity tag. + extra []byte +} + +// A shortPacket is a short header (1-RTT) packet. +type shortPacket struct { + num packetNumber + payload []byte +} + +// A genericLongPacket is a long header packet of an arbitrary QUIC version. +// https://www.rfc-editor.org/rfc/rfc8999#section-5.1 +type genericLongPacket struct { + version uint32 + dstConnID []byte + srcConnID []byte + data []byte +} + +func parseGenericLongHeaderPacket(b []byte) (p genericLongPacket, ok bool) { + if len(b) < 5 || !isLongHeader(b[0]) { + return genericLongPacket{}, false + } + b = b[1:] + // Version (32), + var n int + p.version, n = quicwire.ConsumeUint32(b) + if n < 0 { + return genericLongPacket{}, false + } + b = b[n:] + // Destination Connection ID Length (8), + // Destination Connection ID (0..2048), + p.dstConnID, n = quicwire.ConsumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 2048/8 { + return genericLongPacket{}, false + } + b = b[n:] + // Source Connection ID Length (8), + // Source Connection ID (0..2048), + p.srcConnID, n = quicwire.ConsumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 2048/8 { + return genericLongPacket{}, false + } + b = b[n:] + p.data = b + return p, true +} diff --git a/src/vendor/golang.org/x/net/quic/packet_number.go b/src/vendor/golang.org/x/net/quic/packet_number.go new file mode 100644 index 0000000000..9e9f0ad003 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet_number.go @@ -0,0 +1,72 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A packetNumber is a QUIC packet number. +// Packet numbers are integers in the range [0, 2^62-1]. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-12.3 +type packetNumber int64 + +const maxPacketNumber = 1<<62 - 1 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-1 + +// decodePacketNumber decodes a truncated packet number, given +// the largest acknowledged packet number in this number space, +// the truncated number received in a packet, and the size of the +// number received in bytes. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1 +// https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3 +func decodePacketNumber(largest, truncated packetNumber, numLenInBytes int) packetNumber { + expected := largest + 1 + win := packetNumber(1) << (uint(numLenInBytes) * 8) + hwin := win / 2 + mask := win - 1 + candidate := (expected &^ mask) | truncated + if candidate <= expected-hwin && candidate < (1<<62)-win { + return candidate + win + } + if candidate > expected+hwin && candidate >= win { + return candidate - win + } + return candidate +} + +// appendPacketNumber appends an encoded packet number to b. +// The packet number must be larger than the largest acknowledged packet number. +// When no packets have been acknowledged yet, largestAck is -1. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-5 +func appendPacketNumber(b []byte, pnum, largestAck packetNumber) []byte { + switch packetNumberLength(pnum, largestAck) { + case 1: + return append(b, byte(pnum)) + case 2: + return append(b, byte(pnum>>8), byte(pnum)) + case 3: + return append(b, byte(pnum>>16), byte(pnum>>8), byte(pnum)) + default: + return append(b, byte(pnum>>24), byte(pnum>>16), byte(pnum>>8), byte(pnum)) + } +} + +// packetNumberLength returns the minimum length, in bytes, needed to encode +// a packet number given the largest acknowledged packet number. +// The packet number must be larger than the largest acknowledged packet number. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-5 +func packetNumberLength(pnum, largestAck packetNumber) int { + d := pnum - largestAck + switch { + case d < 0x80: + return 1 + case d < 0x8000: + return 2 + case d < 0x800000: + return 3 + default: + return 4 + } +} diff --git a/src/vendor/golang.org/x/net/quic/packet_parser.go b/src/vendor/golang.org/x/net/quic/packet_parser.go new file mode 100644 index 0000000000..265c4aeb3a --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet_parser.go @@ -0,0 +1,517 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "golang.org/x/net/internal/quic/quicwire" + +// parseLongHeaderPacket parses a QUIC long header packet. +// +// It does not parse Version Negotiation packets. +// +// On input, pkt contains a long header packet (possibly followed by more packets), +// k the decryption keys for the packet, and pnumMax the largest packet number seen +// in the number space of this packet. +// +// parseLongHeaderPacket returns the parsed packet with protection removed +// and its length in bytes. +// +// It returns an empty packet and -1 if the packet could not be parsed. +func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p longPacket, n int) { + if len(pkt) < 5 || !isLongHeader(pkt[0]) { + return longPacket{}, -1 + } + + // Header Form (1) = 1, + // Fixed Bit (1) = 1, + // Long Packet Type (2), + // Type-Specific Bits (4), + b := pkt + p.ptype = getPacketType(b) + if p.ptype == packetTypeInvalid { + return longPacket{}, -1 + } + b = b[1:] + // Version (32), + p.version, n = quicwire.ConsumeUint32(b) + if n < 0 { + return longPacket{}, -1 + } + b = b[n:] + if p.version == 0 { + // Version Negotiation packet; not handled here. + return longPacket{}, -1 + } + + // Destination Connection ID Length (8), + // Destination Connection ID (0..160), + p.dstConnID, n = quicwire.ConsumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > maxConnIDLen { + return longPacket{}, -1 + } + b = b[n:] + + // Source Connection ID Length (8), + // Source Connection ID (0..160), + p.srcConnID, n = quicwire.ConsumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > maxConnIDLen { + return longPacket{}, -1 + } + b = b[n:] + + switch p.ptype { + case packetTypeInitial: + // Token Length (i), + // Token (..), + p.extra, n = quicwire.ConsumeVarintBytes(b) + if n < 0 { + return longPacket{}, -1 + } + b = b[n:] + case packetTypeRetry: + // Retry Token (..), + // Retry Integrity Tag (128), + p.extra = b + return p, len(pkt) + } + + // Length (i), + payLen, n := quicwire.ConsumeVarint(b) + if n < 0 { + return longPacket{}, -1 + } + b = b[n:] + if uint64(len(b)) < payLen { + return longPacket{}, -1 + } + + // Packet Number (8..32), + // Packet Payload (..), + pnumOff := len(pkt) - len(b) + pkt = pkt[:pnumOff+int(payLen)] + + if k.isSet() { + var err error + p.payload, p.num, err = k.unprotect(pkt, pnumOff, pnumMax) + if err != nil { + return longPacket{}, -1 + } + } + return p, len(pkt) +} + +// skipLongHeaderPacket returns the length of the long header packet at the start of pkt, +// or -1 if the buffer does not contain a valid packet. +func skipLongHeaderPacket(pkt []byte) int { + // Header byte, 4 bytes of version. + n := 5 + if len(pkt) <= n { + return -1 + } + // Destination connection ID length, destination connection ID. + n += 1 + int(pkt[n]) + if len(pkt) <= n { + return -1 + } + // Source connection ID length, source connection ID. + n += 1 + int(pkt[n]) + if len(pkt) <= n { + return -1 + } + if getPacketType(pkt) == packetTypeInitial { + // Token length, token. + _, nn := quicwire.ConsumeVarintBytes(pkt[n:]) + if nn < 0 { + return -1 + } + n += nn + } + // Length, packet number, payload. + _, nn := quicwire.ConsumeVarintBytes(pkt[n:]) + if nn < 0 { + return -1 + } + n += nn + if len(pkt) < n { + return -1 + } + return n +} + +// parse1RTTPacket parses a QUIC 1-RTT (short header) packet. +// +// On input, pkt contains a short header packet, k the decryption keys for the packet, +// and pnumMax the largest packet number seen in the number space of this packet. +func parse1RTTPacket(pkt []byte, k *updatingKeyPair, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, err error) { + pay, pnum, err := k.unprotect(pkt, 1+dstConnIDLen, pnumMax) + if err != nil { + return shortPacket{}, err + } + p.num = pnum + p.payload = pay + return p, nil +} + +// Consume functions return n=-1 on conditions which result in FRAME_ENCODING_ERROR, +// which includes both general parse failures and specific violations of frame +// constraints. + +func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumber)) (largest packetNumber, ackDelay unscaledAckDelay, ecn ecnCounts, n int) { + b := frame[1:] // type + + largestAck, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + + v, n := quicwire.ConsumeVarintInt64(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + ackDelay = unscaledAckDelay(v) + + ackRangeCount, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + + rangeMax := packetNumber(largestAck) + for i := uint64(0); ; i++ { + rangeLen, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + rangeMin := rangeMax - packetNumber(rangeLen) + if rangeMin < 0 || rangeMin > rangeMax { + return 0, 0, ecnCounts{}, -1 + } + f(int(i), rangeMin, rangeMax+1) + + if i == ackRangeCount { + break + } + + gap, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + + rangeMax = rangeMin - packetNumber(gap) - 2 + } + + if frame[0] != frameTypeAckECN { + return packetNumber(largestAck), ackDelay, ecnCounts{}, len(frame) - len(b) + } + + ect0Count, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + ect1Count, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + ecnCECount, n := quicwire.ConsumeVarint(b) + if n < 0 { + return 0, 0, ecnCounts{}, -1 + } + b = b[n:] + + ecn.t0 = int(ect0Count) + ecn.t1 = int(ect1Count) + ecn.ce = int(ecnCECount) + + return packetNumber(largestAck), ackDelay, ecn, len(frame) - len(b) +} + +func consumeResetStreamFrame(b []byte) (id streamID, code uint64, finalSize int64, n int) { + n = 1 + idInt, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, 0, -1 + } + n += nn + code, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, 0, -1 + } + n += nn + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, 0, -1 + } + n += nn + finalSize = int64(v) + return streamID(idInt), code, finalSize, n +} + +func consumeStopSendingFrame(b []byte) (id streamID, code uint64, n int) { + n = 1 + idInt, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + code, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + return streamID(idInt), code, n +} + +func consumeCryptoFrame(b []byte) (off int64, data []byte, n int) { + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, nil, -1 + } + off = int64(v) + n += nn + data, nn = quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, nil, -1 + } + n += nn + return off, data, n +} + +func consumeNewTokenFrame(b []byte) (token []byte, n int) { + n = 1 + data, nn := quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return nil, -1 + } + if len(data) == 0 { + return nil, -1 + } + n += nn + return data, n +} + +func consumeStreamFrame(b []byte) (id streamID, off int64, fin bool, data []byte, n int) { + fin = (b[0] & 0x01) != 0 + n = 1 + idInt, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, false, nil, -1 + } + n += nn + if b[0]&0x04 != 0 { + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, false, nil, -1 + } + n += nn + off = int64(v) + } + if b[0]&0x02 != 0 { + data, nn = quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, 0, false, nil, -1 + } + n += nn + } else { + data = b[n:] + n += len(data) + } + if off+int64(len(data)) >= 1<<62 { + return 0, 0, false, nil, -1 + } + return streamID(idInt), off, fin, data, n +} + +func consumeMaxDataFrame(b []byte) (max int64, n int) { + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return int64(v), n +} + +func consumeMaxStreamDataFrame(b []byte) (id streamID, max int64, n int) { + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + id = streamID(v) + v, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + max = int64(v) + return id, max, n +} + +func consumeMaxStreamsFrame(b []byte) (typ streamType, max int64, n int) { + switch b[0] { + case frameTypeMaxStreamsBidi: + typ = bidiStream + case frameTypeMaxStreamsUni: + typ = uniStream + default: + return 0, 0, -1 + } + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + if v > maxStreamsLimit { + return 0, 0, -1 + } + return typ, int64(v), n +} + +func consumeStreamDataBlockedFrame(b []byte) (id streamID, max int64, n int) { + n = 1 + v, nn := quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + id = streamID(v) + max, nn = quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + return id, max, n +} + +func consumeDataBlockedFrame(b []byte) (max int64, n int) { + n = 1 + max, nn := quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return max, n +} + +func consumeStreamsBlockedFrame(b []byte) (typ streamType, max int64, n int) { + if b[0] == frameTypeStreamsBlockedBidi { + typ = bidiStream + } else { + typ = uniStream + } + n = 1 + max, nn := quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + return typ, max, n +} + +func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, resetToken statelessResetToken, n int) { + n = 1 + var nn int + seq, nn = quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, nil, statelessResetToken{}, -1 + } + n += nn + retire, nn = quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, nil, statelessResetToken{}, -1 + } + n += nn + if seq < retire { + return 0, 0, nil, statelessResetToken{}, -1 + } + connID, nn = quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, 0, nil, statelessResetToken{}, -1 + } + if len(connID) < 1 || len(connID) > 20 { + return 0, 0, nil, statelessResetToken{}, -1 + } + n += nn + if len(b[n:]) < len(resetToken) { + return 0, 0, nil, statelessResetToken{}, -1 + } + copy(resetToken[:], b[n:]) + n += len(resetToken) + return seq, retire, connID, resetToken, n +} + +func consumeRetireConnectionIDFrame(b []byte) (seq int64, n int) { + n = 1 + var nn int + seq, nn = quicwire.ConsumeVarintInt64(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return seq, n +} + +func consumePathChallengeFrame(b []byte) (data pathChallengeData, n int) { + n = 1 + nn := copy(data[:], b[n:]) + if nn != len(data) { + return data, -1 + } + n += nn + return data, n +} + +func consumePathResponseFrame(b []byte) (data pathChallengeData, n int) { + return consumePathChallengeFrame(b) // identical frame format +} + +func consumeConnectionCloseTransportFrame(b []byte) (code transportError, frameType uint64, reason string, n int) { + n = 1 + var nn int + var codeInt uint64 + codeInt, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, "", -1 + } + code = transportError(codeInt) + n += nn + frameType, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, 0, "", -1 + } + n += nn + reasonb, nn := quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, 0, "", -1 + } + n += nn + reason = string(reasonb) + return code, frameType, reason, n +} + +func consumeConnectionCloseApplicationFrame(b []byte) (code uint64, reason string, n int) { + n = 1 + var nn int + code, nn = quicwire.ConsumeVarint(b[n:]) + if nn < 0 { + return 0, "", -1 + } + n += nn + reasonb, nn := quicwire.ConsumeVarintBytes(b[n:]) + if nn < 0 { + return 0, "", -1 + } + n += nn + reason = string(reasonb) + return code, reason, n +} diff --git a/src/vendor/golang.org/x/net/quic/packet_protection.go b/src/vendor/golang.org/x/net/quic/packet_protection.go new file mode 100644 index 0000000000..7856d6b5d8 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet_protection.go @@ -0,0 +1,539 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "crypto/tls" + "errors" + "hash" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/hkdf" +) + +var errInvalidPacket = errors.New("quic: invalid packet") + +// headerProtectionSampleSize is the size of the ciphertext sample used for header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.2 +const headerProtectionSampleSize = 16 + +// aeadOverhead is the difference in size between the AEAD output and input. +// All cipher suites defined for use with QUIC have 16 bytes of overhead. +const aeadOverhead = 16 + +// A headerKey applies or removes header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4 +type headerKey struct { + hp headerProtection +} + +func (k headerKey) isSet() bool { + return k.hp != nil +} + +func (k *headerKey) init(suite uint16, secret []byte) { + h, keySize := hashForSuite(suite) + hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keySize) + switch suite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + c, err := aes.NewCipher(hpKey) + if err != nil { + panic(err) + } + k.hp = &aesHeaderProtection{cipher: c} + case tls.TLS_CHACHA20_POLY1305_SHA256: + k.hp = chaCha20HeaderProtection{hpKey} + default: + panic("BUG: unknown cipher suite") + } +} + +// protect applies header protection. +// pnumOff is the offset of the packet number in the packet. +func (k headerKey) protect(hdr []byte, pnumOff int) { + // Apply header protection. + pnumSize := int(hdr[0]&0x03) + 1 + sample := hdr[pnumOff+4:][:headerProtectionSampleSize] + mask := k.hp.headerProtection(sample) + if isLongHeader(hdr[0]) { + hdr[0] ^= mask[0] & 0x0f + } else { + hdr[0] ^= mask[0] & 0x1f + } + for i := 0; i < pnumSize; i++ { + hdr[pnumOff+i] ^= mask[1+i] + } +} + +// unprotect removes header protection. +// pnumOff is the offset of the packet number in the packet. +// pnumMax is the largest packet number seen in the number space of this packet. +func (k headerKey) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (hdr, pay []byte, pnum packetNumber, _ error) { + if len(pkt) < pnumOff+4+headerProtectionSampleSize { + return nil, nil, 0, errInvalidPacket + } + numpay := pkt[pnumOff:] + sample := numpay[4:][:headerProtectionSampleSize] + mask := k.hp.headerProtection(sample) + if isLongHeader(pkt[0]) { + pkt[0] ^= mask[0] & 0x0f + } else { + pkt[0] ^= mask[0] & 0x1f + } + pnumLen := int(pkt[0]&0x03) + 1 + pnum = packetNumber(0) + for i := 0; i < pnumLen; i++ { + numpay[i] ^= mask[1+i] + pnum = (pnum << 8) | packetNumber(numpay[i]) + } + pnum = decodePacketNumber(pnumMax, pnum, pnumLen) + hdr = pkt[:pnumOff+pnumLen] + pay = numpay[pnumLen:] + return hdr, pay, pnum, nil +} + +// headerProtection is the header_protection function as defined in: +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1 +// +// This function takes a sample of the packet ciphertext +// and returns a 5-byte mask which will be applied to the +// protected portions of the packet header. +type headerProtection interface { + headerProtection(sample []byte) (mask [5]byte) +} + +// AES-based header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3 +type aesHeaderProtection struct { + cipher cipher.Block + scratch [aes.BlockSize]byte +} + +func (hp *aesHeaderProtection) headerProtection(sample []byte) (mask [5]byte) { + hp.cipher.Encrypt(hp.scratch[:], sample) + copy(mask[:], hp.scratch[:]) + return mask +} + +// ChaCha20-based header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4 +type chaCha20HeaderProtection struct { + key []byte +} + +func (hp chaCha20HeaderProtection) headerProtection(sample []byte) (mask [5]byte) { + counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0]) + nonce := sample[4:16] + c, err := chacha20.NewUnauthenticatedCipher(hp.key, nonce) + if err != nil { + panic(err) + } + c.SetCounter(counter) + c.XORKeyStream(mask[:], mask[:]) + return mask +} + +// A packetKey applies or removes packet protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.1 +type packetKey struct { + aead cipher.AEAD // AEAD function used for packet protection. + iv []byte // IV used to construct the AEAD nonce. +} + +func (k *packetKey) init(suite uint16, secret []byte) { + // https://www.rfc-editor.org/rfc/rfc9001#section-5.1 + h, keySize := hashForSuite(suite) + key := hkdfExpandLabel(h.New, secret, "quic key", nil, keySize) + switch suite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + k.aead = newAESAEAD(key) + case tls.TLS_CHACHA20_POLY1305_SHA256: + k.aead = newChaCha20AEAD(key) + default: + panic("BUG: unknown cipher suite") + } + k.iv = hkdfExpandLabel(h.New, secret, "quic iv", nil, k.aead.NonceSize()) +} + +func newAESAEAD(key []byte) cipher.AEAD { + c, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(c) + if err != nil { + panic(err) + } + return aead +} + +func newChaCha20AEAD(key []byte) cipher.AEAD { + var err error + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + return aead +} + +func (k packetKey) protect(hdr, pay []byte, pnum packetNumber) []byte { + k.xorIV(pnum) + defer k.xorIV(pnum) + return k.aead.Seal(hdr, k.iv, pay, hdr) +} + +func (k packetKey) unprotect(hdr, pay []byte, pnum packetNumber) (dec []byte, err error) { + k.xorIV(pnum) + defer k.xorIV(pnum) + return k.aead.Open(pay[:0], k.iv, pay, hdr) +} + +// xorIV xors the packet protection IV with the packet number. +func (k packetKey) xorIV(pnum packetNumber) { + k.iv[len(k.iv)-8] ^= uint8(pnum >> 56) + k.iv[len(k.iv)-7] ^= uint8(pnum >> 48) + k.iv[len(k.iv)-6] ^= uint8(pnum >> 40) + k.iv[len(k.iv)-5] ^= uint8(pnum >> 32) + k.iv[len(k.iv)-4] ^= uint8(pnum >> 24) + k.iv[len(k.iv)-3] ^= uint8(pnum >> 16) + k.iv[len(k.iv)-2] ^= uint8(pnum >> 8) + k.iv[len(k.iv)-1] ^= uint8(pnum) +} + +// A fixedKeys is a header protection key and fixed packet protection key. +// The packet protection key is fixed (it does not update). +// +// Fixed keys are used for Initial and Handshake keys, which do not update. +type fixedKeys struct { + hdr headerKey + pkt packetKey +} + +func (k *fixedKeys) init(suite uint16, secret []byte) { + k.hdr.init(suite, secret) + k.pkt.init(suite, secret) +} + +func (k fixedKeys) isSet() bool { + return k.hdr.hp != nil +} + +// protect applies packet protection to a packet. +// +// On input, hdr contains the packet header, pay the unencrypted payload, +// pnumOff the offset of the packet number in the header, and pnum the untruncated +// packet number. +// +// protect returns the result of appending the encrypted payload to hdr and +// applying header protection. +func (k fixedKeys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte { + pkt := k.pkt.protect(hdr, pay, pnum) + k.hdr.protect(pkt, pnumOff) + return pkt +} + +// unprotect removes packet protection from a packet. +// +// On input, pkt contains the full protected packet, pnumOff the offset of +// the packet number in the header, and pnumMax the largest packet number +// seen in the number space of this packet. +// +// unprotect removes header protection from the header in pkt, and returns +// the unprotected payload and packet number. +func (k fixedKeys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) { + hdr, pay, pnum, err := k.hdr.unprotect(pkt, pnumOff, pnumMax) + if err != nil { + return nil, 0, err + } + pay, err = k.pkt.unprotect(hdr, pay, pnum) + if err != nil { + return nil, 0, err + } + return pay, pnum, nil +} + +// A fixedKeyPair is a read/write pair of fixed keys. +type fixedKeyPair struct { + r, w fixedKeys +} + +func (k *fixedKeyPair) discard() { + *k = fixedKeyPair{} +} + +func (k *fixedKeyPair) canRead() bool { + return k.r.isSet() +} + +func (k *fixedKeyPair) canWrite() bool { + return k.w.isSet() +} + +// An updatingKeys is a header protection key and updatable packet protection key. +// updatingKeys are used for 1-RTT keys, where the packet protection key changes +// over the lifetime of a connection. +// https://www.rfc-editor.org/rfc/rfc9001#section-6 +type updatingKeys struct { + suite uint16 + hdr headerKey + pkt [2]packetKey // current, next + nextSecret []byte // secret used to generate pkt[1] +} + +func (k *updatingKeys) init(suite uint16, secret []byte) { + k.suite = suite + k.hdr.init(suite, secret) + // Initialize pkt[1] with secret_0, and then call update to generate secret_1. + k.pkt[1].init(suite, secret) + k.nextSecret = secret + k.update() +} + +// update performs a key update. +// The current key in pkt[0] is discarded. +// The next key in pkt[1] becomes the current key. +// A new next key is generated in pkt[1]. +func (k *updatingKeys) update() { + k.nextSecret = updateSecret(k.suite, k.nextSecret) + k.pkt[0] = k.pkt[1] + k.pkt[1].init(k.suite, k.nextSecret) +} + +func updateSecret(suite uint16, secret []byte) (nextSecret []byte) { + h, _ := hashForSuite(suite) + return hkdfExpandLabel(h.New, secret, "quic ku", nil, len(secret)) +} + +// An updatingKeyPair is a read/write pair of updating keys. +// +// We keep two keys (current and next) in both read and write directions. +// When an incoming packet's phase matches the current phase bit, +// we unprotect it using the current keys; otherwise we use the next keys. +// +// When updating=false, outgoing packets are protected using the current phase. +// +// An update is initiated and updating is set to true when: +// - we decide to initiate a key update; or +// - we successfully unprotect a packet using the next keys, +// indicating the peer has initiated a key update. +// +// When updating=true, outgoing packets are protected using the next phase. +// We do not change the current phase bit or generate new keys yet. +// +// The update concludes when we receive an ACK frame for a packet sent +// with the next keys. At this time, we set updating to false, flip the +// phase bit, and update the keys. This permits us to handle up to 1-RTT +// of reordered packets before discarding the previous phase's keys after +// an update. +type updatingKeyPair struct { + phase uint8 // current key phase (r.pkt[0], w.pkt[0]) + updating bool + authFailures int64 // total packet unprotect failures + minSent packetNumber // min packet number sent since entering the updating state + minReceived packetNumber // min packet number received in the next phase + updateAfter packetNumber // packet number after which to initiate key update + r, w updatingKeys +} + +func (k *updatingKeyPair) init() { + // 1-RTT packets until the first key update. + // + // We perform the first key update early in the connection so a peer + // which does not support key updates will fail rapidly, + // rather than after the connection has been long established. + // + // The QUIC interop runner "keyupdate" test requires that the client + // initiate a key rotation early in the connection. Increasing this + // value may cause interop test failures; if we do want to increase it, + // we should either skip the keyupdate test or provide a way to override + // the setting in interop tests. + k.updateAfter = 100 +} + +func (k *updatingKeyPair) canRead() bool { + return k.r.hdr.hp != nil +} + +func (k *updatingKeyPair) canWrite() bool { + return k.w.hdr.hp != nil +} + +// handleAckFor finishes a key update after receiving an ACK for a packet in the next phase. +func (k *updatingKeyPair) handleAckFor(pnum packetNumber) { + if k.updating && pnum >= k.minSent { + k.updating = false + k.phase ^= keyPhaseBit + k.r.update() + k.w.update() + } +} + +// needAckEliciting reports whether we should send an ack-eliciting packet in the next phase. +// The first packet sent in a phase is ack-eliciting, since the peer must acknowledge a +// packet in the new phase for us to finish the update. +func (k *updatingKeyPair) needAckEliciting() bool { + return k.updating && k.minSent == maxPacketNumber +} + +// protect applies packet protection to a packet. +// Parameters and returns are as for fixedKeyPair.protect. +func (k *updatingKeyPair) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte { + var pkt []byte + if k.updating { + hdr[0] |= k.phase ^ keyPhaseBit + pkt = k.w.pkt[1].protect(hdr, pay, pnum) + k.minSent = min(pnum, k.minSent) + } else { + hdr[0] |= k.phase + pkt = k.w.pkt[0].protect(hdr, pay, pnum) + if pnum >= k.updateAfter { + // Initiate a key update, starting with the next packet we send. + // + // We do this after protecting the current packet + // to allow Conn.appendFrames to ensure that the first packet sent + // in the new phase is ack-eliciting. + k.updating = true + k.minSent = maxPacketNumber + k.minReceived = maxPacketNumber + // The lowest confidentiality limit for a supported AEAD is 2^23 packets. + // https://www.rfc-editor.org/rfc/rfc9001#section-6.6-5 + // + // Schedule our next update for half that. + k.updateAfter += (1 << 22) + } + } + k.w.hdr.protect(pkt, pnumOff) + return pkt +} + +// unprotect removes packet protection from a packet. +// Parameters and returns are as for fixedKeyPair.unprotect. +func (k *updatingKeyPair) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, pnum packetNumber, err error) { + hdr, pay, pnum, err := k.r.hdr.unprotect(pkt, pnumOff, pnumMax) + if err != nil { + return nil, 0, err + } + // To avoid timing signals that might indicate the key phase bit is invalid, + // we always attempt to unprotect the packet with one key. + // + // If the key phase bit matches and the packet number doesn't come after + // the start of an in-progress update, use the current phase. + // Otherwise, use the next phase. + if hdr[0]&keyPhaseBit == k.phase && (!k.updating || pnum < k.minReceived) { + pay, err = k.r.pkt[0].unprotect(hdr, pay, pnum) + } else { + pay, err = k.r.pkt[1].unprotect(hdr, pay, pnum) + if err == nil { + if !k.updating { + // The peer has initiated a key update. + k.updating = true + k.minSent = maxPacketNumber + k.minReceived = pnum + } else { + k.minReceived = min(pnum, k.minReceived) + } + } + } + if err != nil { + k.authFailures++ + if k.authFailures >= aeadIntegrityLimit(k.r.suite) { + return nil, 0, localTransportError{code: errAEADLimitReached} + } + return nil, 0, err + } + return pay, pnum, nil +} + +// aeadIntegrityLimit returns the integrity limit for an AEAD: +// The maximum number of received packets that may fail authentication +// before closing the connection. +// +// https://www.rfc-editor.org/rfc/rfc9001#section-6.6-4 +func aeadIntegrityLimit(suite uint16) int64 { + switch suite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + return 1 << 52 + case tls.TLS_CHACHA20_POLY1305_SHA256: + return 1 << 36 + default: + panic("BUG: unknown cipher suite") + } +} + +// https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2 +var initialSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + +// initialKeys returns the keys used to protect Initial packets. +// +// The Initial packet keys are derived from the Destination Connection ID +// field in the client's first Initial packet. +// +// https://www.rfc-editor.org/rfc/rfc9001#section-5.2 +func initialKeys(cid []byte, side connSide) fixedKeyPair { + initialSecret := hkdf.Extract(sha256.New, cid, initialSalt) + var clientKeys fixedKeys + clientSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size) + clientKeys.init(tls.TLS_AES_128_GCM_SHA256, clientSecret) + var serverKeys fixedKeys + serverSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size) + serverKeys.init(tls.TLS_AES_128_GCM_SHA256, serverSecret) + if side == clientSide { + return fixedKeyPair{r: serverKeys, w: clientKeys} + } else { + return fixedKeyPair{w: serverKeys, r: clientKeys} + } +} + +// checkCipherSuite returns an error if suite is not a supported cipher suite. +func checkCipherSuite(suite uint16) error { + switch suite { + case tls.TLS_AES_128_GCM_SHA256: + case tls.TLS_AES_256_GCM_SHA384: + case tls.TLS_CHACHA20_POLY1305_SHA256: + default: + return errors.New("invalid cipher suite") + } + return nil +} + +func hashForSuite(suite uint16) (h crypto.Hash, keySize int) { + switch suite { + case tls.TLS_AES_128_GCM_SHA256: + return crypto.SHA256, 128 / 8 + case tls.TLS_AES_256_GCM_SHA384: + return crypto.SHA384, 256 / 8 + case tls.TLS_CHACHA20_POLY1305_SHA256: + return crypto.SHA256, chacha20.KeySize + default: + panic("BUG: unknown cipher suite") + } +} + +// hkdfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +// +// Copied from crypto/tls/key_schedule.go. +func hkdfExpandLabel(hash func() hash.Hash, secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(hash, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("quic: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} diff --git a/src/vendor/golang.org/x/net/quic/packet_writer.go b/src/vendor/golang.org/x/net/quic/packet_writer.go new file mode 100644 index 0000000000..f446521d2b --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/packet_writer.go @@ -0,0 +1,566 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "encoding/binary" + + "golang.org/x/net/internal/quic/quicwire" +) + +// A packetWriter constructs QUIC datagrams. +// +// A datagram consists of one or more packets. +// A packet consists of a header followed by one or more frames. +// +// Packets are written in three steps: +// - startProtectedLongHeaderPacket or start1RTT packet prepare the packet; +// - append*Frame appends frames to the payload; and +// - finishProtectedLongHeaderPacket or finish1RTT finalize the packet. +// +// The start functions are efficient, so we can start speculatively +// writing a packet before we know whether we have any frames to +// put in it. The finish functions will abandon the packet if the +// payload contains no data. +type packetWriter struct { + dgramLim int // max datagram size + pktLim int // max packet size + pktOff int // offset of the start of the current packet + payOff int // offset of the payload of the current packet + b []byte + sent *sentPacket +} + +// reset prepares to write a datagram of at most lim bytes. +func (w *packetWriter) reset(lim int) { + if cap(w.b) < lim { + w.b = make([]byte, 0, lim) + } + w.dgramLim = lim + w.b = w.b[:0] +} + +// datagram returns the current datagram. +func (w *packetWriter) datagram() []byte { + return w.b +} + +// packetLen returns the size of the current packet. +func (w *packetWriter) packetLen() int { + return len(w.b[w.pktOff:]) + aeadOverhead +} + +// payload returns the payload of the current packet. +func (w *packetWriter) payload() []byte { + return w.b[w.payOff:] +} + +func (w *packetWriter) abandonPacket() { + w.b = w.b[:w.payOff] + w.sent.reset() +} + +// startProtectedLongHeaderPacket starts writing an Initial, 0-RTT, or Handshake packet. +func (w *packetWriter) startProtectedLongHeaderPacket(pnumMaxAcked packetNumber, p longPacket) { + if w.sent == nil { + w.sent = newSentPacket() + } + w.pktOff = len(w.b) + hdrSize := 1 // packet type + hdrSize += 4 // version + hdrSize += 1 + len(p.dstConnID) + hdrSize += 1 + len(p.srcConnID) + switch p.ptype { + case packetTypeInitial: + hdrSize += quicwire.SizeVarint(uint64(len(p.extra))) + len(p.extra) + } + hdrSize += 2 // length, hardcoded to a 2-byte varint + pnumOff := len(w.b) + hdrSize + hdrSize += packetNumberLength(p.num, pnumMaxAcked) + payOff := len(w.b) + hdrSize + // Check if we have enough space to hold the packet, including the header, + // header protection sample (RFC 9001, section 5.4.2), and encryption overhead. + if pnumOff+4+headerProtectionSampleSize+aeadOverhead >= w.dgramLim { + // Set the limit on the packet size to be the current write buffer length, + // ensuring that any writes to the payload fail. + w.payOff = len(w.b) + w.pktLim = len(w.b) + return + } + w.payOff = payOff + w.pktLim = w.dgramLim - aeadOverhead + // We hardcode the payload length field to be 2 bytes, which limits the payload + // (including the packet number) to 16383 bytes (the largest 2-byte QUIC varint). + // + // Most networks don't support datagrams over 1472 bytes, and even Ethernet + // jumbo frames are generally only about 9000 bytes. + if lim := pnumOff + 16383 - aeadOverhead; lim < w.pktLim { + w.pktLim = lim + } + w.b = w.b[:payOff] +} + +// finishProtectedLongHeaderPacket finishes writing an Initial, 0-RTT, or Handshake packet, +// canceling the packet if it contains no payload. +// It returns a sentPacket describing the packet, or nil if no packet was written. +func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber, k fixedKeys, p longPacket) *sentPacket { + if len(w.b) == w.payOff { + // The payload is empty, so just abandon the packet. + w.b = w.b[:w.pktOff] + return nil + } + pnumLen := packetNumberLength(p.num, pnumMaxAcked) + plen := w.padPacketLength(pnumLen) + hdr := w.b[:w.pktOff] + var typeBits byte + switch p.ptype { + case packetTypeInitial: + typeBits = longPacketTypeInitial + case packetType0RTT: + typeBits = longPacketType0RTT + case packetTypeHandshake: + typeBits = longPacketTypeHandshake + case packetTypeRetry: + typeBits = longPacketTypeRetry + } + hdr = append(hdr, headerFormLong|fixedBit|typeBits|byte(pnumLen-1)) + hdr = binary.BigEndian.AppendUint32(hdr, p.version) + hdr = quicwire.AppendUint8Bytes(hdr, p.dstConnID) + hdr = quicwire.AppendUint8Bytes(hdr, p.srcConnID) + switch p.ptype { + case packetTypeInitial: + hdr = quicwire.AppendVarintBytes(hdr, p.extra) // token + } + + // Packet length, always encoded as a 2-byte varint. + hdr = append(hdr, 0x40|byte(plen>>8), byte(plen)) + + pnumOff := len(hdr) + hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked) + + k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, p.num) + return w.finish(p.ptype, p.num) +} + +// start1RTTPacket starts writing a 1-RTT (short header) packet. +func (w *packetWriter) start1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte) { + if w.sent == nil { + w.sent = newSentPacket() + } + w.pktOff = len(w.b) + hdrSize := 1 // packet type + hdrSize += len(dstConnID) + // Ensure we have enough space to hold the packet, including the header, + // header protection sample (RFC 9001, section 5.4.2), and encryption overhead. + if len(w.b)+hdrSize+4+headerProtectionSampleSize+aeadOverhead >= w.dgramLim { + w.payOff = len(w.b) + w.pktLim = len(w.b) + return + } + hdrSize += packetNumberLength(pnum, pnumMaxAcked) + w.payOff = len(w.b) + hdrSize + w.pktLim = w.dgramLim - aeadOverhead + w.b = w.b[:w.payOff] +} + +// finish1RTTPacket finishes writing a 1-RTT packet, +// canceling the packet if it contains no payload. +// It returns a sentPacket describing the packet, or nil if no packet was written. +func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k *updatingKeyPair) *sentPacket { + if len(w.b) == w.payOff { + // The payload is empty, so just abandon the packet. + w.b = w.b[:w.pktOff] + return nil + } + // TODO: Spin + pnumLen := packetNumberLength(pnum, pnumMaxAcked) + hdr := w.b[:w.pktOff] + hdr = append(hdr, 0x40|byte(pnumLen-1)) + hdr = append(hdr, dstConnID...) + pnumOff := len(hdr) + hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked) + w.padPacketLength(pnumLen) + k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, pnum) + return w.finish(packetType1RTT, pnum) +} + +// padPacketLength pads out the payload of the current packet to the minimum size, +// and returns the combined length of the packet number and payload (used for the Length +// field of long header packets). +func (w *packetWriter) padPacketLength(pnumLen int) int { + plen := len(w.b) - w.payOff + pnumLen + aeadOverhead + // "To ensure that sufficient data is available for sampling, packets are + // padded so that the combined lengths of the encoded packet number and + // protected payload is at least 4 bytes longer than the sample required + // for header protection." + // https://www.rfc-editor.org/rfc/rfc9001.html#section-5.4.2 + for plen < 4+headerProtectionSampleSize { + w.b = append(w.b, 0) + plen++ + } + return plen +} + +// finish finishes the current packet after protection is applied. +func (w *packetWriter) finish(ptype packetType, pnum packetNumber) *sentPacket { + w.b = w.b[:len(w.b)+aeadOverhead] + w.sent.size = len(w.b) - w.pktOff + w.sent.ptype = ptype + w.sent.num = pnum + sent := w.sent + w.sent = nil + return sent +} + +// avail reports how many more bytes may be written to the current packet. +func (w *packetWriter) avail() int { + return w.pktLim - len(w.b) +} + +// appendPaddingTo appends PADDING frames until the total datagram size +// (including AEAD overhead of the current packet) is n. +func (w *packetWriter) appendPaddingTo(n int) { + n -= aeadOverhead + lim := w.pktLim + if n < lim { + lim = n + } + if len(w.b) >= lim { + return + } + for len(w.b) < lim { + w.b = append(w.b, frameTypePadding) + } + // Packets are considered in flight when they contain a PADDING frame. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.6.1 + w.sent.inFlight = true +} + +func (w *packetWriter) appendPingFrame() (added bool) { + if len(w.b) >= w.pktLim { + return false + } + w.b = append(w.b, frameTypePing) + w.sent.markAckEliciting() // no need to record the frame itself + return true +} + +// appendAckFrame appends an ACK frame to the payload. +// It includes at least the most recent range in the rangeset +// (the range with the largest packet numbers), +// followed by as many additional ranges as fit within the packet. +// +// We always place ACK frames at the start of packets, +// we limit the number of ack ranges retained, and +// we set a minimum packet payload size. +// As a result, appendAckFrame will rarely if ever drop ranges +// in practice. +// +// In the event that ranges are dropped, the impact is limited +// to the peer potentially failing to receive an acknowledgement +// for an older packet during a period of high packet loss or +// reordering. This may result in unnecessary retransmissions. +func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscaledAckDelay, ecn ecnCounts) (added bool) { + if len(seen) == 0 { + return false + } + var ( + largest = uint64(seen.max()) + firstRange = uint64(seen[len(seen)-1].size() - 1) + ) + var ecnLen int + ackType := byte(frameTypeAck) + if (ecn != ecnCounts{}) { + // "Even if an endpoint does not set an ECT field in packets it sends, + // the endpoint MUST provide feedback about ECN markings it receives, if + // these are accessible." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.4.1-2 + ecnLen = quicwire.SizeVarint(uint64(ecn.ce)) + quicwire.SizeVarint(uint64(ecn.t0)) + quicwire.SizeVarint(uint64(ecn.t1)) + ackType = frameTypeAckECN + } + if w.avail() < 1+quicwire.SizeVarint(largest)+quicwire.SizeVarint(uint64(delay))+1+quicwire.SizeVarint(firstRange)+ecnLen { + return false + } + w.b = append(w.b, ackType) + w.b = quicwire.AppendVarint(w.b, largest) + w.b = quicwire.AppendVarint(w.b, uint64(delay)) + // The range count is technically a varint, but we'll reserve a single byte for it + // and never add more than 62 ranges (the maximum varint that fits in a byte). + rangeCountOff := len(w.b) + w.b = append(w.b, 0) + w.b = quicwire.AppendVarint(w.b, firstRange) + rangeCount := byte(0) + for i := len(seen) - 2; i >= 0; i-- { + gap := uint64(seen[i+1].start - seen[i].end - 1) + size := uint64(seen[i].size() - 1) + if w.avail() < quicwire.SizeVarint(gap)+quicwire.SizeVarint(size)+ecnLen || rangeCount > 62 { + break + } + w.b = quicwire.AppendVarint(w.b, gap) + w.b = quicwire.AppendVarint(w.b, size) + rangeCount++ + } + w.b[rangeCountOff] = rangeCount + if ackType == frameTypeAckECN { + w.b = quicwire.AppendVarint(w.b, uint64(ecn.t0)) + w.b = quicwire.AppendVarint(w.b, uint64(ecn.t1)) + w.b = quicwire.AppendVarint(w.b, uint64(ecn.ce)) + } + w.sent.appendNonAckElicitingFrame(ackType) + w.sent.appendInt(uint64(seen.max())) + return true +} + +func (w *packetWriter) appendNewTokenFrame(token []byte) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(len(token)))+len(token) { + return false + } + w.b = append(w.b, frameTypeNewToken) + w.b = quicwire.AppendVarintBytes(w.b, token) + return true +} + +func (w *packetWriter) appendResetStreamFrame(id streamID, code uint64, finalSize int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(code)+quicwire.SizeVarint(uint64(finalSize)) { + return false + } + w.b = append(w.b, frameTypeResetStream) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + w.b = quicwire.AppendVarint(w.b, code) + w.b = quicwire.AppendVarint(w.b, uint64(finalSize)) + w.sent.appendAckElicitingFrame(frameTypeResetStream) + w.sent.appendInt(uint64(id)) + return true +} + +func (w *packetWriter) appendStopSendingFrame(id streamID, code uint64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(code) { + return false + } + w.b = append(w.b, frameTypeStopSending) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + w.b = quicwire.AppendVarint(w.b, code) + w.sent.appendAckElicitingFrame(frameTypeStopSending) + w.sent.appendInt(uint64(id)) + return true +} + +// appendCryptoFrame appends a CRYPTO frame. +// It returns a []byte into which the data should be written and whether a frame was added. +// The returned []byte may be smaller than size if the packet cannot hold all the data. +func (w *packetWriter) appendCryptoFrame(off int64, size int) (_ []byte, added bool) { + max := w.avail() + max -= 1 // frame type + max -= quicwire.SizeVarint(uint64(off)) // offset + max -= quicwire.SizeVarint(uint64(size)) // maximum length + if max <= 0 { + return nil, false + } + if max < size { + size = max + } + w.b = append(w.b, frameTypeCrypto) + w.b = quicwire.AppendVarint(w.b, uint64(off)) + w.b = quicwire.AppendVarint(w.b, uint64(size)) + start := len(w.b) + w.b = w.b[:start+size] + w.sent.appendAckElicitingFrame(frameTypeCrypto) + w.sent.appendOffAndSize(off, size) + return w.b[start:][:size], true +} + +// appendStreamFrame appends a STREAM frame. +// It returns a []byte into which the data should be written and whether a frame was added. +// The returned []byte may be smaller than size if the packet cannot hold all the data. +func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin bool) (_ []byte, added bool) { + typ := uint8(frameTypeStreamBase | streamLenBit) + max := w.avail() + max -= 1 // frame type + max -= quicwire.SizeVarint(uint64(id)) + if off != 0 { + max -= quicwire.SizeVarint(uint64(off)) + typ |= streamOffBit + } + max -= quicwire.SizeVarint(uint64(size)) // maximum length + if max < 0 || (max == 0 && size > 0) { + return nil, false + } + if max < size { + size = max + } else if fin { + typ |= streamFinBit + } + w.b = append(w.b, typ) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + if off != 0 { + w.b = quicwire.AppendVarint(w.b, uint64(off)) + } + w.b = quicwire.AppendVarint(w.b, uint64(size)) + start := len(w.b) + w.b = w.b[:start+size] + w.sent.appendAckElicitingFrame(typ & (frameTypeStreamBase | streamFinBit)) + w.sent.appendInt(uint64(id)) + w.sent.appendOffAndSize(off, size) + return w.b[start:][:size], true +} + +func (w *packetWriter) appendMaxDataFrame(max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeMaxData) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeMaxData) + return true +} + +func (w *packetWriter) appendMaxStreamDataFrame(id streamID, max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeMaxStreamData) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeMaxStreamData) + w.sent.appendInt(uint64(id)) + return true +} + +func (w *packetWriter) appendMaxStreamsFrame(streamType streamType, max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(max)) { + return false + } + var typ byte + if streamType == bidiStream { + typ = frameTypeMaxStreamsBidi + } else { + typ = frameTypeMaxStreamsUni + } + w.b = append(w.b, typ) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(typ) + return true +} + +func (w *packetWriter) appendDataBlockedFrame(max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeDataBlocked) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeDataBlocked) + return true +} + +func (w *packetWriter) appendStreamDataBlockedFrame(id streamID, max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeStreamDataBlocked) + w.b = quicwire.AppendVarint(w.b, uint64(id)) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeStreamDataBlocked) + w.sent.appendInt(uint64(id)) + return true +} + +func (w *packetWriter) appendStreamsBlockedFrame(typ streamType, max int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(max)) { + return false + } + var ftype byte + if typ == bidiStream { + ftype = frameTypeStreamsBlockedBidi + } else { + ftype = frameTypeStreamsBlockedUni + } + w.b = append(w.b, ftype) + w.b = quicwire.AppendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(ftype) + return true +} + +func (w *packetWriter) appendNewConnectionIDFrame(seq, retirePriorTo int64, connID []byte, token [16]byte) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(seq))+quicwire.SizeVarint(uint64(retirePriorTo))+1+len(connID)+len(token) { + return false + } + w.b = append(w.b, frameTypeNewConnectionID) + w.b = quicwire.AppendVarint(w.b, uint64(seq)) + w.b = quicwire.AppendVarint(w.b, uint64(retirePriorTo)) + w.b = quicwire.AppendUint8Bytes(w.b, connID) + w.b = append(w.b, token[:]...) + w.sent.appendAckElicitingFrame(frameTypeNewConnectionID) + w.sent.appendInt(uint64(seq)) + return true +} + +func (w *packetWriter) appendRetireConnectionIDFrame(seq int64) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(seq)) { + return false + } + w.b = append(w.b, frameTypeRetireConnectionID) + w.b = quicwire.AppendVarint(w.b, uint64(seq)) + w.sent.appendAckElicitingFrame(frameTypeRetireConnectionID) + w.sent.appendInt(uint64(seq)) + return true +} + +func (w *packetWriter) appendPathChallengeFrame(data pathChallengeData) (added bool) { + if w.avail() < 1+8 { + return false + } + w.b = append(w.b, frameTypePathChallenge) + w.b = append(w.b, data[:]...) + w.sent.markAckEliciting() // no need to record the frame itself + return true +} + +func (w *packetWriter) appendPathResponseFrame(data pathChallengeData) (added bool) { + if w.avail() < 1+8 { + return false + } + w.b = append(w.b, frameTypePathResponse) + w.b = append(w.b, data[:]...) + w.sent.markAckEliciting() // no need to record the frame itself + return true +} + +// appendConnectionCloseTransportFrame appends a CONNECTION_CLOSE frame +// carrying a transport error code. +func (w *packetWriter) appendConnectionCloseTransportFrame(code transportError, frameType uint64, reason string) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(uint64(code))+quicwire.SizeVarint(frameType)+quicwire.SizeVarint(uint64(len(reason)))+len(reason) { + return false + } + w.b = append(w.b, frameTypeConnectionCloseTransport) + w.b = quicwire.AppendVarint(w.b, uint64(code)) + w.b = quicwire.AppendVarint(w.b, frameType) + w.b = quicwire.AppendVarintBytes(w.b, []byte(reason)) + // We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or + // detected as lost. + return true +} + +// appendConnectionCloseApplicationFrame appends a CONNECTION_CLOSE frame +// carrying an application protocol error code. +func (w *packetWriter) appendConnectionCloseApplicationFrame(code uint64, reason string) (added bool) { + if w.avail() < 1+quicwire.SizeVarint(code)+quicwire.SizeVarint(uint64(len(reason)))+len(reason) { + return false + } + w.b = append(w.b, frameTypeConnectionCloseApplication) + w.b = quicwire.AppendVarint(w.b, code) + w.b = quicwire.AppendVarintBytes(w.b, []byte(reason)) + // We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or + // detected as lost. + return true +} + +func (w *packetWriter) appendHandshakeDoneFrame() (added bool) { + if w.avail() < 1 { + return false + } + w.b = append(w.b, frameTypeHandshakeDone) + w.sent.appendAckElicitingFrame(frameTypeHandshakeDone) + return true +} diff --git a/src/vendor/golang.org/x/net/quic/path.go b/src/vendor/golang.org/x/net/quic/path.go new file mode 100644 index 0000000000..5170562c74 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/path.go @@ -0,0 +1,87 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "time" + +type pathState struct { + // Response to a peer's PATH_CHALLENGE. + // This is not a sentVal, because we don't resend lost PATH_RESPONSE frames. + // We only track the most recent PATH_CHALLENGE. + // If the peer sends a second PATH_CHALLENGE before we respond to the first, + // we'll drop the first response. + sendPathResponse pathResponseType + data pathChallengeData +} + +// pathChallengeData is data carried in a PATH_CHALLENGE or PATH_RESPONSE frame. +type pathChallengeData [64 / 8]byte + +type pathResponseType uint8 + +const ( + pathResponseNotNeeded = pathResponseType(iota) + pathResponseSmall // send PATH_RESPONSE, do not expand datagram + pathResponseExpanded // send PATH_RESPONSE, expand datagram to 1200 bytes +) + +func (c *Conn) handlePathChallenge(_ time.Time, dgram *datagram, data pathChallengeData) { + // A PATH_RESPONSE is sent in a datagram expanded to 1200 bytes, + // except when this would exceed the anti-amplification limit. + // + // Rather than maintaining anti-amplification state for each path + // we may be sending a PATH_RESPONSE on, follow the following heuristic: + // + // If we receive a PATH_CHALLENGE in an expanded datagram, + // respond with an expanded datagram. + // + // If we receive a PATH_CHALLENGE in a non-expanded datagram, + // then the peer is presumably blocked by its own anti-amplification limit. + // Respond with a non-expanded datagram. Receiving this PATH_RESPONSE + // will validate the path to the peer, remove its anti-amplification limit, + // and permit it to send a followup PATH_CHALLENGE in an expanded datagram. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-8.2.1 + if len(dgram.b) >= smallestMaxDatagramSize { + c.path.sendPathResponse = pathResponseExpanded + } else { + c.path.sendPathResponse = pathResponseSmall + } + c.path.data = data +} + +func (c *Conn) handlePathResponse(now time.Time, _ pathChallengeData) { + // "If the content of a PATH_RESPONSE frame does not match the content of + // a PATH_CHALLENGE frame previously sent by the endpoint, + // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4 + // + // We never send PATH_CHALLENGE frames. + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "PATH_RESPONSE received when no PATH_CHALLENGE sent", + }) +} + +// appendPathFrames appends path validation related frames to the current packet. +// If the return value pad is true, then the packet should be padded to 1200 bytes. +func (c *Conn) appendPathFrames() (pad, ok bool) { + if c.path.sendPathResponse == pathResponseNotNeeded { + return pad, true + } + // We're required to send the PATH_RESPONSE on the path where the + // PATH_CHALLENGE was received (RFC 9000, Section 8.2.2). + // + // At the moment, we don't support path migration and reject packets if + // the peer changes its source address, so just sending the PATH_RESPONSE + // in a regular datagram is fine. + if !c.w.appendPathResponseFrame(c.path.data) { + return pad, false + } + if c.path.sendPathResponse == pathResponseExpanded { + pad = true + } + c.path.sendPathResponse = pathResponseNotNeeded + return pad, true +} diff --git a/src/vendor/golang.org/x/net/quic/ping.go b/src/vendor/golang.org/x/net/quic/ping.go new file mode 100644 index 0000000000..e604f014bf --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/ping.go @@ -0,0 +1,14 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "time" + +func (c *Conn) ping(space numberSpace) { + c.sendMsg(func(now time.Time, c *Conn) { + c.testSendPing.setUnsent() + c.testSendPingSpace = space + }) +} diff --git a/src/vendor/golang.org/x/net/quic/pipe.go b/src/vendor/golang.org/x/net/quic/pipe.go new file mode 100644 index 0000000000..2ae651ea3e --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/pipe.go @@ -0,0 +1,173 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "sync" +) + +// A pipe is a byte buffer used in implementing streams. +// +// A pipe contains a window of stream data. +// Random access reads and writes are supported within the window. +// Writing past the end of the window extends it. +// Data may be discarded from the start of the pipe, advancing the window. +type pipe struct { + start int64 // stream position of first stored byte + end int64 // stream position just past the last stored byte + head *pipebuf // if non-nil, then head.off + len(head.b) > start + tail *pipebuf // if non-nil, then tail.off + len(tail.b) == end +} + +type pipebuf struct { + off int64 // stream position of b[0] + b []byte + next *pipebuf +} + +func (pb *pipebuf) end() int64 { + return pb.off + int64(len(pb.b)) +} + +var pipebufPool = sync.Pool{ + New: func() any { + return &pipebuf{ + b: make([]byte, 4096), + } + }, +} + +func newPipebuf() *pipebuf { + return pipebufPool.Get().(*pipebuf) +} + +func (b *pipebuf) recycle() { + b.off = 0 + b.next = nil + pipebufPool.Put(b) +} + +// writeAt writes len(b) bytes to the pipe at offset off. +// +// Writes to offsets before p.start are discarded. +// Writes to offsets after p.end extend the pipe window. +func (p *pipe) writeAt(b []byte, off int64) { + end := off + int64(len(b)) + if end > p.end { + p.end = end + } else if end <= p.start { + return + } + + if off < p.start { + // Discard the portion of b which falls before p.start. + trim := p.start - off + b = b[trim:] + off = p.start + } + + if p.head == nil { + p.head = newPipebuf() + p.head.off = p.start + p.tail = p.head + } + pb := p.head + if off >= p.tail.off { + // Common case: Writing past the end of the pipe. + pb = p.tail + } + for { + pboff := off - pb.off + if pboff < int64(len(pb.b)) { + n := copy(pb.b[pboff:], b) + if n == len(b) { + return + } + off += int64(n) + b = b[n:] + } + if pb.next == nil { + pb.next = newPipebuf() + pb.next.off = pb.off + int64(len(pb.b)) + p.tail = pb.next + } + pb = pb.next + } +} + +// copy copies len(b) bytes into b starting from off. +// The pipe must contain [off, off+len(b)). +func (p *pipe) copy(off int64, b []byte) { + dst := b[:0] + p.read(off, len(b), func(c []byte) error { + dst = append(dst, c...) + return nil + }) +} + +// read calls f with the data in [off, off+n) +// The data may be provided sequentially across multiple calls to f. +// Note that read (unlike an io.Reader) does not consume the read data. +func (p *pipe) read(off int64, n int, f func([]byte) error) error { + if off < p.start { + panic("invalid read range") + } + for pb := p.head; pb != nil && n > 0; pb = pb.next { + if off >= pb.end() { + continue + } + b := pb.b[off-pb.off:] + if len(b) > n { + b = b[:n] + } + off += int64(len(b)) + n -= len(b) + if err := f(b); err != nil { + return err + } + } + if n > 0 { + panic("invalid read range") + } + return nil +} + +// peek returns a reference to up to n bytes of internal data buffer, starting at p.start. +// The returned slice is valid until the next call to discardBefore. +// The length of the returned slice will be in the range [0,n]. +func (p *pipe) peek(n int64) []byte { + pb := p.head + if pb == nil { + return nil + } + b := pb.b[p.start-pb.off:] + return b[:min(int64(len(b)), n)] +} + +// availableBuffer returns the available contiguous, allocated buffer space +// following the pipe window. +// +// This is used by the stream write fast path, which makes multiple writes into the pipe buffer +// without a lock, and then adjusts p.end at a later time with a lock held. +func (p *pipe) availableBuffer() []byte { + if p.tail == nil { + return nil + } + return p.tail.b[p.end-p.tail.off:] +} + +// discardBefore discards all data prior to off. +func (p *pipe) discardBefore(off int64) { + for p.head != nil && p.head.end() < off { + head := p.head + p.head = p.head.next + head.recycle() + } + if p.head == nil { + p.tail = nil + } + p.start = off + p.end = max(p.end, off) +} diff --git a/src/vendor/golang.org/x/net/quic/qlog.go b/src/vendor/golang.org/x/net/quic/qlog.go new file mode 100644 index 0000000000..5d2fd0fc1e --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/qlog.go @@ -0,0 +1,272 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "encoding/hex" + "log/slog" + "net/netip" + "time" +) + +// Log levels for qlog events. +const ( + // QLogLevelFrame includes per-frame information. + // When this level is enabled, packet_sent and packet_received events will + // contain information on individual frames sent/received. + QLogLevelFrame = slog.Level(-6) + + // QLogLevelPacket events occur at most once per packet sent or received. + // + // For example: packet_sent, packet_received. + QLogLevelPacket = slog.Level(-4) + + // QLogLevelConn events occur multiple times over a connection's lifetime, + // but less often than the frequency of individual packets. + // + // For example: connection_state_updated. + QLogLevelConn = slog.Level(-2) + + // QLogLevelEndpoint events occur at most once per connection. + // + // For example: connection_started, connection_closed. + QLogLevelEndpoint = slog.Level(0) +) + +func (c *Conn) logEnabled(level slog.Level) bool { + return logEnabled(c.log, level) +} + +func logEnabled(log *slog.Logger, level slog.Level) bool { + return log != nil && log.Enabled(context.Background(), level) +} + +// slogHexstring returns a slog.Attr for a value of the hexstring type. +// +// https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-1.1.1 +func slogHexstring(key string, value []byte) slog.Attr { + return slog.String(key, hex.EncodeToString(value)) +} + +func slogAddr(key string, value netip.Addr) slog.Attr { + return slog.String(key, value.String()) +} + +func (c *Conn) logConnectionStarted(originalDstConnID []byte, peerAddr netip.AddrPort) { + if c.config.QLogLogger == nil || + !c.config.QLogLogger.Enabled(context.Background(), QLogLevelEndpoint) { + return + } + var vantage string + if c.side == clientSide { + vantage = "client" + originalDstConnID = c.connIDState.originalDstConnID + } else { + vantage = "server" + } + // A qlog Trace container includes some metadata (title, description, vantage_point) + // and a list of Events. The Trace also includes a common_fields field setting field + // values common to all events in the trace. + // + // Trace = { + // ? title: text + // ? description: text + // ? configuration: Configuration + // ? common_fields: CommonFields + // ? vantage_point: VantagePoint + // events: [* Event] + // } + // + // To map this into slog's data model, we start each per-connection trace with a With + // call that includes both the trace metadata and the common fields. + // + // This means that in slog's model, each trace event will also include + // the Trace metadata fields (vantage_point), which is a divergence from the qlog model. + c.log = c.config.QLogLogger.With( + // The group_id permits associating traces taken from different vantage points + // for the same connection. + // + // We use the original destination connection ID as the group ID. + // + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-04.html#section-3.4.6 + slogHexstring("group_id", originalDstConnID), + slog.Group("vantage_point", + slog.String("name", "go quic"), + slog.String("type", vantage), + ), + ) + localAddr := c.endpoint.LocalAddr() + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.2 + c.log.LogAttrs(context.Background(), QLogLevelEndpoint, + "connectivity:connection_started", + slogAddr("src_ip", localAddr.Addr()), + slog.Int("src_port", int(localAddr.Port())), + slogHexstring("src_cid", c.connIDState.local[0].cid), + slogAddr("dst_ip", peerAddr.Addr()), + slog.Int("dst_port", int(peerAddr.Port())), + slogHexstring("dst_cid", c.connIDState.remote[0].cid), + ) +} + +func (c *Conn) logConnectionClosed() { + if !c.logEnabled(QLogLevelEndpoint) { + return + } + err := c.lifetime.finalErr + trigger := "error" + switch e := err.(type) { + case *ApplicationError: + // TODO: Distinguish between peer and locally-initiated close. + trigger = "application" + case localTransportError: + switch err { + case errHandshakeTimeout: + trigger = "handshake_timeout" + default: + if e.code == errNo { + trigger = "clean" + } + } + case peerTransportError: + if e.code == errNo { + trigger = "clean" + } + default: + switch err { + case errIdleTimeout: + trigger = "idle_timeout" + case errStatelessReset: + trigger = "stateless_reset" + } + } + // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-quic-events-03.html#section-4.3 + c.log.LogAttrs(context.Background(), QLogLevelEndpoint, + "connectivity:connection_closed", + slog.String("trigger", trigger), + ) +} + +func (c *Conn) logPacketDropped(dgram *datagram) { + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "connectivity:packet_dropped", + ) +} + +func (c *Conn) logLongPacketReceived(p longPacket, pkt []byte) { + var frames slog.Attr + if c.logEnabled(QLogLevelFrame) { + frames = c.packetFramesAttr(p.payload) + } + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "transport:packet_received", + slog.Group("header", + slog.String("packet_type", p.ptype.qlogString()), + slog.Uint64("packet_number", uint64(p.num)), + slog.Uint64("flags", uint64(pkt[0])), + slogHexstring("scid", p.srcConnID), + slogHexstring("dcid", p.dstConnID), + ), + slog.Group("raw", + slog.Int("length", len(pkt)), + ), + frames, + ) +} + +func (c *Conn) log1RTTPacketReceived(p shortPacket, pkt []byte) { + var frames slog.Attr + if c.logEnabled(QLogLevelFrame) { + frames = c.packetFramesAttr(p.payload) + } + dstConnID, _ := dstConnIDForDatagram(pkt) + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "transport:packet_received", + slog.Group("header", + slog.String("packet_type", packetType1RTT.qlogString()), + slog.Uint64("packet_number", uint64(p.num)), + slog.Uint64("flags", uint64(pkt[0])), + slogHexstring("dcid", dstConnID), + ), + slog.Group("raw", + slog.Int("length", len(pkt)), + ), + frames, + ) +} + +func (c *Conn) logPacketSent(ptype packetType, pnum packetNumber, src, dst []byte, pktLen int, payload []byte) { + var frames slog.Attr + if c.logEnabled(QLogLevelFrame) { + frames = c.packetFramesAttr(payload) + } + var scid slog.Attr + if len(src) > 0 { + scid = slogHexstring("scid", src) + } + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "transport:packet_sent", + slog.Group("header", + slog.String("packet_type", ptype.qlogString()), + slog.Uint64("packet_number", uint64(pnum)), + scid, + slogHexstring("dcid", dst), + ), + slog.Group("raw", + slog.Int("length", pktLen), + ), + frames, + ) +} + +// packetFramesAttr returns the "frames" attribute containing the frames in a packet. +// We currently pass this as a slog Any containing a []slog.Value, +// where each Value is a debugFrame that implements slog.LogValuer. +// +// This isn't tremendously efficient, but avoids the need to put a JSON encoder +// in the quic package or a frame parser in the qlog package. +func (c *Conn) packetFramesAttr(payload []byte) slog.Attr { + var frames []slog.Value + for len(payload) > 0 { + f, n := parseDebugFrame(payload) + if n < 0 { + break + } + payload = payload[n:] + switch f := f.(type) { + case debugFrameAck: + // The qlog ACK frame contains the ACK Delay field as a duration. + // Interpreting the contents of this field as a duration requires + // knowing the peer's ack_delay_exponent transport parameter, + // and it's possible for us to parse an ACK frame before we've + // received that parameter. + // + // We could plumb connection state down into the frame parser, + // but for now let's minimize the amount of code that needs to + // deal with this and convert the unscaled value into a scaled one here. + ackDelay := time.Duration(-1) + if c.peerAckDelayExponent >= 0 { + ackDelay = f.ackDelay.Duration(uint8(c.peerAckDelayExponent)) + } + frames = append(frames, slog.AnyValue(debugFrameScaledAck{ + ranges: f.ranges, + ackDelay: ackDelay, + })) + default: + frames = append(frames, slog.AnyValue(f)) + } + } + return slog.Any("frames", frames) +} + +func (c *Conn) logPacketLost(space numberSpace, sent *sentPacket) { + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:packet_lost", + slog.Group("header", + slog.String("packet_type", sent.ptype.qlogString()), + slog.Uint64("packet_number", uint64(sent.num)), + ), + ) +} diff --git a/src/vendor/golang.org/x/net/quic/queue.go b/src/vendor/golang.org/x/net/quic/queue.go new file mode 100644 index 0000000000..f2712f4012 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/queue.go @@ -0,0 +1,63 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "context" + +// A queue is an unbounded queue of some item (new connections and streams). +type queue[T any] struct { + // The gate condition is set if the queue is non-empty or closed. + gate gate + err error + q []T +} + +func newQueue[T any]() queue[T] { + return queue[T]{gate: newGate()} +} + +// close closes the queue, causing pending and future pop operations +// to return immediately with err. +func (q *queue[T]) close(err error) { + q.gate.lock() + defer q.unlock() + if q.err == nil { + q.err = err + } +} + +// put appends an item to the queue. +// It returns true if the item was added, false if the queue is closed. +func (q *queue[T]) put(v T) bool { + q.gate.lock() + defer q.unlock() + if q.err != nil { + return false + } + q.q = append(q.q, v) + return true +} + +// get removes the first item from the queue, blocking until ctx is done, an item is available, +// or the queue is closed. +func (q *queue[T]) get(ctx context.Context) (T, error) { + var zero T + if err := q.gate.waitAndLock(ctx); err != nil { + return zero, err + } + defer q.unlock() + if q.err != nil { + return zero, q.err + } + v := q.q[0] + copy(q.q[:], q.q[1:]) + q.q[len(q.q)-1] = zero + q.q = q.q[:len(q.q)-1] + return v, nil +} + +func (q *queue[T]) unlock() { + q.gate.unlock(q.err != nil || len(q.q) > 0) +} diff --git a/src/vendor/golang.org/x/net/quic/quic.go b/src/vendor/golang.org/x/net/quic/quic.go new file mode 100644 index 0000000000..26256bf422 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/quic.go @@ -0,0 +1,221 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// QUIC versions. +// We only support v1 at this time. +const ( + quicVersion1 = 1 + quicVersion2 = 0x6b3343cf // https://www.rfc-editor.org/rfc/rfc9369 +) + +// connIDLen is the length in bytes of connection IDs chosen by this package. +// Since 1-RTT packets don't include a connection ID length field, +// we use a consistent length for all our IDs. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1-6 +const connIDLen = 8 + +// Local values of various transport parameters. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2 +const ( + defaultMaxIdleTimeout = 30 * time.Second // max_idle_timeout + + // The max_udp_payload_size transport parameter is the size of our + // network receive buffer. + // + // Set this to the largest UDP packet that can be sent over + // Ethernet without using jumbo frames: 1500 byte Ethernet frame, + // minus 20 byte IPv4 header and 8 byte UDP header. + // + // The maximum possible UDP payload is 65527 bytes. Supporting this + // without wasting memory in unused receive buffers will require some + // care. For now, just limit ourselves to the most common case. + maxUDPPayloadSize = 1472 + + ackDelayExponent = 3 // ack_delay_exponent + maxAckDelay = 25 * time.Millisecond // max_ack_delay + + // The active_conn_id_limit transport parameter is the maximum + // number of connection IDs from the peer we're willing to store. + // + // maxPeerActiveConnIDLimit is the maximum number of connection IDs + // we're willing to send to the peer. + // + // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-6.2.1 + activeConnIDLimit = 2 + maxPeerActiveConnIDLimit = 4 +) + +// Time limit for completing the handshake. +const defaultHandshakeTimeout = 10 * time.Second + +// Keep-alive ping frequency. +const defaultKeepAlivePeriod = 0 + +// Local timer granularity. +// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.1.2-6 +const timerGranularity = 1 * time.Millisecond + +// The smallest allowed maximum datagram size. +// https://www.rfc-editor.org/rfc/rfc9000#section-14 +const smallestMaxDatagramSize = 1200 + +// Minimum size of a UDP datagram sent by a client carrying an Initial packet, +// or a server containing an ack-eliciting Initial packet. +// https://www.rfc-editor.org/rfc/rfc9000#section-14.1 +const paddedInitialDatagramSize = smallestMaxDatagramSize + +// Maximum number of streams of a given type which may be created. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-4.6-2 +const maxStreamsLimit = 1 << 60 + +// Maximum number of streams we will allow the peer to create implicitly. +// A stream ID that is used out of order results in all streams of that type +// with lower-numbered IDs also being opened. To limit the amount of work we +// will do in response to a single frame, we cap the peer's stream limit to +// this value. +const implicitStreamLimit = 100 + +// A connSide distinguishes between the client and server sides of a connection. +type connSide int8 + +const ( + clientSide = connSide(iota) + serverSide +) + +func (s connSide) String() string { + switch s { + case clientSide: + return "client" + case serverSide: + return "server" + default: + return "BUG" + } +} + +func (s connSide) peer() connSide { + if s == clientSide { + return serverSide + } else { + return clientSide + } +} + +// A numberSpace is the context in which a packet number applies. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-12.3-7 +type numberSpace byte + +const ( + initialSpace = numberSpace(iota) + handshakeSpace + appDataSpace + numberSpaceCount +) + +func (n numberSpace) String() string { + switch n { + case initialSpace: + return "Initial" + case handshakeSpace: + return "Handshake" + case appDataSpace: + return "AppData" + default: + return "BUG" + } +} + +// A streamType is the type of a stream: bidirectional or unidirectional. +type streamType uint8 + +const ( + bidiStream = streamType(iota) + uniStream + streamTypeCount +) + +func (s streamType) qlogString() string { + switch s { + case bidiStream: + return "bidirectional" + case uniStream: + return "unidirectional" + default: + return "BUG" + } +} + +func (s streamType) String() string { + switch s { + case bidiStream: + return "bidi" + case uniStream: + return "uni" + default: + return "BUG" + } +} + +// A streamID is a QUIC stream ID. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-2.1 +type streamID uint64 + +// The two least significant bits of a stream ID indicate the initiator +// and directionality of the stream. The upper bits are the stream number. +// Each of the four possible combinations of initiator and direction +// each has a distinct number space. +const ( + clientInitiatedStreamBit = 0x0 + serverInitiatedStreamBit = 0x1 + initiatorStreamBitMask = 0x1 + + bidiStreamBit = 0x0 + uniStreamBit = 0x2 + dirStreamBitMask = 0x2 +) + +func newStreamID(initiator connSide, typ streamType, num int64) streamID { + id := streamID(num << 2) + if typ == uniStream { + id |= uniStreamBit + } + if initiator == serverSide { + id |= serverInitiatedStreamBit + } + return id +} + +func (s streamID) initiator() connSide { + if s&initiatorStreamBitMask == serverInitiatedStreamBit { + return serverSide + } + return clientSide +} + +func (s streamID) num() int64 { + return int64(s) >> 2 +} + +func (s streamID) streamType() streamType { + if s&dirStreamBitMask == uniStreamBit { + return uniStream + } + return bidiStream +} + +// packetFate is the fate of a sent packet: Either acknowledged by the peer, +// or declared lost. +type packetFate byte + +const ( + packetLost = packetFate(iota) + packetAcked +) diff --git a/src/vendor/golang.org/x/net/quic/rangeset.go b/src/vendor/golang.org/x/net/quic/rangeset.go new file mode 100644 index 0000000000..3d6f5f9799 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/rangeset.go @@ -0,0 +1,193 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A rangeset is a set of int64s, stored as an ordered list of non-overlapping, +// non-empty ranges. +// +// Rangesets are efficient for small numbers of ranges, +// which is expected to be the common case. +type rangeset[T ~int64] []i64range[T] + +type i64range[T ~int64] struct { + start, end T // [start, end) +} + +// size returns the size of the range. +func (r i64range[T]) size() T { + return r.end - r.start +} + +// contains reports whether v is in the range. +func (r i64range[T]) contains(v T) bool { + return r.start <= v && v < r.end +} + +// add adds [start, end) to the set, combining it with existing ranges if necessary. +func (s *rangeset[T]) add(start, end T) { + if start == end { + return + } + for i := range *s { + r := &(*s)[i] + if r.start > end { + // The new range comes before range i. + s.insertrange(i, start, end) + return + } + if start > r.end { + // The new range comes after range i. + continue + } + // The new range is adjacent to or overlapping range i. + if start < r.start { + r.start = start + } + if end <= r.end { + return + } + // Possibly coalesce subsequent ranges into range i. + r.end = end + j := i + 1 + for ; j < len(*s) && r.end >= (*s)[j].start; j++ { + if e := (*s)[j].end; e > r.end { + // Range j ends after the new range. + r.end = e + } + } + s.removeranges(i+1, j) + return + } + *s = append(*s, i64range[T]{start, end}) +} + +// sub removes [start, end) from the set. +func (s *rangeset[T]) sub(start, end T) { + removefrom, removeto := -1, -1 + for i := range *s { + r := &(*s)[i] + if end < r.start { + break + } + if r.end < start { + continue + } + switch { + case start <= r.start && end >= r.end: + // Remove the entire range. + if removefrom == -1 { + removefrom = i + } + removeto = i + 1 + case start <= r.start: + // Remove a prefix. + r.start = end + case end >= r.end: + // Remove a suffix. + r.end = start + default: + // Remove the middle, leaving two new ranges. + rend := r.end + r.end = start + s.insertrange(i+1, end, rend) + return + } + } + if removefrom != -1 { + s.removeranges(removefrom, removeto) + } +} + +// contains reports whether s contains v. +func (s rangeset[T]) contains(v T) bool { + for _, r := range s { + if v >= r.end { + continue + } + if r.start <= v { + return true + } + return false + } + return false +} + +// rangeContaining returns the range containing v, or the range [0,0) if v is not in s. +func (s rangeset[T]) rangeContaining(v T) i64range[T] { + for _, r := range s { + if v >= r.end { + continue + } + if r.start <= v { + return r + } + break + } + return i64range[T]{0, 0} +} + +// min returns the minimum value in the set, or 0 if empty. +func (s rangeset[T]) min() T { + if len(s) == 0 { + return 0 + } + return s[0].start +} + +// max returns the maximum value in the set, or 0 if empty. +func (s rangeset[T]) max() T { + if len(s) == 0 { + return 0 + } + return s[len(s)-1].end - 1 +} + +// end returns the end of the last range in the set, or 0 if empty. +func (s rangeset[T]) end() T { + if len(s) == 0 { + return 0 + } + return s[len(s)-1].end +} + +// numRanges returns the number of ranges in the rangeset. +func (s rangeset[T]) numRanges() int { + return len(s) +} + +// size returns the size of all ranges in the rangeset. +func (s rangeset[T]) size() (total T) { + for _, r := range s { + total += r.size() + } + return total +} + +// isrange reports if the rangeset covers exactly the range [start, end). +func (s rangeset[T]) isrange(start, end T) bool { + switch len(s) { + case 0: + return start == 0 && end == 0 + case 1: + return s[0].start == start && s[0].end == end + } + return false +} + +// removeranges removes ranges [i,j). +func (s *rangeset[T]) removeranges(i, j int) { + if i == j { + return + } + copy((*s)[i:], (*s)[j:]) + *s = (*s)[:len(*s)-(j-i)] +} + +// insert adds a new range at index i. +func (s *rangeset[T]) insertrange(i int, start, end T) { + *s = append(*s, i64range[T]{}) + copy((*s)[i+1:], (*s)[i:]) + (*s)[i] = i64range[T]{start, end} +} diff --git a/src/vendor/golang.org/x/net/quic/retry.go b/src/vendor/golang.org/x/net/quic/retry.go new file mode 100644 index 0000000000..0392ca9159 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/retry.go @@ -0,0 +1,240 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "net/netip" + "time" + + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/internal/quic/quicwire" +) + +// AEAD and nonce used to compute the Retry Integrity Tag. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.8 +var ( + retrySecret = []byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e} + retryNonce = []byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} + retryAEAD = func() cipher.AEAD { + c, err := aes.NewCipher(retrySecret) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(c) + if err != nil { + panic(err) + } + return aead + }() +) + +// retryTokenValidityPeriod is how long we accept a Retry packet token after sending it. +const retryTokenValidityPeriod = 5 * time.Second + +// retryState generates and validates an endpoint's retry tokens. +type retryState struct { + aead cipher.AEAD +} + +func (rs *retryState) init() error { + // Retry tokens are authenticated using a per-server key chosen at start time. + // TODO: Provide a way for the user to set this key. + secret := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(secret); err != nil { + return err + } + aead, err := chacha20poly1305.NewX(secret) + if err != nil { + panic(err) + } + rs.aead = aead + return nil +} + +// Retry tokens are encrypted with an AEAD. +// The plaintext contains the time the token was created and +// the original destination connection ID. +// The additional data contains the sender's source address and original source connection ID. +// The token nonce is randomly generated. +// We use the nonce as the Source Connection ID of the Retry packet. +// Since the 24-byte XChaCha20-Poly1305 nonce is too large to fit in a 20-byte connection ID, +// we include the remaining 4 bytes of nonce in the token. +// +// Token { +// Last 4 Bytes of Nonce (32), +// Ciphertext (..), +// } +// +// Plaintext { +// Timestamp (64), +// Original Destination Connection ID, +// } +// +// +// Additional Data { +// Original Source Connection ID Length (8), +// Original Source Connection ID (..), +// IP Address (32..128), +// Port (16), +// } +// +// TODO: Consider using AES-256-GCM-SIV once crypto/tls supports it. + +func (rs *retryState) makeToken(now time.Time, srcConnID, origDstConnID []byte, addr netip.AddrPort) (token, newDstConnID []byte, err error) { + nonce := make([]byte, rs.aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, nil, err + } + + var plaintext []byte + plaintext = binary.BigEndian.AppendUint64(plaintext, uint64(now.Unix())) + plaintext = append(plaintext, origDstConnID...) + + token = append(token, nonce[maxConnIDLen:]...) + token = rs.aead.Seal(token, nonce, plaintext, rs.additionalData(srcConnID, addr)) + return token, nonce[:maxConnIDLen], nil +} + +func (rs *retryState) validateToken(now time.Time, token, srcConnID, dstConnID []byte, addr netip.AddrPort) (origDstConnID []byte, ok bool) { + tokenNonceLen := rs.aead.NonceSize() - maxConnIDLen + if len(token) < tokenNonceLen { + return nil, false + } + nonce := append([]byte{}, dstConnID...) + nonce = append(nonce, token[:tokenNonceLen]...) + ciphertext := token[tokenNonceLen:] + if len(nonce) != rs.aead.NonceSize() { + return nil, false + } + + plaintext, err := rs.aead.Open(nil, nonce, ciphertext, rs.additionalData(srcConnID, addr)) + if err != nil { + return nil, false + } + if len(plaintext) < 8 { + return nil, false + } + when := time.Unix(int64(binary.BigEndian.Uint64(plaintext)), 0) + origDstConnID = plaintext[8:] + + // We allow for tokens created in the future (up to the validity period), + // which likely indicates that the system clock was adjusted backwards. + if d := abs(now.Sub(when)); d > retryTokenValidityPeriod { + return nil, false + } + + return origDstConnID, true +} + +func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []byte { + var additional []byte + additional = quicwire.AppendUint8Bytes(additional, srcConnID) + additional = append(additional, addr.Addr().AsSlice()...) + additional = binary.BigEndian.AppendUint16(additional, addr.Port()) + return additional +} + +func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) (origDstConnID []byte, ok bool) { + // The retry token is at the start of an Initial packet's data. + token, n := quicwire.ConsumeUint8Bytes(p.data) + if n < 0 { + // We've already validated that the packet is at least 1200 bytes long, + // so there's no way for even a maximum size token to not fit. + // Check anyway. + return nil, false + } + if len(token) == 0 { + // The sender has not provided a token. + // Send a Retry packet to them with one. + e.sendRetry(now, p, peerAddr) + return nil, false + } + origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, peerAddr) + if !ok { + // This does not seem to be a valid token. + // Close the connection with an INVALID_TOKEN error. + // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5 + e.sendConnectionClose(p, peerAddr, errInvalidToken) + return nil, false + } + return origDstConnID, true +} + +func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) { + token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, peerAddr) + if err != nil { + return + } + b := encodeRetryPacket(p.dstConnID, retryPacket{ + dstConnID: p.srcConnID, + srcConnID: srcConnID, + token: token, + }) + e.sendDatagram(datagram{ + b: b, + peerAddr: peerAddr, + }) +} + +type retryPacket struct { + dstConnID []byte + srcConnID []byte + token []byte +} + +func encodeRetryPacket(originalDstConnID []byte, p retryPacket) []byte { + // Retry packets include an integrity tag, computed by AEAD_AES_128_GCM over + // the original destination connection ID followed by the Retry packet + // (less the integrity tag itself). + // https://www.rfc-editor.org/rfc/rfc9001#section-5.8 + // + // Create the pseudo-packet (including the original DCID), append the tag, + // and return the Retry packet. + var b []byte + b = quicwire.AppendUint8Bytes(b, originalDstConnID) // Original Destination Connection ID + start := len(b) // start of the Retry packet + b = append(b, headerFormLong|fixedBit|longPacketTypeRetry) + b = binary.BigEndian.AppendUint32(b, quicVersion1) // Version + b = quicwire.AppendUint8Bytes(b, p.dstConnID) // Destination Connection ID + b = quicwire.AppendUint8Bytes(b, p.srcConnID) // Source Connection ID + b = append(b, p.token...) // Token + b = retryAEAD.Seal(b, retryNonce, nil, b) // Retry Integrity Tag + return b[start:] +} + +func parseRetryPacket(b, origDstConnID []byte) (p retryPacket, ok bool) { + const retryIntegrityTagLength = 128 / 8 + + lp, ok := parseGenericLongHeaderPacket(b) + if !ok { + return retryPacket{}, false + } + if len(lp.data) < retryIntegrityTagLength { + return retryPacket{}, false + } + gotTag := lp.data[len(lp.data)-retryIntegrityTagLength:] + + // Create the pseudo-packet consisting of the original destination connection ID + // followed by the Retry packet (less the integrity tag). + // Use this to validate the packet integrity tag. + pseudo := quicwire.AppendUint8Bytes(nil, origDstConnID) + pseudo = append(pseudo, b[:len(b)-retryIntegrityTagLength]...) + wantTag := retryAEAD.Seal(nil, retryNonce, nil, pseudo) + if !bytes.Equal(gotTag, wantTag) { + return retryPacket{}, false + } + + token := lp.data[:len(lp.data)-retryIntegrityTagLength] + return retryPacket{ + dstConnID: lp.dstConnID, + srcConnID: lp.srcConnID, + token: token, + }, true +} diff --git a/src/vendor/golang.org/x/net/quic/rtt.go b/src/vendor/golang.org/x/net/quic/rtt.go new file mode 100644 index 0000000000..0dc9bf5bf2 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/rtt.go @@ -0,0 +1,71 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +type rttState struct { + minRTT time.Duration + latestRTT time.Duration + smoothedRTT time.Duration + rttvar time.Duration // RTT variation + firstSampleTime time.Time // time of first RTT sample +} + +func (r *rttState) init() { + r.minRTT = -1 // -1 indicates the first sample has not been taken yet + + // "[...] the initial RTT SHOULD be set to 333 milliseconds." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2-1 + const initialRTT = 333 * time.Millisecond + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-12 + r.smoothedRTT = initialRTT + r.rttvar = initialRTT / 2 +} + +func (r *rttState) establishPersistentCongestion() { + // "Endpoints SHOULD set the min_rtt to the newest RTT sample + // after persistent congestion is established." + // https://www.rfc-editor.org/rfc/rfc9002#section-5.2-5 + r.minRTT = r.latestRTT +} + +// updateSample is called when we generate a new RTT sample. +// https://www.rfc-editor.org/rfc/rfc9002.html#section-5 +func (r *rttState) updateSample(now time.Time, handshakeConfirmed bool, spaceID numberSpace, latestRTT, ackDelay, maxAckDelay time.Duration) { + r.latestRTT = latestRTT + + if r.minRTT < 0 { + // First RTT sample. + // "min_rtt MUST be set to the latest_rtt on the first RTT sample." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2 + r.minRTT = latestRTT + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-14 + r.smoothedRTT = latestRTT + r.rttvar = latestRTT / 2 + r.firstSampleTime = now + return + } + + // "min_rtt MUST be set to the lesser of min_rtt and latest_rtt [...] + // on all other samples." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2 + r.minRTT = min(r.minRTT, latestRTT) + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-16 + if handshakeConfirmed { + ackDelay = min(ackDelay, maxAckDelay) + } + adjustedRTT := latestRTT - ackDelay + if adjustedRTT < r.minRTT { + adjustedRTT = latestRTT + } + rttvarSample := abs(r.smoothedRTT - adjustedRTT) + r.rttvar = (3*r.rttvar + rttvarSample) / 4 + r.smoothedRTT = ((7 * r.smoothedRTT) + adjustedRTT) / 8 +} diff --git a/src/vendor/golang.org/x/net/quic/sent_packet.go b/src/vendor/golang.org/x/net/quic/sent_packet.go new file mode 100644 index 0000000000..f67606b353 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/sent_packet.go @@ -0,0 +1,119 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "sync" + "time" + + "golang.org/x/net/internal/quic/quicwire" +) + +// A sentPacket tracks state related to an in-flight packet we sent, +// to be committed when the peer acks it or resent if the packet is lost. +type sentPacket struct { + num packetNumber + size int // size in bytes + time time.Time // time sent + ptype packetType + + state sentPacketState + ackEliciting bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.4.1 + inFlight bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.6.1 + + // Frames sent in the packet. + // + // This is an abbreviated version of the packet payload, containing only the information + // we need to process an ack for or loss of this packet. + // For example, a CRYPTO frame is recorded as the frame type (0x06), offset, and length, + // but does not include the sent data. + // + // This buffer is written by packetWriter.append* and read by Conn.handleAckOrLoss. + b []byte + n int // read offset into b +} + +type sentPacketState uint8 + +const ( + sentPacketSent = sentPacketState(iota) // sent but neither acked nor lost + sentPacketAcked // acked + sentPacketLost // declared lost + sentPacketUnsent // never sent +) + +var sentPool = sync.Pool{ + New: func() any { + return &sentPacket{} + }, +} + +func newSentPacket() *sentPacket { + sent := sentPool.Get().(*sentPacket) + sent.reset() + return sent +} + +// recycle returns a sentPacket to the pool. +func (sent *sentPacket) recycle() { + sentPool.Put(sent) +} + +func (sent *sentPacket) reset() { + *sent = sentPacket{ + b: sent.b[:0], + } +} + +// markAckEliciting marks the packet as containing an ack-eliciting frame. +func (sent *sentPacket) markAckEliciting() { + sent.ackEliciting = true + sent.inFlight = true +} + +// The append* methods record information about frames in the packet. + +func (sent *sentPacket) appendNonAckElicitingFrame(frameType byte) { + sent.b = append(sent.b, frameType) +} + +func (sent *sentPacket) appendAckElicitingFrame(frameType byte) { + sent.ackEliciting = true + sent.inFlight = true + sent.b = append(sent.b, frameType) +} + +func (sent *sentPacket) appendInt(v uint64) { + sent.b = quicwire.AppendVarint(sent.b, v) +} + +func (sent *sentPacket) appendOffAndSize(start int64, size int) { + sent.b = quicwire.AppendVarint(sent.b, uint64(start)) + sent.b = quicwire.AppendVarint(sent.b, uint64(size)) +} + +// The next* methods read back information about frames in the packet. + +func (sent *sentPacket) next() (frameType byte) { + f := sent.b[sent.n] + sent.n++ + return f +} + +func (sent *sentPacket) nextInt() uint64 { + v, n := quicwire.ConsumeVarint(sent.b[sent.n:]) + sent.n += n + return v +} + +func (sent *sentPacket) nextRange() (start, end int64) { + start = int64(sent.nextInt()) + end = start + int64(sent.nextInt()) + return start, end +} + +func (sent *sentPacket) done() bool { + return sent.n == len(sent.b) +} diff --git a/src/vendor/golang.org/x/net/quic/sent_packet_list.go b/src/vendor/golang.org/x/net/quic/sent_packet_list.go new file mode 100644 index 0000000000..04116f2109 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/sent_packet_list.go @@ -0,0 +1,93 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A sentPacketList is a ring buffer of sentPackets. +// +// Processing an ack for a packet causes all older packets past a small threshold +// to be discarded (RFC 9002, Section 6.1.1), so the list of in-flight packets is +// not sparse and will contain at most a few acked/lost packets we no longer +// care about. +type sentPacketList struct { + nextNum packetNumber // next packet number to add to the buffer + off int // offset of first packet in the buffer + size int // number of packets + p []*sentPacket +} + +// start is the first packet in the list. +func (s *sentPacketList) start() packetNumber { + return s.nextNum - packetNumber(s.size) +} + +// end is one after the last packet in the list. +// If the list is empty, start == end. +func (s *sentPacketList) end() packetNumber { + return s.nextNum +} + +// discard clears the list. +func (s *sentPacketList) discard() { + *s = sentPacketList{} +} + +// add appends a packet to the list. +func (s *sentPacketList) add(sent *sentPacket) { + if s.nextNum != sent.num { + panic("inserting out-of-order packet") + } + s.nextNum++ + if s.size >= len(s.p) { + s.grow() + } + i := (s.off + s.size) % len(s.p) + s.size++ + s.p[i] = sent +} + +// nth returns a packet by index. +func (s *sentPacketList) nth(n int) *sentPacket { + index := (s.off + n) % len(s.p) + return s.p[index] +} + +// num returns a packet by number. +// It returns nil if the packet is not in the list. +func (s *sentPacketList) num(num packetNumber) *sentPacket { + i := int(num - s.start()) + if i < 0 || i >= s.size { + return nil + } + return s.nth(i) +} + +// clean removes all acked or lost packets from the head of the list. +func (s *sentPacketList) clean() { + for s.size > 0 { + sent := s.p[s.off] + if sent.state == sentPacketSent { + return + } + sent.recycle() + s.p[s.off] = nil + s.off = (s.off + 1) % len(s.p) + s.size-- + } + s.off = 0 +} + +// grow increases the buffer to hold more packaets. +func (s *sentPacketList) grow() { + newSize := len(s.p) * 2 + if newSize == 0 { + newSize = 64 + } + p := make([]*sentPacket, newSize) + for i := 0; i < s.size; i++ { + p[i] = s.nth(i) + } + s.p = p + s.off = 0 +} diff --git a/src/vendor/golang.org/x/net/quic/sent_val.go b/src/vendor/golang.org/x/net/quic/sent_val.go new file mode 100644 index 0000000000..f1682dbd78 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/sent_val.go @@ -0,0 +1,103 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A sentVal tracks sending some piece of information to the peer. +// It tracks whether the information has been sent, acked, and +// (when in-flight) the most recent packet to carry it. +// +// For example, a sentVal can track sending of a RESET_STREAM frame. +// +// - unset: stream is active, no need to send RESET_STREAM +// - unsent: we should send a RESET_STREAM, but have not yet +// - sent: we have sent a RESET_STREAM, but have not received an ack +// - received: we have sent a RESET_STREAM, and the peer has acked the packet that contained it +// +// In the "sent" state, a sentVal also tracks the latest packet number to carry +// the information. (QUIC packet numbers are always at most 62 bits in size, +// so the sentVal keeps the number in the low 62 bits and the state in the high 2 bits.) +type sentVal uint64 + +const ( + sentValUnset = 0 // unset + sentValUnsent = 1 << 62 // set, not sent to the peer + sentValSent = 2 << 62 // set, sent to the peer but not yet acked; pnum is set + sentValReceived = 3 << 62 // set, peer acked receipt + + sentValStateMask = 3 << 62 +) + +// isSet reports whether the value is set. +func (s sentVal) isSet() bool { return s != 0 } + +// shouldSend reports whether the value is set and has not been sent to the peer. +func (s sentVal) shouldSend() bool { return s.state() == sentValUnsent } + +// shouldSendPTO reports whether the value needs to be sent to the peer. +// The value needs to be sent if it is set and has not been sent. +// If pto is true, indicating that we are sending a PTO probe, the value +// should also be sent if it is set and has not been acknowledged. +func (s sentVal) shouldSendPTO(pto bool) bool { + st := s.state() + return st == sentValUnsent || (pto && st == sentValSent) +} + +// isReceived reports whether the value has been received by the peer. +func (s sentVal) isReceived() bool { return s == sentValReceived } + +// set sets the value and records that it should be sent to the peer. +// If the value has already been sent, it is not resent. +func (s *sentVal) set() { + if *s == 0 { + *s = sentValUnsent + } +} + +// reset sets the value to the unsent state. +func (s *sentVal) setUnsent() { *s = sentValUnsent } + +// clear sets the value to the unset state. +func (s *sentVal) clear() { *s = sentValUnset } + +// setSent sets the value to the send state and records the number of the most recent +// packet containing the value. +func (s *sentVal) setSent(pnum packetNumber) { + *s = sentValSent | sentVal(pnum) +} + +// setReceived sets the value to the received state. +func (s *sentVal) setReceived() { *s = sentValReceived } + +// ackOrLoss reports that an acknowledgement has been received for the value, +// or that the packet carrying the value has been lost. +func (s *sentVal) ackOrLoss(pnum packetNumber, fate packetFate) { + if fate == packetAcked { + *s = sentValReceived + } else if *s == sentVal(pnum)|sentValSent { + *s = sentValUnsent + } +} + +// ackLatestOrLoss reports that an acknowledgement has been received for the value, +// or that the packet carrying the value has been lost. +// The value is set to the acked state only if pnum is the latest packet containing it. +// +// We use this to handle acks for data that varies every time it is sent. +// For example, if we send a MAX_DATA frame followed by an updated MAX_DATA value in a +// second packet, we consider the data sent only upon receiving an ack for the most +// recent value. +func (s *sentVal) ackLatestOrLoss(pnum packetNumber, fate packetFate) { + if fate == packetAcked { + if *s == sentVal(pnum)|sentValSent { + *s = sentValReceived + } + } else { + if *s == sentVal(pnum)|sentValSent { + *s = sentValUnsent + } + } +} + +func (s sentVal) state() uint64 { return uint64(s) & sentValStateMask } diff --git a/src/vendor/golang.org/x/net/quic/skip.go b/src/vendor/golang.org/x/net/quic/skip.go new file mode 100644 index 0000000000..f0d0234ee6 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/skip.go @@ -0,0 +1,62 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// skipState is state for optimistic ACK defenses. +// +// An endpoint performs an optimistic ACK attack by sending acknowledgements for packets +// which it has not received, potentially convincing the sender's congestion controller to +// send at rates beyond what the network supports. +// +// We defend against this by periodically skipping packet numbers. +// Receiving an ACK for an unsent packet number is a PROTOCOL_VIOLATION error. +// +// We only skip packet numbers in the Application Data number space. +// The total data sent in the Initial/Handshake spaces should generally fit into +// the initial congestion window. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-21.4 +type skipState struct { + // skip is the next packet number (in the Application Data space) we should skip. + skip packetNumber + + // maxSkip is the maximum number of packets to send before skipping another number. + // Increases over time. + maxSkip int64 +} + +func (ss *skipState) init(c *Conn) { + ss.maxSkip = 256 // skip our first packet number within this range + ss.updateNumberSkip(c) +} + +// shouldSkip returns whether we should skip the given packet number. +func (ss *skipState) shouldSkip(num packetNumber) bool { + return ss.skip == num +} + +// updateNumberSkip schedules a packet to be skipped after skipping lastSkipped. +func (ss *skipState) updateNumberSkip(c *Conn) { + // Send at least this many packets before skipping. + // Limits the impact of skipping a little, + // plus allows most tests to ignore skipping. + const minSkip = 64 + + skip := minSkip + c.prng.Int64N(ss.maxSkip-minSkip) + ss.skip += packetNumber(skip) + + // Double the size of the skip each time until we reach 128k. + // The idea here is that an attacker needs to correctly ack ~N packets in order + // to send an optimistic ack for another ~N packets. + // Skipping packet numbers comes with a small cost (it causes the receiver to + // send an immediate ACK rather than the usual delayed ACK), so we increase the + // time between skips as a connection's lifetime grows. + // + // The 128k cap is arbitrary, chosen so that we skip a packet number + // about once a second when sending full-size datagrams at 1Gbps. + if ss.maxSkip < 128*1024 { + ss.maxSkip *= 2 + } +} diff --git a/src/vendor/golang.org/x/net/quic/stateless_reset.go b/src/vendor/golang.org/x/net/quic/stateless_reset.go new file mode 100644 index 0000000000..8907e2e58b --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/stateless_reset.go @@ -0,0 +1,59 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "hash" + "sync" +) + +const statelessResetTokenLen = 128 / 8 + +// A statelessResetToken is a stateless reset token. +// https://www.rfc-editor.org/rfc/rfc9000#section-10.3 +type statelessResetToken [statelessResetTokenLen]byte + +type statelessResetTokenGenerator struct { + canReset bool + + // The hash.Hash interface is not concurrency safe, + // so we need a mutex here. + // + // There shouldn't be much contention on stateless reset token generation. + // If this proves to be a problem, we could avoid the mutex by using a separate + // generator per Conn, or by using a concurrency-safe generator. + mu sync.Mutex + mac hash.Hash +} + +func (g *statelessResetTokenGenerator) init(secret [32]byte) { + zero := true + for _, b := range secret { + if b != 0 { + zero = false + break + } + } + if zero { + // Generate tokens using a random secret, but don't send stateless resets. + rand.Read(secret[:]) + g.canReset = false + } else { + g.canReset = true + } + g.mac = hmac.New(sha256.New, secret[:]) +} + +func (g *statelessResetTokenGenerator) tokenForConnID(cid []byte) (token statelessResetToken) { + g.mu.Lock() + defer g.mu.Unlock() + defer g.mac.Reset() + g.mac.Write(cid) + copy(token[:], g.mac.Sum(nil)) + return token +} diff --git a/src/vendor/golang.org/x/net/quic/stream.go b/src/vendor/golang.org/x/net/quic/stream.go new file mode 100644 index 0000000000..383a6c160a --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/stream.go @@ -0,0 +1,1041 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "errors" + "fmt" + "io" + "math" + + "golang.org/x/net/internal/quic/quicwire" +) + +// A Stream is an ordered byte stream. +// +// Streams may be bidirectional, read-only, or write-only. +// Methods inappropriate for a stream's direction +// (for example, [Write] to a read-only stream) +// return errors. +// +// It is not safe to perform concurrent reads from or writes to a stream. +// It is safe, however, to read and write at the same time. +// +// Reads and writes are buffered. +// It is generally not necessary to wrap a stream in a [bufio.ReadWriter] +// or otherwise apply additional buffering. +// +// To cancel reads or writes, use the [SetReadContext] and [SetWriteContext] methods. +type Stream struct { + id streamID + conn *Conn + + // Contexts used for read/write operations. + // Intentionally not mutex-guarded, to allow the race detector to catch concurrent access. + inctx context.Context + outctx context.Context + + // ingate's lock guards receive-related state. + // + // The gate condition is set if a read from the stream will not block, + // either because the stream has available data or because the read will fail. + ingate gate + in pipe // received data + inwin int64 // last MAX_STREAM_DATA sent to the peer + insendmax sentVal // set when we should send MAX_STREAM_DATA to the peer + inmaxbuf int64 // maximum amount of data we will buffer + insize int64 // stream final size; -1 before this is known + inset rangeset[int64] // received ranges + inclosed sentVal // set by CloseRead + inresetcode int64 // RESET_STREAM code received from the peer; -1 if not reset + + // outgate's lock guards send-related state. + // + // The gate condition is set if a write to the stream will not block, + // either because the stream has available flow control or because + // the write will fail. + outgate gate + out pipe // buffered data to send + outflushed int64 // offset of last flush call + outwin int64 // maximum MAX_STREAM_DATA received from the peer + outmaxsent int64 // maximum data offset we've sent to the peer + outmaxbuf int64 // maximum amount of data we will buffer + outunsent rangeset[int64] // ranges buffered but not yet sent (only flushed data) + outacked rangeset[int64] // ranges sent and acknowledged + outopened sentVal // set if we should open the stream + outclosed sentVal // set by CloseWrite + outblocked sentVal // set when a write to the stream is blocked by flow control + outreset sentVal // set by Reset + outresetcode uint64 // reset code to send in RESET_STREAM + outdone chan struct{} // closed when all data sent + + // Unsynchronized buffers, used for lock-free fast path. + inbuf []byte // received data + inbufoff int // bytes of inbuf which have been consumed + outbuf []byte // written data + outbufoff int // bytes of outbuf which contain data to write + + // Atomic stream state bits. + // + // These bits provide a fast way to coordinate between the + // send and receive sides of the stream, and the conn's loop. + // + // streamIn* bits must be set with ingate held. + // streamOut* bits must be set with outgate held. + // streamConn* bits are set by the conn's loop. + // streamQueue* bits must be set with streamsState.sendMu held. + state atomicBits[streamState] + + prev, next *Stream // guarded by streamsState.sendMu +} + +type streamState uint32 + +const ( + // streamInSendMeta is set when there are frames to send for the + // inbound side of the stream. For example, MAX_STREAM_DATA. + // Inbound frames are never flow-controlled. + streamInSendMeta = streamState(1 << iota) + + // streamOutSendMeta is set when there are non-flow-controlled frames + // to send for the outbound side of the stream. For example, STREAM_DATA_BLOCKED. + // streamOutSendData is set when there are no non-flow-controlled outbound frames + // and the stream has data to send. + // + // At most one of streamOutSendMeta and streamOutSendData is set at any time. + streamOutSendMeta + streamOutSendData + + // streamInDone and streamOutDone are set when the inbound or outbound + // sides of the stream are finished. When both are set, the stream + // can be removed from the Conn and forgotten. + streamInDone + streamOutDone + + // streamConnRemoved is set when the stream has been removed from the conn. + streamConnRemoved + + // streamQueueMeta and streamQueueData indicate which of the streamsState + // send queues the conn is currently on. + streamQueueMeta + streamQueueData +) + +type streamQueue int + +const ( + noQueue = streamQueue(iota) + metaQueue // streamsState.queueMeta + dataQueue // streamsState.queueData +) + +// streamResetByConnClose is assigned to Stream.inresetcode to indicate that a stream +// was implicitly reset when the connection closed. It's out of the range of +// possible reset codes the peer can send. +const streamResetByConnClose = math.MaxInt64 + +// wantQueue returns the send queue the stream should be on. +func (s streamState) wantQueue() streamQueue { + switch { + case s&(streamInSendMeta|streamOutSendMeta) != 0: + return metaQueue + case s&(streamInDone|streamOutDone|streamConnRemoved) == streamInDone|streamOutDone: + return metaQueue + case s&streamOutSendData != 0: + // The stream has no non-flow-controlled frames to send, + // but does have data. Put it on the data queue, which is only + // processed when flow control is available. + return dataQueue + } + return noQueue +} + +// inQueue returns the send queue the stream is currently on. +func (s streamState) inQueue() streamQueue { + switch { + case s&streamQueueMeta != 0: + return metaQueue + case s&streamQueueData != 0: + return dataQueue + } + return noQueue +} + +// newStream returns a new stream. +// +// The stream's ingate and outgate are locked. +// (We create the stream with locked gates so after the caller +// initializes the flow control window, +// unlocking outgate will set the stream writability state.) +func newStream(c *Conn, id streamID) *Stream { + s := &Stream{ + conn: c, + id: id, + insize: -1, // -1 indicates the stream size is unknown + inresetcode: -1, // -1 indicates no RESET_STREAM received + ingate: newLockedGate(), + outgate: newLockedGate(), + inctx: context.Background(), + outctx: context.Background(), + } + if !s.IsReadOnly() { + s.outdone = make(chan struct{}) + } + return s +} + +// ID returns the QUIC stream ID of s. +// +// As specified in RFC 9000, the two least significant bits of a stream ID +// indicate the initiator and directionality of the stream. The upper bits are +// the stream number. +func (s *Stream) ID() int64 { + return int64(s.id) +} + +// SetReadContext sets the context used for reads from the stream. +// +// It is not safe to call SetReadContext concurrently. +func (s *Stream) SetReadContext(ctx context.Context) { + s.inctx = ctx +} + +// SetWriteContext sets the context used for writes to the stream. +// The write context is also used by Close when waiting for writes to be +// received by the peer. +// +// It is not safe to call SetWriteContext concurrently. +func (s *Stream) SetWriteContext(ctx context.Context) { + s.outctx = ctx +} + +// IsReadOnly reports whether the stream is read-only +// (a unidirectional stream created by the peer). +func (s *Stream) IsReadOnly() bool { + return s.id.streamType() == uniStream && s.id.initiator() != s.conn.side +} + +// IsWriteOnly reports whether the stream is write-only +// (a unidirectional stream created locally). +func (s *Stream) IsWriteOnly() bool { + return s.id.streamType() == uniStream && s.id.initiator() == s.conn.side +} + +// Read reads data from the stream. +// +// Read returns as soon as at least one byte of data is available. +// +// If the peer closes the stream cleanly, Read returns io.EOF after +// returning all data sent by the peer. +// If the peer aborts reads on the stream, Read returns +// an error wrapping StreamResetCode. +// +// It is not safe to call Read concurrently. +func (s *Stream) Read(b []byte) (n int, err error) { + if s.IsWriteOnly() { + return 0, errors.New("read from write-only stream") + } + if len(s.inbuf) > s.inbufoff { + // Fast path: If s.inbuf contains unread bytes, return them immediately + // without taking a lock. + n = copy(b, s.inbuf[s.inbufoff:]) + s.inbufoff += n + return n, nil + } + if err := s.ingate.waitAndLock(s.inctx); err != nil { + return 0, err + } + if s.inbufoff > 0 { + // Discard bytes consumed by the fast path above. + s.in.discardBefore(s.in.start + int64(s.inbufoff)) + s.inbufoff = 0 + s.inbuf = nil + } + // bytesRead contains the number of bytes of connection-level flow control to return. + // We return flow control for bytes read by this Read call, as well as bytes moved + // to the fast-path read buffer (s.inbuf). + var bytesRead int64 + defer func() { + s.inUnlock() + s.conn.handleStreamBytesReadOffLoop(bytesRead) // must be done with ingate unlocked + }() + if s.inresetcode != -1 { + if s.inresetcode == streamResetByConnClose { + if err := s.conn.finalError(); err != nil { + return 0, err + } + } + return 0, fmt.Errorf("stream reset by peer: %w", StreamErrorCode(s.inresetcode)) + } + if s.inclosed.isSet() { + return 0, errors.New("read from closed stream") + } + if s.insize == s.in.start { + return 0, io.EOF + } + // Getting here indicates the stream contains data to be read. + if len(s.inset) < 1 || s.inset[0].start != 0 || s.inset[0].end <= s.in.start { + panic("BUG: inconsistent input stream state") + } + if size := int(s.inset[0].end - s.in.start); size < len(b) { + b = b[:size] + } + bytesRead = int64(len(b)) + start := s.in.start + end := start + int64(len(b)) + s.in.copy(start, b) + s.in.discardBefore(end) + if end == s.insize { + // We have read up to the end of the stream. + // No need to update stream flow control. + return len(b), io.EOF + } + if len(s.inset) > 0 && s.inset[0].start <= s.in.start && s.inset[0].end > s.in.start { + // If we have more readable bytes available, put the next chunk of data + // in s.inbuf for lock-free reads. + s.inbuf = s.in.peek(s.inset[0].end - s.in.start) + bytesRead += int64(len(s.inbuf)) + } + if s.insize == -1 || s.insize > s.inwin { + newWindow := s.in.start + int64(len(s.inbuf)) + s.inmaxbuf + addedWindow := newWindow - s.inwin + if shouldUpdateFlowControl(s.inmaxbuf, addedWindow) { + // Update stream flow control with a STREAM_MAX_DATA frame. + s.insendmax.setUnsent() + } + } + return len(b), nil +} + +// ReadByte reads and returns a single byte from the stream. +// +// It is not safe to call ReadByte concurrently. +func (s *Stream) ReadByte() (byte, error) { + if len(s.inbuf) > s.inbufoff { + b := s.inbuf[s.inbufoff] + s.inbufoff++ + return b, nil + } + var b [1]byte + n, err := s.Read(b[:]) + if n > 0 { + return b[0], nil + } + return 0, err +} + +// shouldUpdateFlowControl determines whether to send a flow control window update. +// +// We want to balance keeping the peer well-supplied with flow control with not sending +// many small updates. +func shouldUpdateFlowControl(maxWindow, addedWindow int64) bool { + return addedWindow >= maxWindow/8 +} + +// Write writes data to the stream. +// +// Write writes data to the stream write buffer. +// Buffered data is only sent when the buffer is sufficiently full. +// Call the Flush method to ensure buffered data is sent. +func (s *Stream) Write(b []byte) (n int, err error) { + if s.IsReadOnly() { + return 0, errors.New("write to read-only stream") + } + if len(b) > 0 && len(s.outbuf)-s.outbufoff >= len(b) { + // Fast path: The data to write fits in s.outbuf. + copy(s.outbuf[s.outbufoff:], b) + s.outbufoff += len(b) + return len(b), nil + } + canWrite := s.outgate.lock() + s.flushFastOutputBuffer() + for { + // The first time through this loop, we may or may not be write blocked. + // We exit the loop after writing all data, so on subsequent passes through + // the loop we are always write blocked. + if len(b) > 0 && !canWrite { + // Our send buffer is full. Wait for the peer to ack some data. + s.outUnlock() + if err := s.outgate.waitAndLock(s.outctx); err != nil { + return n, err + } + // Successfully returning from waitAndLockGate means we are no longer + // write blocked. (Unlike traditional condition variables, gates do not + // have spurious wakeups.) + } + if err := s.writeErrorLocked(); err != nil { + s.outUnlock() + return n, err + } + if len(b) == 0 { + break + } + // Write limit is our send buffer limit. + // This is a stream offset. + lim := s.out.start + s.outmaxbuf + // Amount to write is min(the full buffer, data up to the write limit). + // This is a number of bytes. + nn := min(int64(len(b)), lim-s.out.end) + // Copy the data into the output buffer. + s.out.writeAt(b[:nn], s.out.end) + b = b[nn:] + n += int(nn) + // Possibly flush the output buffer. + // We automatically flush if: + // - We have enough data to consume the send window. + // Sending this data may cause the peer to extend the window. + // - We have buffered as much data as we're willing do. + // We need to send data to clear out buffer space. + // - We have enough data to fill a 1-RTT packet using the smallest + // possible maximum datagram size (1200 bytes, less header byte, + // connection ID, packet number, and AEAD overhead). + const autoFlushSize = smallestMaxDatagramSize - 1 - connIDLen - 1 - aeadOverhead + shouldFlush := s.out.end >= s.outwin || // peer send window is full + s.out.end >= lim || // local send buffer is full + (s.out.end-s.outflushed) >= autoFlushSize // enough data buffered + if shouldFlush { + s.flushLocked() + } + if s.out.end > s.outwin { + // We're blocked by flow control. + // Send a STREAM_DATA_BLOCKED frame to let the peer know. + s.outblocked.set() + } + // If we have bytes left to send, we're blocked. + canWrite = false + } + if lim := s.out.start + s.outmaxbuf - s.out.end - 1; lim > 0 { + // If s.out has space allocated and available to be written into, + // then reference it in s.outbuf for fast-path writes. + // + // It's perhaps a bit pointless to limit s.outbuf to the send buffer limit. + // We've already allocated this buffer so we aren't saving any memory + // by not using it. + // For now, we limit it anyway to make it easier to reason about limits. + // + // We set the limit to one less than the send buffer limit (the -1 above) + // so that a write which completely fills the buffer will overflow + // s.outbuf and trigger a flush. + s.outbuf = s.out.availableBuffer() + if int64(len(s.outbuf)) > lim { + s.outbuf = s.outbuf[:lim] + } + } + s.outUnlock() + return n, nil +} + +// WriteByte writes a single byte to the stream. +func (s *Stream) WriteByte(c byte) error { + if s.outbufoff < len(s.outbuf) { + s.outbuf[s.outbufoff] = c + s.outbufoff++ + return nil + } + b := [1]byte{c} + _, err := s.Write(b[:]) + return err +} + +func (s *Stream) flushFastOutputBuffer() { + if s.outbuf == nil { + return + } + // Commit data previously written to s.outbuf. + // s.outbuf is a reference to a buffer in s.out, so we just need to record + // that the output buffer has been extended. + s.out.end += int64(s.outbufoff) + s.outbuf = nil + s.outbufoff = 0 +} + +// Flush flushes data written to the stream. +// It does not wait for the peer to acknowledge receipt of the data. +// Use Close to wait for the peer's acknowledgement. +func (s *Stream) Flush() error { + if s.IsReadOnly() { + return errors.New("flush of read-only stream") + } + s.outgate.lock() + defer s.outUnlock() + if err := s.writeErrorLocked(); err != nil { + return err + } + s.flushLocked() + return nil +} + +// writeErrorLocked returns the error (if any) which should be returned by write operations +// due to the stream being reset or closed. +func (s *Stream) writeErrorLocked() error { + if s.outreset.isSet() { + if s.outresetcode == streamResetByConnClose { + if err := s.conn.finalError(); err != nil { + return err + } + } + return errors.New("write to reset stream") + } + if s.outclosed.isSet() { + return errors.New("write to closed stream") + } + return nil +} + +func (s *Stream) flushLocked() { + s.flushFastOutputBuffer() + s.outopened.set() + if s.outflushed < s.outwin { + s.outunsent.add(s.outflushed, min(s.outwin, s.out.end)) + } + s.outflushed = s.out.end +} + +// Close closes the stream. +// Any blocked stream operations will be unblocked and return errors. +// +// Close flushes any data in the stream write buffer and waits for the peer to +// acknowledge receipt of the data. +// If the stream has been reset, it waits for the peer to acknowledge the reset. +// If the context expires before the peer receives the stream's data, +// Close discards the buffer and returns the context error. +func (s *Stream) Close() error { + s.CloseRead() + if s.IsReadOnly() { + return nil + } + s.CloseWrite() + // TODO: Return code from peer's RESET_STREAM frame? + if err := s.conn.waitOnDone(s.outctx, s.outdone); err != nil { + return err + } + s.outgate.lock() + defer s.outUnlock() + if s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end) { + return nil + } + return errors.New("stream reset") +} + +// CloseRead aborts reads on the stream. +// Any blocked reads will be unblocked and return errors. +// +// CloseRead notifies the peer that the stream has been closed for reading. +// It does not wait for the peer to acknowledge the closure. +// Use Close to wait for the peer's acknowledgement. +func (s *Stream) CloseRead() { + if s.IsWriteOnly() { + return + } + s.ingate.lock() + if s.inset.isrange(0, s.insize) || s.inresetcode != -1 { + // We've already received all data from the peer, + // so there's no need to send STOP_SENDING. + // This is the same as saying we sent one and they got it. + s.inclosed.setReceived() + } else { + s.inclosed.set() + } + discarded := s.in.end - s.in.start + s.in.discardBefore(s.in.end) + s.inUnlock() + s.conn.handleStreamBytesReadOffLoop(discarded) // must be done with ingate unlocked +} + +// CloseWrite aborts writes on the stream. +// Any blocked writes will be unblocked and return errors. +// +// CloseWrite sends any data in the stream write buffer to the peer. +// It does not wait for the peer to acknowledge receipt of the data. +// Use Close to wait for the peer's acknowledgement. +func (s *Stream) CloseWrite() { + if s.IsReadOnly() { + return + } + s.outgate.lock() + defer s.outUnlock() + s.outclosed.set() + s.flushLocked() +} + +// Reset aborts writes on the stream and notifies the peer +// that the stream was terminated abruptly. +// Any blocked writes will be unblocked and return errors. +// +// Reset sends the application protocol error code, which must be +// less than 2^62, to the peer. +// It does not wait for the peer to acknowledge receipt of the error. +// Use Close to wait for the peer's acknowledgement. +// +// Reset does not affect reads. +// Use CloseRead to abort reads on the stream. +func (s *Stream) Reset(code uint64) { + const userClosed = true + s.resetInternal(code, userClosed) +} + +// resetInternal resets the send side of the stream. +// +// If userClosed is true, this is s.Reset. +// If userClosed is false, this is a reaction to a STOP_SENDING frame. +func (s *Stream) resetInternal(code uint64, userClosed bool) { + s.outgate.lock() + defer s.outUnlock() + if s.IsReadOnly() { + return + } + if userClosed { + // Mark that the user closed the stream. + s.outclosed.set() + } + if s.outreset.isSet() { + return + } + if code > quicwire.MaxVarint { + code = quicwire.MaxVarint + } + // We could check here to see if the stream is closed and the + // peer has acked all the data and the FIN, but sending an + // extra RESET_STREAM in this case is harmless. + s.outreset.set() + s.outresetcode = code + s.outbuf = nil + s.outbufoff = 0 + s.out.discardBefore(s.out.end) + s.outunsent = rangeset[int64]{} + s.outblocked.clear() +} + +// connHasClosed indicates the stream's conn has closed. +func (s *Stream) connHasClosed() { + // If we're in the closing state, the user closed the conn. + // Otherwise, we the peer initiated the close. + // This only matters for the error we're going to return from stream operations. + localClose := s.conn.lifetime.state == connStateClosing + + s.ingate.lock() + if !s.inset.isrange(0, s.insize) && s.inresetcode == -1 { + if localClose { + s.inclosed.set() + } else { + s.inresetcode = streamResetByConnClose + } + } + s.inUnlock() + + s.outgate.lock() + if localClose { + s.outclosed.set() + s.outreset.set() + } else { + s.outresetcode = streamResetByConnClose + s.outreset.setReceived() + } + s.outUnlock() +} + +// inUnlock unlocks s.ingate. +// It sets the gate condition if reads from s will not block. +// If s has receive-related frames to write or if both directions +// are done and the stream should be removed, it notifies the Conn. +func (s *Stream) inUnlock() { + state := s.inUnlockNoQueue() + s.conn.maybeQueueStreamForSend(s, state) +} + +// inUnlockNoQueue is inUnlock, +// but reports whether s has frames to write rather than notifying the Conn. +func (s *Stream) inUnlockNoQueue() streamState { + nextByte := s.in.start + int64(len(s.inbuf)) + canRead := s.inset.contains(nextByte) || // data available to read + s.insize == s.in.start+int64(len(s.inbuf)) || // at EOF + s.inresetcode != -1 || // reset by peer + s.inclosed.isSet() // closed locally + defer s.ingate.unlock(canRead) + var state streamState + switch { + case s.IsWriteOnly(): + state = streamInDone + case s.inresetcode != -1: // reset by peer + fallthrough + case s.in.start == s.insize: // all data received and read + // We don't increase MAX_STREAMS until the user calls ReadClose or Close, + // so the receive side is not finished until inclosed is set. + if s.inclosed.isSet() { + state = streamInDone + } + case s.insendmax.shouldSend(): // STREAM_MAX_DATA + state = streamInSendMeta + case s.inclosed.shouldSend(): // STOP_SENDING + state = streamInSendMeta + } + const mask = streamInDone | streamInSendMeta + return s.state.set(state, mask) +} + +// outUnlock unlocks s.outgate. +// It sets the gate condition if writes to s will not block. +// If s has send-related frames to write or if both directions +// are done and the stream should be removed, it notifies the Conn. +func (s *Stream) outUnlock() { + state := s.outUnlockNoQueue() + s.conn.maybeQueueStreamForSend(s, state) +} + +// outUnlockNoQueue is outUnlock, +// but reports whether s has frames to write rather than notifying the Conn. +func (s *Stream) outUnlockNoQueue() streamState { + isDone := s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end) || // all data acked + s.outreset.isSet() // reset locally + if isDone { + select { + case <-s.outdone: + default: + if !s.IsReadOnly() { + close(s.outdone) + } + } + } + lim := s.out.start + s.outmaxbuf + canWrite := lim > s.out.end || // available send buffer + s.outclosed.isSet() || // closed locally + s.outreset.isSet() // reset locally + defer s.outgate.unlock(canWrite) + var state streamState + switch { + case s.IsReadOnly(): + state = streamOutDone + case s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end): // all data sent and acked + fallthrough + case s.outreset.isReceived(): // RESET_STREAM sent and acked + // We don't increase MAX_STREAMS until the user calls WriteClose or Close, + // so the send side is not finished until outclosed is set. + if s.outclosed.isSet() { + state = streamOutDone + } + case s.outreset.shouldSend(): // RESET_STREAM + state = streamOutSendMeta + case s.outreset.isSet(): // RESET_STREAM sent but not acknowledged + case s.outblocked.shouldSend(): // STREAM_DATA_BLOCKED + state = streamOutSendMeta + case len(s.outunsent) > 0: // STREAM frame with data + if s.outunsent.min() < s.outmaxsent { + state = streamOutSendMeta // resent data, will not consume flow control + } else { + state = streamOutSendData // new data, requires flow control + } + case s.outclosed.shouldSend() && s.out.end == s.outmaxsent: // empty STREAM frame with FIN bit + state = streamOutSendMeta + case s.outopened.shouldSend(): // STREAM frame with no data + state = streamOutSendMeta + } + const mask = streamOutDone | streamOutSendMeta | streamOutSendData + return s.state.set(state, mask) +} + +// handleData handles data received in a STREAM frame. +func (s *Stream) handleData(off int64, b []byte, fin bool) error { + s.ingate.lock() + defer s.inUnlock() + end := off + int64(len(b)) + if err := s.checkStreamBounds(end, fin); err != nil { + return err + } + if s.inclosed.isSet() || s.inresetcode != -1 { + // The user read-closed the stream, or the peer reset it. + // Either way, we can discard this frame. + return nil + } + if s.insize == -1 && end > s.in.end { + added := end - s.in.end + if err := s.conn.handleStreamBytesReceived(added); err != nil { + return err + } + } + s.in.writeAt(b, off) + s.inset.add(off, end) + if fin { + s.insize = end + // The peer has enough flow control window to send the entire stream. + s.insendmax.clear() + } + return nil +} + +// handleReset handles a RESET_STREAM frame. +func (s *Stream) handleReset(code uint64, finalSize int64) error { + s.ingate.lock() + defer s.inUnlock() + const fin = true + if err := s.checkStreamBounds(finalSize, fin); err != nil { + return err + } + if s.inresetcode != -1 { + // The stream was already reset. + return nil + } + if s.insize == -1 { + added := finalSize - s.in.end + if err := s.conn.handleStreamBytesReceived(added); err != nil { + return err + } + } + s.conn.handleStreamBytesReadOnLoop(finalSize - s.in.start) + s.in.discardBefore(s.in.end) + s.inresetcode = int64(code) + s.insize = finalSize + return nil +} + +// checkStreamBounds validates the stream offset in a STREAM or RESET_STREAM frame. +func (s *Stream) checkStreamBounds(end int64, fin bool) error { + if end > s.inwin { + // The peer sent us data past the maximum flow control window we gave them. + return localTransportError{ + code: errFlowControl, + reason: "stream flow control window exceeded", + } + } + if s.insize != -1 && end > s.insize { + // The peer sent us data past the final size of the stream they previously gave us. + return localTransportError{ + code: errFinalSize, + reason: "data received past end of stream", + } + } + if fin && s.insize != -1 && end != s.insize { + // The peer changed the final size of the stream. + return localTransportError{ + code: errFinalSize, + reason: "final size of stream changed", + } + } + if fin && end < s.in.end { + // The peer has previously sent us data past the final size. + return localTransportError{ + code: errFinalSize, + reason: "end of stream occurs before prior data", + } + } + return nil +} + +// handleStopSending handles a STOP_SENDING frame. +func (s *Stream) handleStopSending(code uint64) error { + // Peer requests that we reset this stream. + // https://www.rfc-editor.org/rfc/rfc9000#section-3.5-4 + const userReset = false + s.resetInternal(code, userReset) + return nil +} + +// handleMaxStreamData handles an update received in a MAX_STREAM_DATA frame. +func (s *Stream) handleMaxStreamData(maxStreamData int64) error { + s.outgate.lock() + defer s.outUnlock() + if maxStreamData <= s.outwin { + return nil + } + if s.outflushed > s.outwin { + s.outunsent.add(s.outwin, min(maxStreamData, s.outflushed)) + } + s.outwin = maxStreamData + if s.out.end > s.outwin { + // We've still got more data than flow control window. + s.outblocked.setUnsent() + } else { + s.outblocked.clear() + } + return nil +} + +// ackOrLoss handles the fate of stream frames other than STREAM. +func (s *Stream) ackOrLoss(pnum packetNumber, ftype byte, fate packetFate) { + // Frames which carry new information each time they are sent + // (MAX_STREAM_DATA, STREAM_DATA_BLOCKED) must only be marked + // as received if the most recent packet carrying this frame is acked. + // + // Frames which are always the same (STOP_SENDING, RESET_STREAM) + // can be marked as received if any packet carrying this frame is acked. + switch ftype { + case frameTypeResetStream: + s.outgate.lock() + s.outreset.ackOrLoss(pnum, fate) + s.outUnlock() + case frameTypeStopSending: + s.ingate.lock() + s.inclosed.ackOrLoss(pnum, fate) + s.inUnlock() + case frameTypeMaxStreamData: + s.ingate.lock() + s.insendmax.ackLatestOrLoss(pnum, fate) + s.inUnlock() + case frameTypeStreamDataBlocked: + s.outgate.lock() + s.outblocked.ackLatestOrLoss(pnum, fate) + s.outUnlock() + default: + panic("unhandled frame type") + } +} + +// ackOrLossData handles the fate of a STREAM frame. +func (s *Stream) ackOrLossData(pnum packetNumber, start, end int64, fin bool, fate packetFate) { + s.outgate.lock() + defer s.outUnlock() + s.outopened.ackOrLoss(pnum, fate) + if fin { + s.outclosed.ackOrLoss(pnum, fate) + } + if s.outreset.isSet() { + // If the stream has been reset, we don't care any more. + return + } + switch fate { + case packetAcked: + s.outacked.add(start, end) + s.outunsent.sub(start, end) + // If this ack is for data at the start of the send buffer, we can now discard it. + if s.outacked.contains(s.out.start) { + s.out.discardBefore(s.outacked[0].end) + } + case packetLost: + // Mark everything lost, but not previously acked, as needing retransmission. + // We do this by adding all the lost bytes to outunsent, and then + // removing everything already acked. + s.outunsent.add(start, end) + for _, a := range s.outacked { + s.outunsent.sub(a.start, a.end) + } + } +} + +// appendInFramesLocked appends STOP_SENDING and MAX_STREAM_DATA frames +// to the current packet. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (s *Stream) appendInFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool { + if s.inclosed.shouldSendPTO(pto) { + // We don't currently have an API for setting the error code. + // Just send zero. + code := uint64(0) + if !w.appendStopSendingFrame(s.id, code) { + return false + } + s.inclosed.setSent(pnum) + } + // TODO: STOP_SENDING + if s.insendmax.shouldSendPTO(pto) { + // MAX_STREAM_DATA + maxStreamData := s.in.start + s.inmaxbuf + if !w.appendMaxStreamDataFrame(s.id, maxStreamData) { + return false + } + s.inwin = maxStreamData + s.insendmax.setSent(pnum) + } + return true +} + +// appendOutFramesLocked appends RESET_STREAM, STREAM_DATA_BLOCKED, and STREAM frames +// to the current packet. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool { + if s.outreset.isSet() { + // RESET_STREAM + if s.outreset.shouldSendPTO(pto) { + if !w.appendResetStreamFrame(s.id, s.outresetcode, min(s.outwin, s.out.end)) { + return false + } + s.outreset.setSent(pnum) + s.frameOpensStream(pnum) + } + return true + } + if s.outblocked.shouldSendPTO(pto) { + // STREAM_DATA_BLOCKED + if !w.appendStreamDataBlockedFrame(s.id, s.outwin) { + return false + } + s.outblocked.setSent(pnum) + s.frameOpensStream(pnum) + } + for { + // STREAM + off, size := dataToSend(min(s.out.start, s.outwin), min(s.outflushed, s.outwin), s.outunsent, s.outacked, pto) + if end := off + size; end > s.outmaxsent { + // This will require connection-level flow control to send. + end = min(end, s.outmaxsent+s.conn.streams.outflow.avail()) + end = max(end, off) + size = end - off + } + fin := s.outclosed.isSet() && off+size == s.out.end + shouldSend := size > 0 || // have data to send + s.outopened.shouldSendPTO(pto) || // should open the stream + (fin && s.outclosed.shouldSendPTO(pto)) // should close the stream + if !shouldSend { + return true + } + b, added := w.appendStreamFrame(s.id, off, int(size), fin) + if !added { + return false + } + s.out.copy(off, b) + end := off + int64(len(b)) + if end > s.outmaxsent { + s.conn.streams.outflow.consume(end - s.outmaxsent) + s.outmaxsent = end + } + s.outunsent.sub(off, end) + s.frameOpensStream(pnum) + if fin { + s.outclosed.setSent(pnum) + } + if pto { + return true + } + if int64(len(b)) < size { + return false + } + } +} + +// frameOpensStream records that we're sending a frame that will open the stream. +// +// If we don't have an acknowledgement from the peer for a previous frame opening the stream, +// record this packet as being the latest one to open it. +func (s *Stream) frameOpensStream(pnum packetNumber) { + if !s.outopened.isReceived() { + s.outopened.setSent(pnum) + } +} + +// dataToSend returns the next range of data to send in a STREAM or CRYPTO_STREAM. +func dataToSend(start, end int64, outunsent, outacked rangeset[int64], pto bool) (sendStart, size int64) { + switch { + case pto: + // On PTO, resend unacked data that fits in the probe packet. + // For simplicity, we send the range starting at s.out.start + // (which is definitely unacked, or else we would have discarded it) + // up to the next acked byte (if any). + // + // This may miss unacked data starting after that acked byte, + // but avoids resending data the peer has acked. + for _, r := range outacked { + if r.start > start { + return start, r.start - start + } + } + return start, end - start + case outunsent.numRanges() > 0: + return outunsent.min(), outunsent[0].size() + default: + return end, 0 + } +} diff --git a/src/vendor/golang.org/x/net/quic/stream_limits.go b/src/vendor/golang.org/x/net/quic/stream_limits.go new file mode 100644 index 0000000000..f1abcae99c --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/stream_limits.go @@ -0,0 +1,125 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" +) + +// Limits on the number of open streams. +// Every connection has separate limits for bidirectional and unidirectional streams. +// +// Note that the MAX_STREAMS limit includes closed as well as open streams. +// Closing a stream doesn't enable an endpoint to open a new one; +// only an increase in the MAX_STREAMS limit does. + +// localStreamLimits are limits on the number of open streams created by us. +type localStreamLimits struct { + gate gate + max int64 // peer-provided MAX_STREAMS + opened int64 // number of streams opened by us, -1 when conn is closed +} + +func (lim *localStreamLimits) init() { + lim.gate = newGate() +} + +// open creates a new local stream, blocking until MAX_STREAMS quota is available. +func (lim *localStreamLimits) open(ctx context.Context, c *Conn) (num int64, err error) { + // TODO: Send a STREAMS_BLOCKED when blocked. + if err := lim.gate.waitAndLock(ctx); err != nil { + return 0, err + } + if lim.opened < 0 { + lim.gate.unlock(true) + return 0, errConnClosed + } + num = lim.opened + lim.opened++ + lim.gate.unlock(lim.opened < lim.max) + return num, nil +} + +// connHasClosed indicates the connection has been closed, locally or by the peer. +func (lim *localStreamLimits) connHasClosed() { + lim.gate.lock() + lim.opened = -1 + lim.gate.unlock(true) +} + +// setMax sets the MAX_STREAMS provided by the peer. +func (lim *localStreamLimits) setMax(maxStreams int64) { + lim.gate.lock() + lim.max = max(lim.max, maxStreams) + lim.gate.unlock(lim.opened < lim.max) +} + +// remoteStreamLimits are limits on the number of open streams created by the peer. +type remoteStreamLimits struct { + max int64 // last MAX_STREAMS sent to the peer + opened int64 // number of streams opened by the peer (including subsequently closed ones) + closed int64 // number of peer streams in the "closed" state + maxOpen int64 // how many streams we want to let the peer simultaneously open + sendMax sentVal // set when we should send MAX_STREAMS +} + +func (lim *remoteStreamLimits) init(maxOpen int64) { + lim.maxOpen = maxOpen + lim.max = min(maxOpen, implicitStreamLimit) // initial limit sent in transport parameters + lim.opened = 0 +} + +// open handles the peer opening a new stream. +func (lim *remoteStreamLimits) open(id streamID) error { + num := id.num() + if num >= lim.max { + return localTransportError{ + code: errStreamLimit, + reason: "stream limit exceeded", + } + } + if num >= lim.opened { + lim.opened = num + 1 + lim.maybeUpdateMax() + } + return nil +} + +// close handles the peer closing an open stream. +func (lim *remoteStreamLimits) close() { + lim.closed++ + lim.maybeUpdateMax() +} + +// maybeUpdateMax updates the MAX_STREAMS value we will send to the peer. +func (lim *remoteStreamLimits) maybeUpdateMax() { + newMax := min( + // Max streams the peer can have open at once. + lim.closed+lim.maxOpen, + // Max streams the peer can open with a single frame. + lim.opened+implicitStreamLimit, + ) + avail := lim.max - lim.opened + if newMax > lim.max && (avail < 8 || newMax-lim.max >= 2*avail) { + // If the peer has less than 8 streams, or if increasing the peer's + // stream limit would double it, then send a MAX_STREAMS. + lim.max = newMax + lim.sendMax.setUnsent() + } +} + +// appendFrame appends a MAX_STREAMS frame to the current packet, if necessary. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (lim *remoteStreamLimits) appendFrame(w *packetWriter, typ streamType, pnum packetNumber, pto bool) bool { + if lim.sendMax.shouldSendPTO(pto) { + if !w.appendMaxStreamsFrame(typ, lim.max) { + return false + } + lim.sendMax.setSent(pnum) + } + return true +} diff --git a/src/vendor/golang.org/x/net/quic/tls.go b/src/vendor/golang.org/x/net/quic/tls.go new file mode 100644 index 0000000000..9f6e0bc29a --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/tls.go @@ -0,0 +1,123 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "time" +) + +// startTLS starts the TLS handshake. +func (c *Conn) startTLS(now time.Time, initialConnID []byte, peerHostname string, params transportParameters) error { + tlsConfig := c.config.TLSConfig + if a, _, err := net.SplitHostPort(peerHostname); err == nil { + peerHostname = a + } + if tlsConfig.ServerName == "" && peerHostname != "" { + tlsConfig = tlsConfig.Clone() + tlsConfig.ServerName = peerHostname + } + + c.keysInitial = initialKeys(initialConnID, c.side) + + qconfig := &tls.QUICConfig{TLSConfig: tlsConfig} + if c.side == clientSide { + c.tls = tls.QUICClient(qconfig) + } else { + c.tls = tls.QUICServer(qconfig) + } + c.tls.SetTransportParameters(marshalTransportParameters(params)) + // TODO: We don't need or want a context for cancellation here, + // but users can use a context to plumb values through to hooks defined + // in the tls.Config. Pass through a context. + if err := c.tls.Start(context.TODO()); err != nil { + return err + } + return c.handleTLSEvents(now) +} + +func (c *Conn) handleTLSEvents(now time.Time) error { + for { + e := c.tls.NextEvent() + if c.testHooks != nil { + c.testHooks.handleTLSEvent(e) + } + switch e.Kind { + case tls.QUICNoEvent: + return nil + case tls.QUICSetReadSecret: + if err := checkCipherSuite(e.Suite); err != nil { + return err + } + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + c.keysHandshake.r.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + c.keysAppData.r.init(e.Suite, e.Data) + } + case tls.QUICSetWriteSecret: + if err := checkCipherSuite(e.Suite); err != nil { + return err + } + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + c.keysHandshake.w.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + c.keysAppData.w.init(e.Suite, e.Data) + } + case tls.QUICWriteData: + var space numberSpace + switch e.Level { + case tls.QUICEncryptionLevelInitial: + space = initialSpace + case tls.QUICEncryptionLevelHandshake: + space = handshakeSpace + case tls.QUICEncryptionLevelApplication: + space = appDataSpace + default: + return fmt.Errorf("quic: internal error: write handshake data at level %v", e.Level) + } + c.crypto[space].write(e.Data) + case tls.QUICHandshakeDone: + if c.side == serverSide { + // "[...] the TLS handshake is considered confirmed + // at the server when the handshake completes." + // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2-1 + c.confirmHandshake(now) + } + c.handshakeDone() + case tls.QUICTransportParameters: + params, err := unmarshalTransportParams(e.Data) + if err != nil { + return err + } + if err := c.receiveTransportParameters(params); err != nil { + return err + } + } + } +} + +// handleCrypto processes data received in a CRYPTO frame. +func (c *Conn) handleCrypto(now time.Time, space numberSpace, off int64, data []byte) error { + var level tls.QUICEncryptionLevel + switch space { + case initialSpace: + level = tls.QUICEncryptionLevelInitial + case handshakeSpace: + level = tls.QUICEncryptionLevelHandshake + case appDataSpace: + level = tls.QUICEncryptionLevelApplication + default: + return errors.New("quic: internal error: received CRYPTO frame in unexpected number space") + } + return c.crypto[space].handleCrypto(off, data, func(b []byte) error { + return c.tls.HandleData(level, b) + }) +} diff --git a/src/vendor/golang.org/x/net/quic/transport_params.go b/src/vendor/golang.org/x/net/quic/transport_params.go new file mode 100644 index 0000000000..2734c586de --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/transport_params.go @@ -0,0 +1,283 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "encoding/binary" + "net/netip" + "time" + + "golang.org/x/net/internal/quic/quicwire" +) + +// transportParameters transferred in the quic_transport_parameters TLS extension. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2 +type transportParameters struct { + originalDstConnID []byte + maxIdleTimeout time.Duration + statelessResetToken []byte + maxUDPPayloadSize int64 + initialMaxData int64 + initialMaxStreamDataBidiLocal int64 + initialMaxStreamDataBidiRemote int64 + initialMaxStreamDataUni int64 + initialMaxStreamsBidi int64 + initialMaxStreamsUni int64 + ackDelayExponent int8 + maxAckDelay time.Duration + disableActiveMigration bool + preferredAddrV4 netip.AddrPort + preferredAddrV6 netip.AddrPort + preferredAddrConnID []byte + preferredAddrResetToken []byte + activeConnIDLimit int64 + initialSrcConnID []byte + retrySrcConnID []byte +} + +const ( + defaultParamMaxUDPPayloadSize = 65527 + defaultParamAckDelayExponent = 3 + defaultParamMaxAckDelayMilliseconds = 25 + defaultParamActiveConnIDLimit = 2 +) + +// defaultTransportParameters is initialized to the RFC 9000 default values. +func defaultTransportParameters() transportParameters { + return transportParameters{ + maxUDPPayloadSize: defaultParamMaxUDPPayloadSize, + ackDelayExponent: defaultParamAckDelayExponent, + maxAckDelay: defaultParamMaxAckDelayMilliseconds * time.Millisecond, + activeConnIDLimit: defaultParamActiveConnIDLimit, + } +} + +const ( + paramOriginalDestinationConnectionID = 0x00 + paramMaxIdleTimeout = 0x01 + paramStatelessResetToken = 0x02 + paramMaxUDPPayloadSize = 0x03 + paramInitialMaxData = 0x04 + paramInitialMaxStreamDataBidiLocal = 0x05 + paramInitialMaxStreamDataBidiRemote = 0x06 + paramInitialMaxStreamDataUni = 0x07 + paramInitialMaxStreamsBidi = 0x08 + paramInitialMaxStreamsUni = 0x09 + paramAckDelayExponent = 0x0a + paramMaxAckDelay = 0x0b + paramDisableActiveMigration = 0x0c + paramPreferredAddress = 0x0d + paramActiveConnectionIDLimit = 0x0e + paramInitialSourceConnectionID = 0x0f + paramRetrySourceConnectionID = 0x10 +) + +func marshalTransportParameters(p transportParameters) []byte { + var b []byte + if v := p.originalDstConnID; v != nil { + b = quicwire.AppendVarint(b, paramOriginalDestinationConnectionID) + b = quicwire.AppendVarintBytes(b, v) + } + if v := uint64(p.maxIdleTimeout / time.Millisecond); v != 0 { + b = quicwire.AppendVarint(b, paramMaxIdleTimeout) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(v))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.statelessResetToken; v != nil { + b = quicwire.AppendVarint(b, paramStatelessResetToken) + b = quicwire.AppendVarintBytes(b, v) + } + if v := p.maxUDPPayloadSize; v != defaultParamMaxUDPPayloadSize { + b = quicwire.AppendVarint(b, paramMaxUDPPayloadSize) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxData; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxData) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamDataBidiLocal; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamDataBidiLocal) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamDataBidiRemote; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamDataBidiRemote) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamDataUni; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamDataUni) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamsBidi; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamsBidi) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialMaxStreamsUni; v != 0 { + b = quicwire.AppendVarint(b, paramInitialMaxStreamsUni) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.ackDelayExponent; v != defaultParamAckDelayExponent { + b = quicwire.AppendVarint(b, paramAckDelayExponent) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := uint64(p.maxAckDelay / time.Millisecond); v != defaultParamMaxAckDelayMilliseconds { + b = quicwire.AppendVarint(b, paramMaxAckDelay) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(v))) + b = quicwire.AppendVarint(b, v) + } + if p.disableActiveMigration { + b = quicwire.AppendVarint(b, paramDisableActiveMigration) + b = append(b, 0) // 0-length value + } + if p.preferredAddrConnID != nil { + b = append(b, paramPreferredAddress) + b = quicwire.AppendVarint(b, uint64(4+2+16+2+1+len(p.preferredAddrConnID)+16)) + b = append(b, p.preferredAddrV4.Addr().AsSlice()...) // 4 bytes + b = binary.BigEndian.AppendUint16(b, p.preferredAddrV4.Port()) // 2 bytes + b = append(b, p.preferredAddrV6.Addr().AsSlice()...) // 16 bytes + b = binary.BigEndian.AppendUint16(b, p.preferredAddrV6.Port()) // 2 bytes + b = quicwire.AppendUint8Bytes(b, p.preferredAddrConnID) // 1 byte + len(conn_id) + b = append(b, p.preferredAddrResetToken...) // 16 bytes + } + if v := p.activeConnIDLimit; v != defaultParamActiveConnIDLimit { + b = quicwire.AppendVarint(b, paramActiveConnectionIDLimit) + b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v)))) + b = quicwire.AppendVarint(b, uint64(v)) + } + if v := p.initialSrcConnID; v != nil { + b = quicwire.AppendVarint(b, paramInitialSourceConnectionID) + b = quicwire.AppendVarintBytes(b, v) + } + if v := p.retrySrcConnID; v != nil { + b = quicwire.AppendVarint(b, paramRetrySourceConnectionID) + b = quicwire.AppendVarintBytes(b, v) + } + return b +} + +func unmarshalTransportParams(params []byte) (transportParameters, error) { + p := defaultTransportParameters() + for len(params) > 0 { + id, n := quicwire.ConsumeVarint(params) + if n < 0 { + return p, localTransportError{code: errTransportParameter} + } + params = params[n:] + val, n := quicwire.ConsumeVarintBytes(params) + if n < 0 { + return p, localTransportError{code: errTransportParameter} + } + params = params[n:] + n = 0 + switch id { + case paramOriginalDestinationConnectionID: + p.originalDstConnID = val + n = len(val) + case paramMaxIdleTimeout: + var v uint64 + v, n = quicwire.ConsumeVarint(val) + // If this is unreasonably large, consider it as no timeout to avoid + // time.Duration overflows. + if v > 1<<32 { + v = 0 + } + p.maxIdleTimeout = time.Duration(v) * time.Millisecond + case paramStatelessResetToken: + if len(val) != 16 { + return p, localTransportError{code: errTransportParameter} + } + p.statelessResetToken = val + n = 16 + case paramMaxUDPPayloadSize: + p.maxUDPPayloadSize, n = quicwire.ConsumeVarintInt64(val) + if p.maxUDPPayloadSize < 1200 { + return p, localTransportError{code: errTransportParameter} + } + case paramInitialMaxData: + p.initialMaxData, n = quicwire.ConsumeVarintInt64(val) + case paramInitialMaxStreamDataBidiLocal: + p.initialMaxStreamDataBidiLocal, n = quicwire.ConsumeVarintInt64(val) + case paramInitialMaxStreamDataBidiRemote: + p.initialMaxStreamDataBidiRemote, n = quicwire.ConsumeVarintInt64(val) + case paramInitialMaxStreamDataUni: + p.initialMaxStreamDataUni, n = quicwire.ConsumeVarintInt64(val) + case paramInitialMaxStreamsBidi: + p.initialMaxStreamsBidi, n = quicwire.ConsumeVarintInt64(val) + if p.initialMaxStreamsBidi > maxStreamsLimit { + return p, localTransportError{code: errTransportParameter} + } + case paramInitialMaxStreamsUni: + p.initialMaxStreamsUni, n = quicwire.ConsumeVarintInt64(val) + if p.initialMaxStreamsUni > maxStreamsLimit { + return p, localTransportError{code: errTransportParameter} + } + case paramAckDelayExponent: + var v uint64 + v, n = quicwire.ConsumeVarint(val) + if v > 20 { + return p, localTransportError{code: errTransportParameter} + } + p.ackDelayExponent = int8(v) + case paramMaxAckDelay: + var v uint64 + v, n = quicwire.ConsumeVarint(val) + if v >= 1<<14 { + return p, localTransportError{code: errTransportParameter} + } + p.maxAckDelay = time.Duration(v) * time.Millisecond + case paramDisableActiveMigration: + p.disableActiveMigration = true + case paramPreferredAddress: + if len(val) < 4+2+16+2+1 { + return p, localTransportError{code: errTransportParameter} + } + p.preferredAddrV4 = netip.AddrPortFrom( + netip.AddrFrom4(*(*[4]byte)(val[:4])), + binary.BigEndian.Uint16(val[4:][:2]), + ) + val = val[4+2:] + p.preferredAddrV6 = netip.AddrPortFrom( + netip.AddrFrom16(*(*[16]byte)(val[:16])), + binary.BigEndian.Uint16(val[16:][:2]), + ) + val = val[16+2:] + var nn int + p.preferredAddrConnID, nn = quicwire.ConsumeUint8Bytes(val) + if nn < 0 { + return p, localTransportError{code: errTransportParameter} + } + val = val[nn:] + if len(val) != 16 { + return p, localTransportError{code: errTransportParameter} + } + p.preferredAddrResetToken = val + val = nil + case paramActiveConnectionIDLimit: + p.activeConnIDLimit, n = quicwire.ConsumeVarintInt64(val) + if p.activeConnIDLimit < 2 { + return p, localTransportError{code: errTransportParameter} + } + case paramInitialSourceConnectionID: + p.initialSrcConnID = val + n = len(val) + case paramRetrySourceConnectionID: + p.retrySrcConnID = val + n = len(val) + default: + n = len(val) + } + if n != len(val) { + return p, localTransportError{code: errTransportParameter} + } + } + return p, nil +} diff --git a/src/vendor/golang.org/x/net/quic/udp.go b/src/vendor/golang.org/x/net/quic/udp.go new file mode 100644 index 0000000000..cf23c5ce88 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp.go @@ -0,0 +1,28 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "net/netip" + +// Per-plaform consts describing support for various features. +// +// const udpECNSupport indicates whether the platform supports setting +// the ECN (Explicit Congestion Notification) IP header bits. +// +// const udpInvalidLocalAddrIsError indicates whether sending a packet +// from an local address not associated with the system is an error. +// For example, assuming 127.0.0.2 is not a local address, does sending +// from it (using IP_PKTINFO or some other such feature) result in an error? + +// unmapAddrPort returns a with any IPv4-mapped IPv6 address prefix removed. +func unmapAddrPort(a netip.AddrPort) netip.AddrPort { + if a.Addr().Is4In6() { + return netip.AddrPortFrom( + a.Addr().Unmap(), + a.Port(), + ) + } + return a +} diff --git a/src/vendor/golang.org/x/net/quic/udp_darwin.go b/src/vendor/golang.org/x/net/quic/udp_darwin.go new file mode 100644 index 0000000000..91e8e81c17 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_darwin.go @@ -0,0 +1,45 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build darwin + +package quic + +import ( + "encoding/binary" + "syscall" +) + +// These socket options are available on darwin, but are not in the syscall +// package. Since syscall package is frozen, just define them manually here. +const ( + ip_recvtos = 0x1b + ipv6_recvpktinfo = 0x3d + ipv6_pktinfo = 0x2e +) + +// See udp.go. +const ( + udpECNSupport = true + udpInvalidLocalAddrIsError = true +) + +// Confusingly, on Darwin the contents of the IP_TOS option differ depending on whether +// it is used as an inbound or outbound cmsg. + +func parseIPTOS(b []byte) (ecnBits, bool) { + // Single byte. The low two bits are the ECN field. + if len(b) != 1 { + return 0, false + } + return ecnBits(b[0] & ecnMask), true +} + +func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { + // 32-bit integer. + // https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/in_tclass.c#L1062-L1073 + b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_TOS, 4) + binary.NativeEndian.PutUint32(data, uint32(ecn)) + return b +} diff --git a/src/vendor/golang.org/x/net/quic/udp_linux.go b/src/vendor/golang.org/x/net/quic/udp_linux.go new file mode 100644 index 0000000000..08deaf9829 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_linux.go @@ -0,0 +1,39 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux + +package quic + +import ( + "syscall" +) + +const ( + ip_recvtos = syscall.IP_RECVTOS + ipv6_recvpktinfo = syscall.IPV6_RECVPKTINFO + ipv6_pktinfo = syscall.IPV6_PKTINFO +) + +// See udp.go. +const ( + udpECNSupport = true + udpInvalidLocalAddrIsError = false +) + +// The IP_TOS socket option is a single byte containing the IP TOS field. +// The low two bits are the ECN field. + +func parseIPTOS(b []byte) (ecnBits, bool) { + if len(b) != 1 { + return 0, false + } + return ecnBits(b[0] & ecnMask), true +} + +func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_TOS, 1) + data[0] = byte(ecn) + return b +} diff --git a/src/vendor/golang.org/x/net/quic/udp_msg.go b/src/vendor/golang.org/x/net/quic/udp_msg.go new file mode 100644 index 0000000000..10909044fe --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_msg.go @@ -0,0 +1,245 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !quicbasicnet && (darwin || linux) + +package quic + +import ( + "encoding/binary" + "net" + "net/netip" + "sync" + "syscall" + "unsafe" +) + +// Network interface for platforms using sendmsg/recvmsg with cmsgs. + +type netUDPConn struct { + c *net.UDPConn + localAddr netip.AddrPort +} + +func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) { + a, _ := uc.LocalAddr().(*net.UDPAddr) + localAddr := a.AddrPort() + if localAddr.Addr().IsUnspecified() { + // If the conn is not bound to a specified (non-wildcard) address, + // then set localAddr.Addr to an invalid netip.Addr. + // This better conveys that this is not an address we should be using, + // and is a bit more efficient to test against. + localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port()) + } + + sc, err := uc.SyscallConn() + if err != nil { + return nil, err + } + sc.Control(func(fd uintptr) { + // Ask for ECN info and (when we aren't bound to a fixed local address) + // destination info. + // + // If any of these calls fail, we won't get the requested information. + // That's fine, we'll gracefully handle the lack. + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ip_recvtos, 1) + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_RECVTCLASS, 1) + if !localAddr.IsValid() { + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1) + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, ipv6_recvpktinfo, 1) + } + }) + + return &netUDPConn{ + c: uc, + localAddr: localAddr, + }, nil +} + +func (c *netUDPConn) Close() error { return c.c.Close() } + +func (c *netUDPConn) LocalAddr() netip.AddrPort { + a, _ := c.c.LocalAddr().(*net.UDPAddr) + return a.AddrPort() +} + +func (c *netUDPConn) Read(f func(*datagram)) { + // We shouldn't ever see all of these messages at the same time, + // but the total is small so just allocate enough space for everything we use. + const ( + inPktinfoSize = 12 // int + in_addr + in_addr + in6PktinfoSize = 20 // in6_addr + int + ipTOSSize = 4 + ipv6TclassSize = 4 + ) + control := make([]byte, 0+ + syscall.CmsgSpace(inPktinfoSize)+ + syscall.CmsgSpace(in6PktinfoSize)+ + syscall.CmsgSpace(ipTOSSize)+ + syscall.CmsgSpace(ipv6TclassSize)) + + for { + d := newDatagram() + n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control) + if err != nil { + return + } + if n == 0 { + continue + } + d.localAddr = c.localAddr + d.peerAddr = unmapAddrPort(peerAddr) + d.b = d.b[:n] + parseControl(d, control[:controlLen]) + f(d) + } +} + +var cmsgPool = sync.Pool{ + New: func() any { + return new([]byte) + }, +} + +func (c *netUDPConn) Write(dgram datagram) error { + controlp := cmsgPool.Get().(*[]byte) + control := *controlp + defer func() { + *controlp = control[:0] + cmsgPool.Put(controlp) + }() + + localIP := dgram.localAddr.Addr() + if localIP.IsValid() { + if localIP.Is4() { + control = appendCmsgIPSourceAddrV4(control, localIP) + } else { + control = appendCmsgIPSourceAddrV6(control, localIP) + } + } + if dgram.ecn != ecnNotECT { + if dgram.peerAddr.Addr().Is4() { + control = appendCmsgECNv4(control, dgram.ecn) + } else { + control = appendCmsgECNv6(control, dgram.ecn) + } + } + + _, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr) + return err +} + +func parseControl(d *datagram, control []byte) { + msgs, err := syscall.ParseSocketControlMessage(control) + if err != nil { + return + } + for _, m := range msgs { + switch m.Header.Level { + case syscall.IPPROTO_IP: + switch m.Header.Type { + case syscall.IP_TOS, ip_recvtos: + // (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS, + // just check for both.) + if ecn, ok := parseIPTOS(m.Data); ok { + d.ecn = ecn + } + case syscall.IP_PKTINFO: + if a, ok := parseInPktinfo(m.Data); ok { + d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) + } + } + case syscall.IPPROTO_IPV6: + switch m.Header.Type { + case syscall.IPV6_TCLASS: + // 32-bit integer containing the traffic class field. + // The low two bits are the ECN field. + if ecn, ok := parseIPv6TCLASS(m.Data); ok { + d.ecn = ecn + } + case ipv6_pktinfo: + if a, ok := parseIn6Pktinfo(m.Data); ok { + d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) + } + } + } + } +} + +// IPV6_TCLASS is specified by RFC 3542 as an int. + +func parseIPv6TCLASS(b []byte) (ecnBits, bool) { + if len(b) != 4 { + return 0, false + } + return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true +} + +func appendCmsgECNv6(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, 4) + binary.NativeEndian.PutUint32(data, uint32(ecn)) + return b +} + +// struct in_pktinfo { +// unsigned int ipi_ifindex; /* send/recv interface index */ +// struct in_addr ipi_spec_dst; /* Local address */ +// struct in_addr ipi_addr; /* IP Header dst address */ +// }; + +// parseInPktinfo returns the destination address from an IP_PKTINFO. +func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) { + if len(b) != 12 { + return netip.Addr{}, false + } + return netip.AddrFrom4([4]byte(b[8:][:4])), true +} + +// appendCmsgIPSourceAddrV4 appends an IP_PKTINFO setting the source address +// for an outbound datagram. +func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte { + // struct in_pktinfo { + // unsigned int ipi_ifindex; /* send/recv interface index */ + // struct in_addr ipi_spec_dst; /* Local address */ + // struct in_addr ipi_addr; /* IP Header dst address */ + // }; + b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_PKTINFO, 12) + ip := src.As4() + copy(data[4:], ip[:]) + return b +} + +// struct in6_pktinfo { +// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ +// unsigned int ipi6_ifindex; /* send/recv interface index */ +// }; + +// parseIn6Pktinfo returns the destination address from an IPV6_PKTINFO. +func parseIn6Pktinfo(b []byte) (netip.Addr, bool) { + if len(b) != 20 { + return netip.Addr{}, false + } + return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true +} + +// appendCmsgIPSourceAddrV6 appends an IPV6_PKTINFO setting the source address +// for an outbound datagram. +func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte { + b, data := appendCmsg(b, syscall.IPPROTO_IPV6, ipv6_pktinfo, 20) + ip := src.As16() + copy(data[0:], ip[:]) + return b +} + +// appendCmsg appends a cmsg with the given level, type, and size to b. +// It returns the new buffer, and the data section of the cmsg. +func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) { + off := len(b) + b = append(b, make([]byte, syscall.CmsgSpace(size))...) + h := (*syscall.Cmsghdr)(unsafe.Pointer(&b[off])) + h.Level = level + h.Type = typ + h.SetLen(syscall.CmsgLen(size)) + return b, b[off+syscall.CmsgSpace(0):][:size] +} diff --git a/src/vendor/golang.org/x/net/quic/udp_other.go b/src/vendor/golang.org/x/net/quic/udp_other.go new file mode 100644 index 0000000000..02e4a5fc23 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_other.go @@ -0,0 +1,62 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build quicbasicnet || !(darwin || linux) + +package quic + +import ( + "net" + "net/netip" +) + +// Lowest common denominator network interface: Basic net.UDPConn, no cmsgs. +// We will not be able to send or receive ECN bits, +// and we will not know what our local address is. +// +// The quicbasicnet build tag allows selecting this interface on any platform. + +// See udp.go. +const ( + udpECNSupport = false + udpInvalidLocalAddrIsError = false +) + +type netUDPConn struct { + c *net.UDPConn +} + +func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) { + return &netUDPConn{ + c: uc, + }, nil +} + +func (c *netUDPConn) Close() error { return c.c.Close() } + +func (c *netUDPConn) LocalAddr() netip.AddrPort { + a, _ := c.c.LocalAddr().(*net.UDPAddr) + return a.AddrPort() +} + +func (c *netUDPConn) Read(f func(*datagram)) { + for { + dgram := newDatagram() + n, peerAddr, err := c.c.ReadFromUDPAddrPort(dgram.b) + if err != nil { + return + } + if n == 0 { + continue + } + dgram.peerAddr = unmapAddrPort(peerAddr) + dgram.b = dgram.b[:n] + f(dgram) + } +} + +func (c *netUDPConn) Write(dgram datagram) error { + _, err := c.c.WriteToUDPAddrPort(dgram.b, dgram.peerAddr) + return err +} diff --git a/src/vendor/golang.org/x/net/quic/udp_packetconn.go b/src/vendor/golang.org/x/net/quic/udp_packetconn.go new file mode 100644 index 0000000000..2c7e71cf61 --- /dev/null +++ b/src/vendor/golang.org/x/net/quic/udp_packetconn.go @@ -0,0 +1,67 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "net" + "net/netip" +) + +// netPacketConn is a packetConn implementation wrapping a net.PacketConn. +// +// This is mostly useful for tests, since PacketConn doesn't provide access to +// important features such as identifying the local address packets were received on. +type netPacketConn struct { + c net.PacketConn + localAddr netip.AddrPort +} + +func newNetPacketConn(pc net.PacketConn) (*netPacketConn, error) { + addr, err := addrPortFromAddr(pc.LocalAddr()) + if err != nil { + return nil, err + } + return &netPacketConn{ + c: pc, + localAddr: addr, + }, nil +} + +func (c *netPacketConn) Close() error { + return c.c.Close() +} + +func (c *netPacketConn) LocalAddr() netip.AddrPort { + return c.localAddr +} + +func (c *netPacketConn) Read(f func(*datagram)) { + for { + dgram := newDatagram() + n, peerAddr, err := c.c.ReadFrom(dgram.b) + if err != nil { + return + } + dgram.peerAddr, err = addrPortFromAddr(peerAddr) + if err != nil { + continue + } + dgram.b = dgram.b[:n] + f(dgram) + } +} + +func (c *netPacketConn) Write(dgram datagram) error { + _, err := c.c.WriteTo(dgram.b, net.UDPAddrFromAddrPort(dgram.peerAddr)) + return err +} + +func addrPortFromAddr(addr net.Addr) (netip.AddrPort, error) { + switch a := addr.(type) { + case *net.UDPAddr: + return a.AddrPort(), nil + } + return netip.ParseAddrPort(addr.String()) +} |
