diff options
| author | Adam Langley <agl@golang.org> | 2012-04-24 13:46:22 -0400 |
|---|---|---|
| committer | Adam Langley <agl@golang.org> | 2012-04-24 13:46:22 -0400 |
| commit | bcdd6a2fd3e36323c71ab4c80588f4e48e8a3678 (patch) | |
| tree | 40fc39ef3adec1c3e3955740da4af66fd3a4ef68 /ssh | |
| parent | 58afe880f197c244a2edbfab2bb090a5bf02dfe1 (diff) | |
| download | go-x-crypto-bcdd6a2fd3e36323c71ab4c80588f4e48e8a3678.tar.xz | |
ssh: handle bad servers better.
This change prevents bad servers from crashing a client by sending an
invalid channel ID. It also makes the client disconnect in more cases
of invalid messages from a server and cleans up the client channels
in the event of a disconnect.
R=dave
CC=golang-dev
https://golang.org/cl/6099050
Diffstat (limited to 'ssh')
| -rw-r--r-- | ssh/client.go | 67 | ||||
| -rw-r--r-- | ssh/session_test.go | 25 |
2 files changed, 75 insertions, 17 deletions
diff --git a/ssh/client.go b/ssh/client.go index 493d8ec..3b29923 100644 --- a/ssh/client.go +++ b/ssh/client.go @@ -184,8 +184,16 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha // mainLoop reads incoming messages and routes channel messages // to their respective ClientChans. func (c *ClientConn) mainLoop() { - // TODO(dfc) signal the underlying close to all channels - defer c.Close() + defer func() { + // We don't check, for example, that the channel IDs from the + // server are valid before using them. Thus a bad server can + // cause us to panic, but we don't want to crash the program. + recover() + + c.Close() + c.closeAll() + }() + for { packet, err := c.readPacket() if err != nil { @@ -199,28 +207,34 @@ func (c *ClientConn) mainLoop() { case msgChannelData: if len(packet) < 9 { // malformed data packet - break + return } peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) - if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 { - packet = packet[9:] - c.getChan(peersId).stdout.handleData(packet[:length]) + length := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8]) + packet = packet[9:] + + if length != uint32(len(packet)) { + return } + c.getChan(peersId).stdout.handleData(packet) case msgChannelExtendedData: if len(packet) < 13 { // malformed data packet - break + return } peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8]) - if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 { - packet = packet[13:] - // RFC 4254 5.2 defines data_type_code 1 to be data destined - // for stderr on interactive sessions. Other data types are - // silently discarded. - if datatype == 1 { - c.getChan(peersId).stderr.handleData(packet[:length]) - } + length := uint32(packet[9])<<24 | uint32(packet[10])<<16 | uint32(packet[11])<<8 | uint32(packet[12]) + packet = packet[13:] + + if length != uint32(len(packet)) { + return + } + // RFC 4254 5.2 defines data_type_code 1 to be data destined + // for stderr on interactive sessions. Other data types are + // silently discarded. + if datatype == 1 { + c.getChan(peersId).stderr.handleData(packet) } default: switch msg := decode(packet).(type) { @@ -256,10 +270,10 @@ func (c *ClientConn) mainLoop() { case *windowAdjustMsg: if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) { // invalid window update - break + return } case *disconnectMsg: - break + return default: fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg) } @@ -408,6 +422,9 @@ func (c *chanlist) newChan(t *transport) *clientChan { func (c *chanlist) getChan(id uint32) *clientChan { c.Lock() defer c.Unlock() + if id >= uint32(len(c.chans)) { + return nil + } return c.chans[int(id)] } @@ -417,6 +434,22 @@ func (c *chanlist) remove(id uint32) { c.chans[int(id)] = nil } +func (c *chanlist) closeAll() { + c.Lock() + defer c.Unlock() + + for _, ch := range c.chans { + if ch == nil { + continue + } + + ch.theyClosed = true + ch.stdout.eof() + ch.stderr.eof() + close(ch.msg) + } +} + // A chanWriter represents the stdin of a remote process. type chanWriter struct { win *window diff --git a/ssh/session_test.go b/ssh/session_test.go index df66e1d..df97fcf 100644 --- a/ssh/session_test.go +++ b/ssh/session_test.go @@ -275,6 +275,20 @@ func TestExitWithoutStatusOrSignal(t *testing.T) { } } +func TestInvalidServerMessage(t *testing.T) { + conn := dial(sendInvalidRecord, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %s", err) + } + // Make sure that we closed all the clientChans when the connection + // failed. + session.wait() + + defer session.Close() +} + type exitStatusMsg struct { PeersId uint32 Request string @@ -373,3 +387,14 @@ func sendSignal(signal string, ch *channel) { } ch.serverConn.writePacket(marshal(msgChannelRequest, sig)) } + +func sendInvalidRecord(ch *channel) { + defer ch.Close() + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], 29348723 /* invalid channel id */) + marshalUint32(packet[5:], 1) + packet[9] = 42 + + ch.serverConn.writePacket(packet) +} |
