aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/crypto/tls/conn.go49
1 files changed, 24 insertions, 25 deletions
diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go
index 9c662ef8f6..54a4d1a883 100644
--- a/src/crypto/tls/conn.go
+++ b/src/crypto/tls/conn.go
@@ -800,29 +800,6 @@ func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
return c.readRecordOrCCS(expectChangeCipherSpec)
}
-// atLeastReader reads from R, stopping with EOF once at least N bytes have been
-// read. It is different from an io.LimitedReader in that it doesn't cut short
-// the last Read call, and in that it considers an early EOF an error.
-type atLeastReader struct {
- R io.Reader
- N int64
-}
-
-func (r *atLeastReader) Read(p []byte) (int, error) {
- if r.N <= 0 {
- return 0, io.EOF
- }
- n, err := r.R.Read(p)
- r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
- if r.N > 0 && err == io.EOF {
- return n, io.ErrUnexpectedEOF
- }
- if r.N <= 0 && err == nil {
- return n, io.EOF
- }
- return n, err
-}
-
// readFromUntil reads from r into c.rawInput until c.rawInput contains
// at least n bytes or else returns an error.
func (c *Conn) readFromUntil(r io.Reader, n int) error {
@@ -833,9 +810,31 @@ func (c *Conn) readFromUntil(r io.Reader, n int) error {
// There might be extra input waiting on the wire. Make a best effort
// attempt to fetch it so that it can be used in (*Conn).Read to
// "predict" closeNotify alerts.
+ // TODO(dmo): we use bytes.MinRead here because we used the buffer
+ // ReadFrom mechanism to avoid allocations, but we've hoisted this
+ // loop for performance. We really should use our own heuristic here
+ // for how much to read ahead.
c.rawInput.Grow(needs + bytes.MinRead)
- _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
- return err
+ for {
+ buf := c.rawInput.AvailableBuffer()[:needs+bytes.MinRead]
+ n, err := r.Read(buf)
+ // This write is just to update the internal state of the
+ // rawInput bytes.Buffer. It cannot fail.
+ c.rawInput.Write(buf[:n])
+ needs -= n
+ if needs <= 0 {
+ if err == io.EOF {
+ err = nil
+ }
+ return err
+ }
+ if err == io.EOF {
+ return io.ErrUnexpectedEOF
+ }
+ if err != nil {
+ return err
+ }
+ }
}
// sendAlertLocked sends a TLS alert message.