diff options
Diffstat (limited to 'src/crypto/tls/handshake_client.go')
| -rw-r--r-- | src/crypto/tls/handshake_client.go | 51 |
1 files changed, 27 insertions, 24 deletions
diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index 46b0a770d5..92e33e7169 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -14,6 +15,7 @@ import ( "crypto/x509" "errors" "fmt" + "hash" "io" "net" "strings" @@ -23,6 +25,7 @@ import ( type clientHandshakeState struct { c *Conn + ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg suite *cipherSuite @@ -133,7 +136,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { return hello, params, nil } -func (c *Conn) clientHandshake() (err error) { +func (c *Conn) clientHandshake(ctx context.Context) (err error) { if c.config == nil { c.config = defaultConfig() } @@ -197,6 +200,7 @@ func (c *Conn) clientHandshake() (err error) { if c.vers == VersionTLS13 { hs := &clientHandshakeStateTLS13{ c: c, + ctx: ctx, serverHello: serverHello, hello: hello, ecdheParams: ecdheParams, @@ -211,6 +215,7 @@ func (c *Conn) clientHandshake() (err error) { hs := &clientHandshakeState{ c: c, + ctx: ctx, serverHello: serverHello, hello: hello, session: session, @@ -539,7 +544,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { certRequested = true hs.finishedHash.Write(certReq.marshal()) - cri := certificateRequestInfoFromMsg(c.vers, certReq) + cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { c.sendAlert(alertInternalError) return err @@ -647,12 +652,12 @@ func (hs *clientHandshakeState) establishKeys() error { clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) var clientCipher, serverCipher interface{} - var clientHash, serverHash macFunction + var clientHash, serverHash hash.Hash if hs.suite.cipher != nil { clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) - clientHash = hs.suite.mac(c.vers, clientMAC) + clientHash = hs.suite.mac(clientMAC) serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) - serverHash = hs.suite.mac(c.vers, serverMAC) + serverHash = hs.suite.mac(serverMAC) } else { clientCipher = hs.suite.aead(clientKey, clientIV) serverCipher = hs.suite.aead(serverKey, serverIV) @@ -700,18 +705,18 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { } } - clientDidALPN := len(hs.hello.alpnProtocols) > 0 - serverHasALPN := len(hs.serverHello.alpnProtocol) > 0 - - if !clientDidALPN && serverHasALPN { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server advertised unrequested ALPN extension") - } - - if serverHasALPN { + if hs.serverHello.alpnProtocol != "" { + if len(hs.hello.alpnProtocols) == 0 { + c.sendAlert(alertUnsupportedExtension) + return false, errors.New("tls: server advertised unrequested ALPN extension") + } + if mutualProtocol([]string{hs.serverHello.alpnProtocol}, hs.hello.alpnProtocols) == "" { + c.sendAlert(alertUnsupportedExtension) + return false, errors.New("tls: server selected unadvertised ALPN protocol") + } c.clientProtocol = hs.serverHello.alpnProtocol - c.clientProtocolFallback = false } + c.scts = hs.serverHello.scts if !hs.serverResumedSession() { @@ -879,10 +884,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { // certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS // <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { +func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { cri := &CertificateRequestInfo{ AcceptableCAs: certReq.certificateAuthorities, Version: vers, + ctx: ctx, } var rsaAvail, ecAvail bool @@ -967,20 +973,17 @@ func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { return serverAddr.String() } -// mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol -// given list of possible protocols and a list of the preference order. The -// first list must not be empty. It returns the resulting protocol and flag -// indicating if the fallback case was reached. -func mutualProtocol(protos, preferenceProtos []string) (string, bool) { +// mutualProtocol finds the mutual ALPN protocol given list of possible +// protocols and a list of the preference order. +func mutualProtocol(protos, preferenceProtos []string) string { for _, s := range preferenceProtos { for _, c := range protos { if s == c { - return s, false + return s } } } - - return protos[0], true + return "" } // hostnameInSNI converts name into an appropriate hostname for SNI. |
