aboutsummaryrefslogtreecommitdiff
path: root/src/crypto/tls/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/tls/conn.go')
-rw-r--r--src/crypto/tls/conn.go189
1 files changed, 136 insertions, 53 deletions
diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go
index edcfecf81d..969f357834 100644
--- a/src/crypto/tls/conn.go
+++ b/src/crypto/tls/conn.go
@@ -8,11 +8,13 @@ package tls
import (
"bytes"
+ "context"
"crypto/cipher"
"crypto/subtle"
"crypto/x509"
"errors"
"fmt"
+ "hash"
"io"
"net"
"sync"
@@ -26,7 +28,7 @@ type Conn struct {
// constant
conn net.Conn
isClient bool
- handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
+ handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
// handshakeStatus is 1 if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
@@ -86,15 +88,14 @@ type Conn struct {
clientFinished [12]byte
serverFinished [12]byte
- clientProtocol string
- clientProtocolFallback bool
+ // clientProtocol is the negotiated ALPN protocol.
+ clientProtocol string
// input/output
in, out halfConn
rawInput bytes.Buffer // raw input, starting with a record header
input bytes.Reader // application data waiting to be read, from rawInput.Next
hand bytes.Buffer // handshake data waiting to be read
- outBuf []byte // scratch buffer used by out.encrypt
buffering bool // whether records are buffered in sendBuf
sendBuf []byte // a buffer of records waiting to be sent
@@ -155,31 +156,32 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
type halfConn struct {
sync.Mutex
- err error // first permanent error
- version uint16 // protocol version
- cipher interface{} // cipher algorithm
- mac macFunction
- seq [8]byte // 64-bit sequence number
- additionalData [13]byte // to avoid allocs; interface method args escape
+ err error // first permanent error
+ version uint16 // protocol version
+ cipher interface{} // cipher algorithm
+ mac hash.Hash
+ seq [8]byte // 64-bit sequence number
+
+ scratchBuf [13]byte // to avoid allocs; interface method args escape
nextCipher interface{} // next encryption state
- nextMac macFunction // next MAC algorithm
+ nextMac hash.Hash // next MAC algorithm
trafficSecret []byte // current TLS 1.3 traffic secret
}
-type permamentError struct {
+type permanentError struct {
err net.Error
}
-func (e *permamentError) Error() string { return e.err.Error() }
-func (e *permamentError) Unwrap() error { return e.err }
-func (e *permamentError) Timeout() bool { return e.err.Timeout() }
-func (e *permamentError) Temporary() bool { return false }
+func (e *permanentError) Error() string { return e.err.Error() }
+func (e *permanentError) Unwrap() error { return e.err }
+func (e *permanentError) Timeout() bool { return e.err.Timeout() }
+func (e *permanentError) Temporary() bool { return false }
func (hc *halfConn) setErrorLocked(err error) error {
if e, ok := err.(net.Error); ok {
- hc.err = &permamentError{err: e}
+ hc.err = &permanentError{err: e}
} else {
hc.err = err
}
@@ -188,7 +190,7 @@ func (hc *halfConn) setErrorLocked(err error) error {
// prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use.
-func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
+func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) {
hc.version = version
hc.nextCipher = cipher
hc.nextMac = mac
@@ -350,15 +352,14 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
}
payload = payload[explicitNonceLen:]
- additionalData := hc.additionalData[:]
+ var additionalData []byte
if hc.version == VersionTLS13 {
additionalData = record[:recordHeaderLen]
} else {
- copy(additionalData, hc.seq[:])
- copy(additionalData[8:], record[:3])
+ additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
+ additionalData = append(additionalData, record[:3]...)
n := len(payload) - c.Overhead()
- additionalData[11] = byte(n >> 8)
- additionalData[12] = byte(n)
+ additionalData = append(additionalData, byte(n>>8), byte(n))
}
var err error
@@ -424,7 +425,7 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
record[3] = byte(n >> 8)
record[4] = byte(n)
remoteMAC := payload[n : n+macSize]
- localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
+ localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
// This is equivalent to checking the MACs and paddingGood
// separately, but in constant-time to prevent distinguishing
@@ -460,7 +461,7 @@ func sliceForAppend(in []byte, n int) (head, tail []byte) {
}
// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
-// appends it to record, which contains the record header.
+// appends it to record, which must already contain the record header.
func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
if hc.cipher == nil {
return append(record, payload...), nil
@@ -477,7 +478,7 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
// an 8 bytes nonce but its nonces must be unpredictable (see RFC
// 5246, Appendix F.3), forcing us to use randomness. That's not
// 3DES' biggest problem anyway because the birthday bound on block
- // collision is reached first due to its simlarly small block size
+ // collision is reached first due to its similarly small block size
// (see the Sweet32 attack).
copy(explicitNonce, hc.seq[:])
} else {
@@ -487,14 +488,10 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
}
}
- var mac []byte
- if hc.mac != nil {
- mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil)
- }
-
var dst []byte
switch c := hc.cipher.(type) {
case cipher.Stream:
+ mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
record, dst = sliceForAppend(record, len(payload)+len(mac))
c.XORKeyStream(dst[:len(payload)], payload)
c.XORKeyStream(dst[len(payload):], mac)
@@ -518,11 +515,12 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
record = c.Seal(record[:recordHeaderLen],
nonce, record[recordHeaderLen:], record[:recordHeaderLen])
} else {
- copy(hc.additionalData[:], hc.seq[:])
- copy(hc.additionalData[8:], record)
- record = c.Seal(record, nonce, payload, hc.additionalData[:])
+ additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
+ additionalData = append(additionalData, record[:recordHeaderLen]...)
+ record = c.Seal(record, nonce, payload, additionalData)
}
case cbcMode:
+ mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
blockSize := c.BlockSize()
plaintextLen := len(payload) + len(mac)
paddingLen := blockSize - plaintextLen%blockSize
@@ -928,9 +926,28 @@ func (c *Conn) flush() (int, error) {
return n, err
}
+// outBufPool pools the record-sized scratch buffers used by writeRecordLocked.
+var outBufPool = sync.Pool{
+ New: func() interface{} {
+ return new([]byte)
+ },
+}
+
// writeRecordLocked writes a TLS record with the given type and payload to the
// connection and updates the record layer state.
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
+ outBufPtr := outBufPool.Get().(*[]byte)
+ outBuf := *outBufPtr
+ defer func() {
+ // You might be tempted to simplify this by just passing &outBuf to Put,
+ // but that would make the local copy of the outBuf slice header escape
+ // to the heap, causing an allocation. Instead, we keep around the
+ // pointer to the slice header returned by Get, which is already on the
+ // heap, and overwrite and return that.
+ *outBufPtr = outBuf
+ outBufPool.Put(outBufPtr)
+ }()
+
var n int
for len(data) > 0 {
m := len(data)
@@ -938,8 +955,8 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
m = maxPayload
}
- _, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen)
- c.outBuf[0] = byte(typ)
+ _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
+ outBuf[0] = byte(typ)
vers := c.vers
if vers == 0 {
// Some TLS servers fail if the record version is
@@ -950,17 +967,17 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
// See RFC 8446, Section 5.1.
vers = VersionTLS12
}
- c.outBuf[1] = byte(vers >> 8)
- c.outBuf[2] = byte(vers)
- c.outBuf[3] = byte(m >> 8)
- c.outBuf[4] = byte(m)
+ outBuf[1] = byte(vers >> 8)
+ outBuf[2] = byte(vers)
+ outBuf[3] = byte(m >> 8)
+ outBuf[4] = byte(m)
var err error
- c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand())
+ outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
if err != nil {
return n, err
}
- if _, err := c.write(c.outBuf); err != nil {
+ if _, err := c.write(outBuf); err != nil {
return n, err
}
n += m
@@ -1070,17 +1087,21 @@ func (c *Conn) readHandshake() (interface{}, error) {
}
var (
- errClosed = errors.New("tls: use of closed connection")
errShutdown = errors.New("tls: protocol is shutdown")
)
// Write writes data to the connection.
+//
+// As Write calls Handshake, in order to prevent indefinite blocking a deadline
+// must be set for both Read and Write before Write is called when the handshake
+// has not yet completed. See SetDeadline, SetReadDeadline, and
+// SetWriteDeadline.
func (c *Conn) Write(b []byte) (int, error) {
// interlock with Close below
for {
x := atomic.LoadInt32(&c.activeCall)
if x&1 != 0 {
- return 0, errClosed
+ return 0, net.ErrClosed
}
if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
break
@@ -1170,7 +1191,7 @@ func (c *Conn) handleRenegotiation() error {
defer c.handshakeMutex.Unlock()
atomic.StoreUint32(&c.handshakeStatus, 0)
- if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
+ if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
c.handshakes++
}
return c.handshakeErr
@@ -1233,8 +1254,12 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
return nil
}
-// Read can be made to time out and return a net.Error with Timeout() == true
-// after a fixed time limit; see SetDeadline and SetReadDeadline.
+// Read reads data from the connection.
+//
+// As Read calls Handshake, in order to prevent indefinite blocking a deadline
+// must be set for both Read and Write before Read is called when the handshake
+// has not yet completed. See SetDeadline, SetReadDeadline, and
+// SetWriteDeadline.
func (c *Conn) Read(b []byte) (int, error) {
if err := c.Handshake(); err != nil {
return 0, err
@@ -1285,7 +1310,7 @@ func (c *Conn) Close() error {
for {
x = atomic.LoadInt32(&c.activeCall)
if x&1 != 0 {
- return errClosed
+ return net.ErrClosed
}
if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
break
@@ -1302,9 +1327,10 @@ func (c *Conn) Close() error {
}
var alertErr error
-
if c.handshakeComplete() {
- alertErr = c.closeNotify()
+ if err := c.closeNotify(); err != nil {
+ alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
+ }
}
if err := c.conn.Close(); err != nil {
@@ -1331,8 +1357,12 @@ func (c *Conn) closeNotify() error {
defer c.out.Unlock()
if !c.closeNotifySent {
+ // Set a Write Deadline to prevent possibly blocking forever.
+ c.SetWriteDeadline(time.Now().Add(time.Second * 5))
c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
c.closeNotifySent = true
+ // Any subsequent writes will fail.
+ c.SetWriteDeadline(time.Now())
}
return c.closeNotifyErr
}
@@ -1344,8 +1374,61 @@ func (c *Conn) closeNotify() error {
// first Read or Write will call it automatically.
//
// For control over canceling or setting a timeout on a handshake, use
-// the Dialer's DialContext method.
+// HandshakeContext or the Dialer's DialContext method instead.
func (c *Conn) Handshake() error {
+ return c.HandshakeContext(context.Background())
+}
+
+// HandshakeContext runs the client or server handshake
+// protocol if it has not yet been run.
+//
+// The provided Context must be non-nil. If the context is canceled before
+// the handshake is complete, the handshake is interrupted and an error is returned.
+// Once the handshake has completed, cancellation of the context will not affect the
+// connection.
+//
+// Most uses of this package need not call HandshakeContext explicitly: the
+// first Read or Write will call it automatically.
+func (c *Conn) HandshakeContext(ctx context.Context) error {
+ // Delegate to unexported method for named return
+ // without confusing documented signature.
+ return c.handshakeContext(ctx)
+}
+
+func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
+ handshakeCtx, cancel := context.WithCancel(ctx)
+ // Note: defer this before starting the "interrupter" goroutine
+ // so that we can tell the difference between the input being canceled and
+ // this cancellation. In the former case, we need to close the connection.
+ defer cancel()
+
+ // Start the "interrupter" goroutine, if this context might be canceled.
+ // (The background context cannot).
+ //
+ // The interrupter goroutine waits for the input context to be done and
+ // closes the connection if this happens before the function returns.
+ if ctx.Done() != nil {
+ done := make(chan struct{})
+ interruptRes := make(chan error, 1)
+ defer func() {
+ close(done)
+ if ctxErr := <-interruptRes; ctxErr != nil {
+ // Return context error to user.
+ ret = ctxErr
+ }
+ }()
+ go func() {
+ select {
+ case <-handshakeCtx.Done():
+ // Close the connection, discarding the error
+ _ = c.conn.Close()
+ interruptRes <- handshakeCtx.Err()
+ case <-done:
+ interruptRes <- nil
+ }
+ }()
+ }
+
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
@@ -1359,7 +1442,7 @@ func (c *Conn) Handshake() error {
c.in.Lock()
defer c.in.Unlock()
- c.handshakeErr = c.handshakeFn()
+ c.handshakeErr = c.handshakeFn(handshakeCtx)
if c.handshakeErr == nil {
c.handshakes++
} else {
@@ -1388,7 +1471,7 @@ func (c *Conn) connectionStateLocked() ConnectionState {
state.Version = c.vers
state.NegotiatedProtocol = c.clientProtocol
state.DidResume = c.didResume
- state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
+ state.NegotiatedProtocolIsMutual = true
state.ServerName = c.serverName
state.CipherSuite = c.cipherSuite
state.PeerCertificates = c.peerCertificates