aboutsummaryrefslogtreecommitdiff
path: root/ssh
diff options
context:
space:
mode:
authorDave Cheney <dave@cheney.net>2013-10-25 06:29:58 +1100
committerDave Cheney <dave@cheney.net>2013-10-25 06:29:58 +1100
commitc0d640c88782f757a45d3f7b93eec2ec63b229cb (patch)
treef1d2d258b71d469f94760ba35e47716bf7400b85 /ssh
parent105632d35b7181298edeb557a23e66534203796f (diff)
downloadgo-x-crypto-c0d640c88782f757a45d3f7b93eec2ec63b229cb.tar.xz
go.crypto/ssh: ensure {Server,Client}Conn do not expose io.ReadWriter
Transport should not be a ReadWriter. It can only write packets, i.e. no partial reads or writes. Furthermore, you can currently do ClientConn.Write() while the connection is live, which sends raw bytes over the connection. Doing so will confuse the transports because the data is not encrypted. As a consequence, ClientConn and ServerConn stop being a net.Conn Finally, ensure that {Server,Client}Conn implement LocalAddr and RemoteAddr methods that previously were exposed by an embedded net.Conn field. R=hanwen CC=golang-dev https://golang.org/cl/16610043
Diffstat (limited to 'ssh')
-rw-r--r--ssh/client.go41
-rw-r--r--ssh/client_auth.go4
-rw-r--r--ssh/common_test.go31
-rw-r--r--ssh/server.go47
-rw-r--r--ssh/session.go2
-rw-r--r--ssh/tcpip.go2
6 files changed, 88 insertions, 39 deletions
diff --git a/ssh/client.go b/ssh/client.go
index 65787bf..7a7f1b8 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -16,7 +16,7 @@ import (
// ClientConn represents the client side of an SSH connection.
type ClientConn struct {
- *transport
+ transport *transport
config *ClientConfig
chanList // channels associated with this connection
forwardList // forwarded tcpip connections from the remote side
@@ -47,13 +47,22 @@ func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientCo
}
if err := conn.handshake(); err != nil {
- conn.Close()
+ conn.transport.Close()
return nil, fmt.Errorf("handshake failed: %v", err)
}
go conn.mainLoop()
return conn, nil
}
+// Close closes the connection.
+func (c *ClientConn) Close() error { return c.transport.Close() }
+
+// LocalAddr returns the local network address.
+func (c *ClientConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
+
+// RemoteAddr returns the remote network address.
+func (c *ClientConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
+
// handshake performs the client side key exchange. See RFC 4253 Section 7.
func (c *ClientConn) handshake() error {
clientVersion := []byte(packageVersion)
@@ -78,10 +87,10 @@ func (c *ClientConn) handshake() error {
CompressionServerClient: supportedCompressions,
}
kexInitPacket := marshal(msgKexInit, clientKexInit)
- if err := c.writePacket(kexInitPacket); err != nil {
+ if err := c.transport.writePacket(kexInitPacket); err != nil {
return err
}
- packet, err := c.readPacket()
+ packet, err := c.transport.readPacket()
if err != nil {
return err
}
@@ -99,7 +108,7 @@ func (c *ClientConn) handshake() error {
if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
// The server sent a Kex message for the wrong algorithm,
// which we have to ignore.
- if _, err := c.readPacket(); err != nil {
+ if _, err := c.transport.readPacket(); err != nil {
return err
}
}
@@ -115,7 +124,7 @@ func (c *ClientConn) handshake() error {
clientKexInit: kexInitPacket,
serverKexInit: packet,
}
- result, err := kex.Client(c, c.config.rand(), &magics)
+ result, err := kex.Client(c.transport, c.config.rand(), &magics)
if err != nil {
return err
}
@@ -126,7 +135,7 @@ func (c *ClientConn) handshake() error {
}
if checker := c.config.HostKeyChecker; checker != nil {
- err = checker.Check(c.dialAddress, c.RemoteAddr(), algs.hostKey, result.HostKey)
+ err = checker.Check(c.dialAddress, c.transport.RemoteAddr(), algs.hostKey, result.HostKey)
if err != nil {
return err
}
@@ -134,10 +143,10 @@ func (c *ClientConn) handshake() error {
c.transport.prepareKeyChange(algs, result)
- if err = c.writePacket([]byte{msgNewKeys}); err != nil {
+ if err = c.transport.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
- if packet, err = c.readPacket(); err != nil {
+ if packet, err = c.transport.readPacket(); err != nil {
return err
}
if packet[0] != msgNewKeys {
@@ -171,13 +180,13 @@ func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
defer func() {
- c.Close()
+ c.transport.Close()
c.chanList.closeAll()
c.forwardList.closeAll()
}()
for {
- packet, err := c.readPacket()
+ packet, err := c.transport.readPacket()
if err != nil {
break
}
@@ -298,7 +307,7 @@ func (c *ClientConn) mainLoop() {
// This handles keepalive messages and matches
// the behaviour of OpenSSH.
if msg.WantReply {
- c.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
+ c.transport.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
}
case *globalRequestSuccessMsg, *globalRequestFailureMsg:
c.globalRequest.response <- msg
@@ -355,7 +364,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
MaxPacketSize: 1 << 15,
}
- c.writePacket(marshal(msgChannelOpenConfirm, m))
+ c.transport.writePacket(marshal(msgChannelOpenConfirm, m))
l <- forward{ch, raddr}
default:
// unknown channel type
@@ -365,7 +374,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
Message: fmt.Sprintf("unknown channel type: %v", msg.ChanType),
Language: "en_US.UTF-8",
}
- c.writePacket(marshal(msgChannelOpenFailure, m))
+ c.transport.writePacket(marshal(msgChannelOpenFailure, m))
}
}
@@ -375,7 +384,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
c.globalRequest.Lock()
defer c.globalRequest.Unlock()
- if err := c.writePacket(marshal(msgGlobalRequest, m)); err != nil {
+ if err := c.transport.writePacket(marshal(msgGlobalRequest, m)); err != nil {
return nil, err
}
r := <-c.globalRequest.response
@@ -394,7 +403,7 @@ func (c *ClientConn) sendConnectionFailed(remoteId uint32) error {
Message: "invalid request",
Language: "en_US.UTF-8",
}
- return c.writePacket(marshal(msgChannelOpenFailure, m))
+ return c.transport.writePacket(marshal(msgChannelOpenFailure, m))
}
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 47443b3..c22d45c 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -14,10 +14,10 @@ import (
// authenticate authenticates with the remote server. See RFC 4252.
func (c *ClientConn) authenticate(session []byte) error {
// initiate user auth session
- if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
+ if err := c.transport.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
return err
}
- packet, err := c.readPacket()
+ packet, err := c.transport.readPacket()
if err != nil {
return err
}
diff --git a/ssh/common_test.go b/ssh/common_test.go
index 058fb04..d9df56f 100644
--- a/ssh/common_test.go
+++ b/ssh/common_test.go
@@ -5,6 +5,8 @@
package ssh
import (
+ "io"
+ "net"
"testing"
)
@@ -24,3 +26,32 @@ func TestSafeString(t *testing.T) {
}
}
}
+
+// Make sure Read/Write are not exposed.
+func TestConnHideRWMethods(t *testing.T) {
+ for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
+ if _, ok := c.(io.Reader); ok {
+ t.Errorf("%T implements io.Reader", c)
+ }
+ if _, ok := c.(io.Writer); ok {
+ t.Errorf("%T implements io.Writer", c)
+ }
+ }
+}
+
+func TestConnSupportsLocalRemoteMethods(t *testing.T) {
+ type LocalAddr interface {
+ LocalAddr() net.Addr
+ }
+ type RemoteAddr interface {
+ RemoteAddr() net.Addr
+ }
+ for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
+ if _, ok := c.(LocalAddr); !ok {
+ t.Errorf("%T does not implement LocalAddr", c)
+ }
+ if _, ok := c.(RemoteAddr); !ok {
+ t.Errorf("%T does not implement RemoteAddr", c)
+ }
+ }
+}
diff --git a/ssh/server.go b/ssh/server.go
index f9791e2..e4ec0e0 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -97,8 +97,8 @@ const maxCachedPubKeys = 16
// A ServerConn represents an incoming connection.
type ServerConn struct {
- *transport
- config *ServerConfig
+ transport *transport
+ config *ServerConfig
channels map[uint32]*serverChan
nextChanId uint32
@@ -147,6 +147,15 @@ func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
}
+// Close closes the connection.
+func (s *ServerConn) Close() error { return s.transport.Close() }
+
+// LocalAddr returns the local network address.
+func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
+
+// RemoteAddr returns the remote network address.
+func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
+
// Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConn) Handshake() error {
var err error
@@ -160,7 +169,7 @@ func (s *ServerConn) Handshake() error {
}
var packet []byte
- if packet, err = s.readPacket(); err != nil {
+ if packet, err = s.transport.readPacket(); err != nil {
return err
}
var serviceRequest serviceRequestMsg
@@ -173,7 +182,7 @@ func (s *ServerConn) Handshake() error {
serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth,
}
- if err := s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
+ if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
return err
}
@@ -199,13 +208,13 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
}
serverKexInitPacket := marshal(msgKexInit, serverKexInit)
- if err = s.writePacket(serverKexInitPacket); err != nil {
+ if err = s.transport.writePacket(serverKexInitPacket); err != nil {
return
}
if clientKexInitPacket == nil {
clientKexInit = new(kexInitMsg)
- if clientKexInitPacket, err = s.readPacket(); err != nil {
+ if clientKexInitPacket, err = s.transport.readPacket(); err != nil {
return
}
if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
@@ -221,7 +230,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
// The client sent a Kex message for the wrong algorithm,
// which we have to ignore.
- if _, err = s.readPacket(); err != nil {
+ if _, err = s.transport.readPacket(); err != nil {
return
}
}
@@ -244,7 +253,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
serverKexInit: marshal(msgKexInit, serverKexInit),
clientKexInit: clientKexInitPacket,
}
- result, err := kex.Server(s, s.config.rand(), &magics, hostKey)
+ result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey)
if err != nil {
return err
}
@@ -253,10 +262,10 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
return err
}
- if err = s.writePacket([]byte{msgNewKeys}); err != nil {
+ if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil {
return
}
- if packet, err := s.readPacket(); err != nil {
+ if packet, err := s.transport.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
@@ -308,7 +317,7 @@ func (s *ServerConn) authenticate(H []byte) error {
userAuthLoop:
for {
- if packet, err = s.readPacket(); err != nil {
+ if packet, err = s.transport.readPacket(); err != nil {
return err
}
if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
@@ -382,7 +391,7 @@ userAuthLoop:
Algo: algo,
PubKey: string(pubKey),
}
- if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
+ if err = s.transport.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
return err
}
continue userAuthLoop
@@ -432,13 +441,13 @@ userAuthLoop:
return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
- if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
+ if err = s.transport.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
return err
}
}
packet = []byte{msgUserAuthSuccess}
- if err = s.writePacket(packet); err != nil {
+ if err = s.transport.writePacket(packet); err != nil {
return err
}
@@ -462,7 +471,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
prompts = appendBool(prompts, echos[i])
}
- if err := c.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
+ if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
Instruction: instruction,
NumPrompts: uint32(len(questions)),
Prompts: prompts,
@@ -470,7 +479,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
return nil, err
}
- packet, err := c.readPacket()
+ packet, err := c.transport.readPacket()
if err != nil {
return nil, err
}
@@ -511,7 +520,7 @@ func (s *ServerConn) Accept() (Channel, error) {
}
for {
- packet, err := s.readPacket()
+ packet, err := s.transport.readPacket()
if err != nil {
s.lock.Lock()
@@ -557,7 +566,7 @@ func (s *ServerConn) Accept() (Channel, error) {
}
c := &serverChan{
channel: channel{
- packetConn: s,
+ packetConn: s.transport,
remoteId: msg.PeersId,
remoteWin: window{Cond: newCond()},
maxPacket: msg.MaxPacketSize,
@@ -619,7 +628,7 @@ func (s *ServerConn) Accept() (Channel, error) {
case *globalRequestMsg:
if msg.WantReply {
- if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
+ if err := s.transport.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}
diff --git a/ssh/session.go b/ssh/session.go
index 4fa2994..9035b49 100644
--- a/ssh/session.go
+++ b/ssh/session.go
@@ -564,7 +564,7 @@ func (s *Session) StderrPipe() (io.Reader, error) {
// NewSession returns a new interactive session on the remote host.
func (c *ClientConn) NewSession() (*Session, error) {
ch := c.newChan(c.transport)
- if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
+ if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenMsg{
ChanType: "session",
PeersId: ch.localId,
PeersWindow: 1 << 14,
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index 3ff444e..19c2418 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -296,7 +296,7 @@ type channelOpenDirectMsg struct {
// strings and are expected to be resolvable at the remote end.
func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpChan, error) {
ch := c.newChan(c.transport)
- if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
+ if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
ChanType: "direct-tcpip",
PeersId: ch.localId,
PeersWindow: 1 << 14,