aboutsummaryrefslogtreecommitdiff
path: root/src/crypto/tls/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/tls/conn.go')
-rw-r--r--src/crypto/tls/conn.go46
1 files changed, 37 insertions, 9 deletions
diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go
index fc65d87aaf..a5e19dcc52 100644
--- a/src/crypto/tls/conn.go
+++ b/src/crypto/tls/conn.go
@@ -1004,18 +1004,37 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
return n, nil
}
-// writeRecord writes a TLS record with the given type and payload to the
-// connection and updates the record layer state.
-func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
+// writeHandshakeRecord writes a handshake message to the connection and updates
+// the record layer state. If transcript is non-nil the marshalled message is
+// written to it.
+func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
c.out.Lock()
defer c.out.Unlock()
- return c.writeRecordLocked(typ, data)
+ data, err := msg.marshal()
+ if err != nil {
+ return 0, err
+ }
+ if transcript != nil {
+ transcript.Write(data)
+ }
+
+ return c.writeRecordLocked(recordTypeHandshake, data)
+}
+
+// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and
+// updates the record layer state.
+func (c *Conn) writeChangeCipherRecord() error {
+ c.out.Lock()
+ defer c.out.Unlock()
+ _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
+ return err
}
// readHandshake reads the next handshake message from
-// the record layer.
-func (c *Conn) readHandshake() (any, error) {
+// the record layer. If transcript is non-nil, the message
+// is written to the passed transcriptHash.
+func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
for c.hand.Len() < 4 {
if err := c.readRecord(); err != nil {
return nil, err
@@ -1094,6 +1113,11 @@ func (c *Conn) readHandshake() (any, error) {
if !m.unmarshal(data) {
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
+
+ if transcript != nil {
+ transcript.Write(data)
+ }
+
return m, nil
}
@@ -1169,7 +1193,7 @@ func (c *Conn) handleRenegotiation() error {
return errors.New("tls: internal error: unexpected renegotiation")
}
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -1215,7 +1239,7 @@ func (c *Conn) handlePostHandshakeMessage() error {
return c.handleRenegotiation()
}
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -1251,7 +1275,11 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
defer c.out.Unlock()
msg := &keyUpdateMsg{}
- _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
+ msgBytes, err := msg.marshal()
+ if err != nil {
+ return err
+ }
+ _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
if err != nil {
// Surface the error at the next write.
c.out.setErrorLocked(err)