aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/crypto/tls/common.go25
-rw-r--r--src/crypto/tls/handshake_client.go7
-rw-r--r--src/crypto/tls/handshake_server.go7
-rw-r--r--src/crypto/tls/handshake_server_test.go214
-rw-r--r--src/crypto/tls/handshake_server_tls13.go8
5 files changed, 253 insertions, 8 deletions
diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go
index 65cff5f5b9..dfd24d70b0 100644
--- a/src/crypto/tls/common.go
+++ b/src/crypto/tls/common.go
@@ -1847,13 +1847,28 @@ func fipsAllowChain(chain []*x509.Certificate) bool {
return true
}
-// anyUnexpiredChain reports if at least one of verifiedChains is still
-// unexpired. If verifiedChains is empty, it returns false.
-func anyUnexpiredChain(verifiedChains [][]*x509.Certificate, now time.Time) bool {
+// anyValidVerifiedChain reports if at least one of the chains in verifiedChains
+// is valid, as indicated by none of the certificates being expired and the root
+// being in opts.Roots (or in the system root pool if opts.Roots is nil). If
+// verifiedChains is empty, it returns false.
+func anyValidVerifiedChain(verifiedChains [][]*x509.Certificate, opts x509.VerifyOptions) bool {
for _, chain := range verifiedChains {
- if len(chain) != 0 && !slices.ContainsFunc(chain, func(cert *x509.Certificate) bool {
- return now.Before(cert.NotBefore) || now.After(cert.NotAfter) // cert is expired
+ if len(chain) == 0 {
+ continue
+ }
+ if slices.ContainsFunc(chain, func(cert *x509.Certificate) bool {
+ return opts.CurrentTime.Before(cert.NotBefore) || opts.CurrentTime.After(cert.NotAfter)
}) {
+ continue
+ }
+ // Since we already validated the chain, we only care that it is
+ // rooted in a CA in CAs, or in the system pool. On platforms where
+ // we control chain validation (e.g. not Windows or macOS) this is a
+ // simple lookup in the CertPool internal hash map. On other
+ // platforms, this may be more expensive, depending on how they
+ // implement verification of just root certificates.
+ root := chain[len(chain)-1]
+ if _, err := root.Verify(opts); err == nil {
return true
}
}
diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go
index d1ad9d582b..7d4bd5bcce 100644
--- a/src/crypto/tls/handshake_client.go
+++ b/src/crypto/tls/handshake_client.go
@@ -412,7 +412,12 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
// application from a faulty ClientSessionCache implementation.
return nil, nil, nil, nil
}
- if !anyUnexpiredChain(session.verifiedChains, c.config.time()) {
+ opts := x509.VerifyOptions{
+ CurrentTime: c.config.time(),
+ Roots: c.config.RootCAs,
+ KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ }
+ if !anyValidVerifiedChain(session.verifiedChains, opts) {
// No valid chains, delete the entry.
c.config.ClientSessionCache.Put(cacheKey, nil)
return nil, nil, nil, nil
diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go
index 64053e1a9e..34dfb13b67 100644
--- a/src/crypto/tls/handshake_server.go
+++ b/src/crypto/tls/handshake_server.go
@@ -523,8 +523,13 @@ func (hs *serverHandshakeState) checkForResumption() error {
if sessionHasClientCerts && c.config.time().After(sessionState.peerCertificates[0].NotAfter) {
return nil
}
+ opts := x509.VerifyOptions{
+ CurrentTime: c.config.time(),
+ Roots: c.config.ClientCAs,
+ KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+ }
if sessionHasClientCerts && c.config.ClientAuth >= VerifyClientCertIfGiven &&
- !anyUnexpiredChain(sessionState.verifiedChains, c.config.time()) {
+ !anyValidVerifiedChain(sessionState.verifiedChains, opts) {
return nil
}
diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go
index 8325c9fac3..cbcf012974 100644
--- a/src/crypto/tls/handshake_server_test.go
+++ b/src/crypto/tls/handshake_server_test.go
@@ -2275,3 +2275,217 @@ func testHandshakeChainExpiryResumption(t *testing.T, version uint16) {
testExpiration("LeafExpiresBeforeRoot", now.Add(2*time.Hour), now.Add(3*time.Hour))
testExpiration("LeafExpiresAfterRoot", now.Add(2*time.Hour), now.Add(time.Hour))
}
+
+func TestHandshakeGetConfigForClientDifferentClientCAs(t *testing.T) {
+ t.Run("TLS1.2", func(t *testing.T) {
+ testHandshakeGetConfigForClientDifferentClientCAs(t, VersionTLS12)
+ })
+ t.Run("TLS1.3", func(t *testing.T) {
+ testHandshakeGetConfigForClientDifferentClientCAs(t, VersionTLS13)
+ })
+}
+
+func testHandshakeGetConfigForClientDifferentClientCAs(t *testing.T, version uint16) {
+ now := time.Now()
+ tmpl := &x509.Certificate{
+ Subject: pkix.Name{CommonName: "root"},
+ NotBefore: now.Add(-time.Hour * 24),
+ NotAfter: now.Add(time.Hour * 24),
+ IsCA: true,
+ BasicConstraintsValid: true,
+ }
+ rootDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
+ if err != nil {
+ t.Fatalf("CreateCertificate: %v", err)
+ }
+ rootA, err := x509.ParseCertificate(rootDER)
+ if err != nil {
+ t.Fatalf("ParseCertificate: %v", err)
+ }
+ rootDER, err = x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
+ if err != nil {
+ t.Fatalf("CreateCertificate: %v", err)
+ }
+ rootB, err := x509.ParseCertificate(rootDER)
+ if err != nil {
+ t.Fatalf("ParseCertificate: %v", err)
+ }
+
+ tmpl = &x509.Certificate{
+ Subject: pkix.Name{},
+ DNSNames: []string{"example.com"},
+ NotBefore: now.Add(-time.Hour * 24),
+ NotAfter: now.Add(time.Hour * 24),
+ KeyUsage: x509.KeyUsageDigitalSignature,
+ }
+ certDER, err := x509.CreateCertificate(rand.Reader, tmpl, rootA, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
+ if err != nil {
+ t.Fatalf("CreateCertificate: %v", err)
+ }
+
+ serverConfig := testConfig.Clone()
+ serverConfig.MaxVersion = version
+ serverConfig.Certificates = []Certificate{{
+ Certificate: [][]byte{certDER},
+ PrivateKey: testECDSAPrivateKey,
+ }}
+ serverConfig.Time = func() time.Time {
+ return now
+ }
+ serverConfig.ClientCAs = x509.NewCertPool()
+ serverConfig.ClientCAs.AddCert(rootA)
+ serverConfig.ClientAuth = RequireAndVerifyClientCert
+ switchConfig := false
+ serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) {
+ if !switchConfig {
+ return nil, nil
+ }
+ cfg := serverConfig.Clone()
+ cfg.ClientCAs = x509.NewCertPool()
+ cfg.ClientCAs.AddCert(rootB)
+ return cfg, nil
+ }
+ serverConfig.InsecureSkipVerify = false
+ serverConfig.ServerName = "example.com"
+
+ clientConfig := testConfig.Clone()
+ clientConfig.MaxVersion = version
+ clientConfig.Certificates = []Certificate{{
+ Certificate: [][]byte{certDER},
+ PrivateKey: testECDSAPrivateKey,
+ }}
+ clientConfig.ClientSessionCache = NewLRUClientSessionCache(32)
+ clientConfig.RootCAs = x509.NewCertPool()
+ clientConfig.RootCAs.AddCert(rootA)
+ clientConfig.Time = func() time.Time {
+ return now
+ }
+ clientConfig.InsecureSkipVerify = false
+ clientConfig.ServerName = "example.com"
+
+ testResume := func(t *testing.T, sc, cc *Config, expectResume bool) {
+ t.Helper()
+ ss, cs, err := testHandshake(t, cc, sc)
+ if err != nil {
+ t.Fatalf("handshake: %v", err)
+ }
+ if cs.DidResume != expectResume {
+ t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
+ }
+ if ss.DidResume != expectResume {
+ t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
+ }
+ }
+
+ testResume(t, serverConfig, clientConfig, false)
+ testResume(t, serverConfig, clientConfig, true)
+
+ // Cause GetConfigForClient to return a config cloned from the base config,
+ // but with a different ClientCAs pool. This should cause resumption to fail.
+ switchConfig = true
+
+ testResume(t, serverConfig, clientConfig, false)
+ testResume(t, serverConfig, clientConfig, true)
+}
+
+func TestHandshakeChangeRootCAsResumption(t *testing.T) {
+ t.Run("TLS1.2", func(t *testing.T) {
+ testHandshakeChangeRootCAsResumption(t, VersionTLS12)
+ })
+ t.Run("TLS1.3", func(t *testing.T) {
+ testHandshakeChangeRootCAsResumption(t, VersionTLS13)
+ })
+}
+
+func testHandshakeChangeRootCAsResumption(t *testing.T, version uint16) {
+ now := time.Now()
+ tmpl := &x509.Certificate{
+ Subject: pkix.Name{CommonName: "root"},
+ NotBefore: now.Add(-time.Hour * 24),
+ NotAfter: now.Add(time.Hour * 24),
+ IsCA: true,
+ BasicConstraintsValid: true,
+ }
+ rootDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
+ if err != nil {
+ t.Fatalf("CreateCertificate: %v", err)
+ }
+ rootA, err := x509.ParseCertificate(rootDER)
+ if err != nil {
+ t.Fatalf("ParseCertificate: %v", err)
+ }
+ rootDER, err = x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
+ if err != nil {
+ t.Fatalf("CreateCertificate: %v", err)
+ }
+ rootB, err := x509.ParseCertificate(rootDER)
+ if err != nil {
+ t.Fatalf("ParseCertificate: %v", err)
+ }
+
+ tmpl = &x509.Certificate{
+ Subject: pkix.Name{},
+ DNSNames: []string{"example.com"},
+ NotBefore: now.Add(-time.Hour * 24),
+ NotAfter: now.Add(time.Hour * 24),
+ KeyUsage: x509.KeyUsageDigitalSignature,
+ }
+ certDER, err := x509.CreateCertificate(rand.Reader, tmpl, rootA, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
+ if err != nil {
+ t.Fatalf("CreateCertificate: %v", err)
+ }
+
+ serverConfig := testConfig.Clone()
+ serverConfig.MaxVersion = version
+ serverConfig.Certificates = []Certificate{{
+ Certificate: [][]byte{certDER},
+ PrivateKey: testECDSAPrivateKey,
+ }}
+ serverConfig.Time = func() time.Time {
+ return now
+ }
+ serverConfig.ClientCAs = x509.NewCertPool()
+ serverConfig.ClientCAs.AddCert(rootA)
+ serverConfig.ClientAuth = RequireAndVerifyClientCert
+ serverConfig.InsecureSkipVerify = false
+ serverConfig.ServerName = "example.com"
+
+ clientConfig := testConfig.Clone()
+ clientConfig.MaxVersion = version
+ clientConfig.Certificates = []Certificate{{
+ Certificate: [][]byte{certDER},
+ PrivateKey: testECDSAPrivateKey,
+ }}
+ clientConfig.ClientSessionCache = NewLRUClientSessionCache(32)
+ clientConfig.RootCAs = x509.NewCertPool()
+ clientConfig.RootCAs.AddCert(rootA)
+ clientConfig.Time = func() time.Time {
+ return now
+ }
+ clientConfig.InsecureSkipVerify = false
+ clientConfig.ServerName = "example.com"
+
+ testResume := func(t *testing.T, sc, cc *Config, expectResume bool) {
+ t.Helper()
+ ss, cs, err := testHandshake(t, cc, sc)
+ if err != nil {
+ t.Fatalf("handshake: %v", err)
+ }
+ if cs.DidResume != expectResume {
+ t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
+ }
+ if ss.DidResume != expectResume {
+ t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
+ }
+ }
+
+ testResume(t, serverConfig, clientConfig, false)
+ testResume(t, serverConfig, clientConfig, true)
+
+ clientConfig = clientConfig.Clone()
+ clientConfig.RootCAs = x509.NewCertPool()
+ clientConfig.RootCAs.AddCert(rootB)
+
+ testResume(t, serverConfig, clientConfig, false)
+ testResume(t, serverConfig, clientConfig, true)
+}
diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go
index 11dbaa9f0a..bce94ed2d8 100644
--- a/src/crypto/tls/handshake_server_tls13.go
+++ b/src/crypto/tls/handshake_server_tls13.go
@@ -14,6 +14,7 @@ import (
"crypto/internal/fips140/tls13"
"crypto/rsa"
"crypto/tls/internal/fips140tls"
+ "crypto/x509"
"errors"
"fmt"
"hash"
@@ -369,8 +370,13 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
if sessionHasClientCerts && c.config.time().After(sessionState.peerCertificates[0].NotAfter) {
continue
}
+ opts := x509.VerifyOptions{
+ CurrentTime: c.config.time(),
+ Roots: c.config.ClientCAs,
+ KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+ }
if sessionHasClientCerts && c.config.ClientAuth >= VerifyClientCertIfGiven &&
- !anyUnexpiredChain(sessionState.verifiedChains, c.config.time()) {
+ !anyValidVerifiedChain(sessionState.verifiedChains, opts) {
continue
}