diff options
Diffstat (limited to 'src/crypto/tls/conn.go')
| -rw-r--r-- | src/crypto/tls/conn.go | 189 |
1 files changed, 136 insertions, 53 deletions
diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index edcfecf81d..969f357834 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -8,11 +8,13 @@ package tls import ( "bytes" + "context" "crypto/cipher" "crypto/subtle" "crypto/x509" "errors" "fmt" + "hash" "io" "net" "sync" @@ -26,7 +28,7 @@ type Conn struct { // constant conn net.Conn isClient bool - handshakeFn func() error // (*Conn).clientHandshake or serverHandshake + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake // handshakeStatus is 1 if the connection is currently transferring // application data (i.e. is not currently processing a handshake). @@ -86,15 +88,14 @@ type Conn struct { clientFinished [12]byte serverFinished [12]byte - clientProtocol string - clientProtocolFallback bool + // clientProtocol is the negotiated ALPN protocol. + clientProtocol string // input/output in, out halfConn rawInput bytes.Buffer // raw input, starting with a record header input bytes.Reader // application data waiting to be read, from rawInput.Next hand bytes.Buffer // handshake data waiting to be read - outBuf []byte // scratch buffer used by out.encrypt buffering bool // whether records are buffered in sendBuf sendBuf []byte // a buffer of records waiting to be sent @@ -155,31 +156,32 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { type halfConn struct { sync.Mutex - err error // first permanent error - version uint16 // protocol version - cipher interface{} // cipher algorithm - mac macFunction - seq [8]byte // 64-bit sequence number - additionalData [13]byte // to avoid allocs; interface method args escape + err error // first permanent error + version uint16 // protocol version + cipher interface{} // cipher algorithm + mac hash.Hash + seq [8]byte // 64-bit sequence number + + scratchBuf [13]byte // to avoid allocs; interface method args escape nextCipher interface{} // next encryption state - nextMac macFunction // next MAC algorithm + nextMac hash.Hash // next MAC algorithm trafficSecret []byte // current TLS 1.3 traffic secret } -type permamentError struct { +type permanentError struct { err net.Error } -func (e *permamentError) Error() string { return e.err.Error() } -func (e *permamentError) Unwrap() error { return e.err } -func (e *permamentError) Timeout() bool { return e.err.Timeout() } -func (e *permamentError) Temporary() bool { return false } +func (e *permanentError) Error() string { return e.err.Error() } +func (e *permanentError) Unwrap() error { return e.err } +func (e *permanentError) Timeout() bool { return e.err.Timeout() } +func (e *permanentError) Temporary() bool { return false } func (hc *halfConn) setErrorLocked(err error) error { if e, ok := err.(net.Error); ok { - hc.err = &permamentError{err: e} + hc.err = &permanentError{err: e} } else { hc.err = err } @@ -188,7 +190,7 @@ func (hc *halfConn) setErrorLocked(err error) error { // prepareCipherSpec sets the encryption and MAC states // that a subsequent changeCipherSpec will use. -func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) { +func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) { hc.version = version hc.nextCipher = cipher hc.nextMac = mac @@ -350,15 +352,14 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { } payload = payload[explicitNonceLen:] - additionalData := hc.additionalData[:] + var additionalData []byte if hc.version == VersionTLS13 { additionalData = record[:recordHeaderLen] } else { - copy(additionalData, hc.seq[:]) - copy(additionalData[8:], record[:3]) + additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:3]...) n := len(payload) - c.Overhead() - additionalData[11] = byte(n >> 8) - additionalData[12] = byte(n) + additionalData = append(additionalData, byte(n>>8), byte(n)) } var err error @@ -424,7 +425,7 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { record[3] = byte(n >> 8) record[4] = byte(n) remoteMAC := payload[n : n+macSize] - localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) + localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) // This is equivalent to checking the MACs and paddingGood // separately, but in constant-time to prevent distinguishing @@ -460,7 +461,7 @@ func sliceForAppend(in []byte, n int) (head, tail []byte) { } // encrypt encrypts payload, adding the appropriate nonce and/or MAC, and -// appends it to record, which contains the record header. +// appends it to record, which must already contain the record header. func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { if hc.cipher == nil { return append(record, payload...), nil @@ -477,7 +478,7 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err // an 8 bytes nonce but its nonces must be unpredictable (see RFC // 5246, Appendix F.3), forcing us to use randomness. That's not // 3DES' biggest problem anyway because the birthday bound on block - // collision is reached first due to its simlarly small block size + // collision is reached first due to its similarly small block size // (see the Sweet32 attack). copy(explicitNonce, hc.seq[:]) } else { @@ -487,14 +488,10 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err } } - var mac []byte - if hc.mac != nil { - mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil) - } - var dst []byte switch c := hc.cipher.(type) { case cipher.Stream: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) record, dst = sliceForAppend(record, len(payload)+len(mac)) c.XORKeyStream(dst[:len(payload)], payload) c.XORKeyStream(dst[len(payload):], mac) @@ -518,11 +515,12 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err record = c.Seal(record[:recordHeaderLen], nonce, record[recordHeaderLen:], record[:recordHeaderLen]) } else { - copy(hc.additionalData[:], hc.seq[:]) - copy(hc.additionalData[8:], record) - record = c.Seal(record, nonce, payload, hc.additionalData[:]) + additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:recordHeaderLen]...) + record = c.Seal(record, nonce, payload, additionalData) } case cbcMode: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) blockSize := c.BlockSize() plaintextLen := len(payload) + len(mac) paddingLen := blockSize - plaintextLen%blockSize @@ -928,9 +926,28 @@ func (c *Conn) flush() (int, error) { return n, err } +// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. +var outBufPool = sync.Pool{ + New: func() interface{} { + return new([]byte) + }, +} + // writeRecordLocked writes a TLS record with the given type and payload to the // connection and updates the record layer state. func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + outBufPtr := outBufPool.Get().(*[]byte) + outBuf := *outBufPtr + defer func() { + // You might be tempted to simplify this by just passing &outBuf to Put, + // but that would make the local copy of the outBuf slice header escape + // to the heap, causing an allocation. Instead, we keep around the + // pointer to the slice header returned by Get, which is already on the + // heap, and overwrite and return that. + *outBufPtr = outBuf + outBufPool.Put(outBufPtr) + }() + var n int for len(data) > 0 { m := len(data) @@ -938,8 +955,8 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { m = maxPayload } - _, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen) - c.outBuf[0] = byte(typ) + _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) + outBuf[0] = byte(typ) vers := c.vers if vers == 0 { // Some TLS servers fail if the record version is @@ -950,17 +967,17 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { // See RFC 8446, Section 5.1. vers = VersionTLS12 } - c.outBuf[1] = byte(vers >> 8) - c.outBuf[2] = byte(vers) - c.outBuf[3] = byte(m >> 8) - c.outBuf[4] = byte(m) + outBuf[1] = byte(vers >> 8) + outBuf[2] = byte(vers) + outBuf[3] = byte(m >> 8) + outBuf[4] = byte(m) var err error - c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand()) + outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) if err != nil { return n, err } - if _, err := c.write(c.outBuf); err != nil { + if _, err := c.write(outBuf); err != nil { return n, err } n += m @@ -1070,17 +1087,21 @@ func (c *Conn) readHandshake() (interface{}, error) { } var ( - errClosed = errors.New("tls: use of closed connection") errShutdown = errors.New("tls: protocol is shutdown") ) // Write writes data to the connection. +// +// As Write calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Write is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. func (c *Conn) Write(b []byte) (int, error) { // interlock with Close below for { x := atomic.LoadInt32(&c.activeCall) if x&1 != 0 { - return 0, errClosed + return 0, net.ErrClosed } if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { break @@ -1170,7 +1191,7 @@ func (c *Conn) handleRenegotiation() error { defer c.handshakeMutex.Unlock() atomic.StoreUint32(&c.handshakeStatus, 0) - if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { c.handshakes++ } return c.handshakeErr @@ -1233,8 +1254,12 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { return nil } -// Read can be made to time out and return a net.Error with Timeout() == true -// after a fixed time limit; see SetDeadline and SetReadDeadline. +// Read reads data from the connection. +// +// As Read calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Read is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. func (c *Conn) Read(b []byte) (int, error) { if err := c.Handshake(); err != nil { return 0, err @@ -1285,7 +1310,7 @@ func (c *Conn) Close() error { for { x = atomic.LoadInt32(&c.activeCall) if x&1 != 0 { - return errClosed + return net.ErrClosed } if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { break @@ -1302,9 +1327,10 @@ func (c *Conn) Close() error { } var alertErr error - if c.handshakeComplete() { - alertErr = c.closeNotify() + if err := c.closeNotify(); err != nil { + alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) + } } if err := c.conn.Close(); err != nil { @@ -1331,8 +1357,12 @@ func (c *Conn) closeNotify() error { defer c.out.Unlock() if !c.closeNotifySent { + // Set a Write Deadline to prevent possibly blocking forever. + c.SetWriteDeadline(time.Now().Add(time.Second * 5)) c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) c.closeNotifySent = true + // Any subsequent writes will fail. + c.SetWriteDeadline(time.Now()) } return c.closeNotifyErr } @@ -1344,8 +1374,61 @@ func (c *Conn) closeNotify() error { // first Read or Write will call it automatically. // // For control over canceling or setting a timeout on a handshake, use -// the Dialer's DialContext method. +// HandshakeContext or the Dialer's DialContext method instead. func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first Read or Write will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + // Delegate to unexported method for named return + // without confusing documented signature. + return c.handshakeContext(ctx) +} + +func (c *Conn) handshakeContext(ctx context.Context) (ret error) { + handshakeCtx, cancel := context.WithCancel(ctx) + // Note: defer this before starting the "interrupter" goroutine + // so that we can tell the difference between the input being canceled and + // this cancellation. In the former case, we need to close the connection. + defer cancel() + + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot). + // + // The interrupter goroutine waits for the input context to be done and + // closes the connection if this happens before the function returns. + if ctx.Done() != nil { + done := make(chan struct{}) + interruptRes := make(chan error, 1) + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil { + // Return context error to user. + ret = ctxErr + } + }() + go func() { + select { + case <-handshakeCtx.Done(): + // Close the connection, discarding the error + _ = c.conn.Close() + interruptRes <- handshakeCtx.Err() + case <-done: + interruptRes <- nil + } + }() + } + c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() @@ -1359,7 +1442,7 @@ func (c *Conn) Handshake() error { c.in.Lock() defer c.in.Unlock() - c.handshakeErr = c.handshakeFn() + c.handshakeErr = c.handshakeFn(handshakeCtx) if c.handshakeErr == nil { c.handshakes++ } else { @@ -1388,7 +1471,7 @@ func (c *Conn) connectionStateLocked() ConnectionState { state.Version = c.vers state.NegotiatedProtocol = c.clientProtocol state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback + state.NegotiatedProtocolIsMutual = true state.ServerName = c.serverName state.CipherSuite = c.cipherSuite state.PeerCertificates = c.peerCertificates |
