diff options
Diffstat (limited to 'src/crypto/tls/conn.go')
| -rw-r--r-- | src/crypto/tls/conn.go | 46 |
1 files changed, 37 insertions, 9 deletions
diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index fc65d87aaf..a5e19dcc52 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -1004,18 +1004,37 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { return n, nil } -// writeRecord writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { +// writeHandshakeRecord writes a handshake message to the connection and updates +// the record layer state. If transcript is non-nil the marshalled message is +// written to it. +func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { c.out.Lock() defer c.out.Unlock() - return c.writeRecordLocked(typ, data) + data, err := msg.marshal() + if err != nil { + return 0, err + } + if transcript != nil { + transcript.Write(data) + } + + return c.writeRecordLocked(recordTypeHandshake, data) +} + +// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and +// updates the record layer state. +func (c *Conn) writeChangeCipherRecord() error { + c.out.Lock() + defer c.out.Unlock() + _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1}) + return err } // readHandshake reads the next handshake message from -// the record layer. -func (c *Conn) readHandshake() (any, error) { +// the record layer. If transcript is non-nil, the message +// is written to the passed transcriptHash. +func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { for c.hand.Len() < 4 { if err := c.readRecord(); err != nil { return nil, err @@ -1094,6 +1113,11 @@ func (c *Conn) readHandshake() (any, error) { if !m.unmarshal(data) { return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } + + if transcript != nil { + transcript.Write(data) + } + return m, nil } @@ -1169,7 +1193,7 @@ func (c *Conn) handleRenegotiation() error { return errors.New("tls: internal error: unexpected renegotiation") } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1215,7 +1239,7 @@ func (c *Conn) handlePostHandshakeMessage() error { return c.handleRenegotiation() } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1251,7 +1275,11 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { defer c.out.Unlock() msg := &keyUpdateMsg{} - _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + msgBytes, err := msg.marshal() + if err != nil { + return err + } + _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes) if err != nil { // Surface the error at the next write. c.out.setErrorLocked(err) |
