aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/crypto/tls/conn.go10
-rw-r--r--src/crypto/tls/handshake_client_tls13.go4
-rw-r--r--src/crypto/tls/handshake_server_tls13.go4
-rw-r--r--src/crypto/tls/handshake_test.go48
4 files changed, 59 insertions, 7 deletions
diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go
index 9c662ef8f6..cfadec68f1 100644
--- a/src/crypto/tls/conn.go
+++ b/src/crypto/tls/conn.go
@@ -1363,7 +1363,7 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
}
newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
- if err := c.setReadTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret); err != nil {
+ if err := c.setReadTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret, keyUpdate.updateRequested); err != nil {
return err
}
@@ -1683,12 +1683,16 @@ func (c *Conn) VerifyHostname(host string) error {
// setReadTrafficSecret sets the read traffic secret for the given encryption level. If
// being called at the same time as setWriteTrafficSecret, the caller must ensure the call
// to setWriteTrafficSecret happens first so any alerts are sent at the write level.
-func (c *Conn) setReadTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) error {
+func (c *Conn) setReadTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte, locked bool) error {
// Ensure that there are no buffered handshake messages before changing the
// read keys, since that can cause messages to be parsed that were encrypted
// using old keys which are no longer appropriate.
if c.hand.Len() != 0 {
- c.sendAlert(alertUnexpectedMessage)
+ if locked {
+ c.sendAlertLocked(alertUnexpectedMessage)
+ } else {
+ c.sendAlert(alertUnexpectedMessage)
+ }
return errors.New("tls: handshake buffer not empty before setting read traffic secret")
}
c.in.setTrafficSecret(suite, level, secret)
diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go
index 77a24b4a78..65177767a0 100644
--- a/src/crypto/tls/handshake_client_tls13.go
+++ b/src/crypto/tls/handshake_client_tls13.go
@@ -492,7 +492,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
clientSecret := handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript)
c.setWriteTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript)
- if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret); err != nil {
+ if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret, false); err != nil {
return err
}
@@ -711,7 +711,7 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error {
hs.trafficSecret = hs.masterSecret.ClientApplicationTrafficSecret(hs.transcript)
serverSecret := hs.masterSecret.ServerApplicationTrafficSecret(hs.transcript)
- if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret); err != nil {
+ if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret, false); err != nil {
return err
}
diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go
index bce94ed2d8..b45d7cbc53 100644
--- a/src/crypto/tls/handshake_server_tls13.go
+++ b/src/crypto/tls/handshake_server_tls13.go
@@ -752,7 +752,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
serverSecret := hs.handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript)
c.setWriteTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
clientSecret := hs.handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript)
- if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret); err != nil {
+ if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret, false); err != nil {
return err
}
@@ -1136,7 +1136,7 @@ func (hs *serverHandshakeStateTLS13) readClientFinished() error {
return errors.New("tls: invalid client finished hash")
}
- if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret); err != nil {
+ if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret, false); err != nil {
return err
}
diff --git a/src/crypto/tls/handshake_test.go b/src/crypto/tls/handshake_test.go
index 9cea8182d0..6e00b4348e 100644
--- a/src/crypto/tls/handshake_test.go
+++ b/src/crypto/tls/handshake_test.go
@@ -778,3 +778,51 @@ func concatHandshakeMessages(msgs ...handshakeMessage) ([]byte, error) {
outBuf = append(outBuf, marshalled...)
return outBuf, nil
}
+
+func TestMultipleKeyUpdate(t *testing.T) {
+ for _, requestUpdate := range []bool{true, false} {
+ t.Run(fmt.Sprintf("requestUpdate=%t", requestUpdate), func(t *testing.T) {
+
+ c, s := localPipe(t)
+ cfg := testConfig.Clone()
+ cfg.MinVersion = VersionTLS13
+ cfg.MaxVersion = VersionTLS13
+ client := Client(c, testConfig)
+ server := Server(s, testConfig)
+
+ clientHandshakeDone := make(chan struct{})
+ go func() {
+ if err := client.Handshake(); err != nil {
+ }
+ close(clientHandshakeDone)
+ io.Copy(io.Discard, server)
+ }()
+
+ if err := server.Handshake(); err != nil {
+ t.Fatalf("server handshake failed: %v\n", err)
+ }
+ <-clientHandshakeDone
+
+ c.SetReadDeadline(time.Now().Add(1 * time.Second))
+ s.SetReadDeadline(time.Now().Add(1 * time.Second))
+
+ kuMsg, err := (&keyUpdateMsg{updateRequested: requestUpdate}).marshal()
+ if err != nil {
+ t.Fatalf("failed to marshal key update message: %v", err)
+ }
+
+ client.out.Lock()
+ if _, err := client.writeRecordLocked(recordTypeHandshake, append(kuMsg, kuMsg...)); err != nil {
+ t.Fatalf("failed to write key update messages: %v", err)
+ }
+ client.out.Unlock()
+
+ _, err = io.Copy(io.Discard, client)
+ if err == nil {
+ t.Fatal("expected multiple key update messages to cause an error, got nil")
+ } else if !strings.HasSuffix(err.Error(), "tls: unexpected message") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ })
+ }
+}