aboutsummaryrefslogtreecommitdiff
path: root/ssh/server.go
diff options
context:
space:
mode:
authorDave Cheney <dave@cheney.net>2013-10-25 06:29:58 +1100
committerDave Cheney <dave@cheney.net>2013-10-25 06:29:58 +1100
commitc0d640c88782f757a45d3f7b93eec2ec63b229cb (patch)
treef1d2d258b71d469f94760ba35e47716bf7400b85 /ssh/server.go
parent105632d35b7181298edeb557a23e66534203796f (diff)
downloadgo-x-crypto-c0d640c88782f757a45d3f7b93eec2ec63b229cb.tar.xz
go.crypto/ssh: ensure {Server,Client}Conn do not expose io.ReadWriter
Transport should not be a ReadWriter. It can only write packets, i.e. no partial reads or writes. Furthermore, you can currently do ClientConn.Write() while the connection is live, which sends raw bytes over the connection. Doing so will confuse the transports because the data is not encrypted. As a consequence, ClientConn and ServerConn stop being a net.Conn Finally, ensure that {Server,Client}Conn implement LocalAddr and RemoteAddr methods that previously were exposed by an embedded net.Conn field. R=hanwen CC=golang-dev https://golang.org/cl/16610043
Diffstat (limited to 'ssh/server.go')
-rw-r--r--ssh/server.go47
1 files changed, 28 insertions, 19 deletions
diff --git a/ssh/server.go b/ssh/server.go
index f9791e2..e4ec0e0 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -97,8 +97,8 @@ const maxCachedPubKeys = 16
// A ServerConn represents an incoming connection.
type ServerConn struct {
- *transport
- config *ServerConfig
+ transport *transport
+ config *ServerConfig
channels map[uint32]*serverChan
nextChanId uint32
@@ -147,6 +147,15 @@ func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
}
+// Close closes the connection.
+func (s *ServerConn) Close() error { return s.transport.Close() }
+
+// LocalAddr returns the local network address.
+func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
+
+// RemoteAddr returns the remote network address.
+func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
+
// Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConn) Handshake() error {
var err error
@@ -160,7 +169,7 @@ func (s *ServerConn) Handshake() error {
}
var packet []byte
- if packet, err = s.readPacket(); err != nil {
+ if packet, err = s.transport.readPacket(); err != nil {
return err
}
var serviceRequest serviceRequestMsg
@@ -173,7 +182,7 @@ func (s *ServerConn) Handshake() error {
serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth,
}
- if err := s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
+ if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
return err
}
@@ -199,13 +208,13 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
}
serverKexInitPacket := marshal(msgKexInit, serverKexInit)
- if err = s.writePacket(serverKexInitPacket); err != nil {
+ if err = s.transport.writePacket(serverKexInitPacket); err != nil {
return
}
if clientKexInitPacket == nil {
clientKexInit = new(kexInitMsg)
- if clientKexInitPacket, err = s.readPacket(); err != nil {
+ if clientKexInitPacket, err = s.transport.readPacket(); err != nil {
return
}
if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
@@ -221,7 +230,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
// The client sent a Kex message for the wrong algorithm,
// which we have to ignore.
- if _, err = s.readPacket(); err != nil {
+ if _, err = s.transport.readPacket(); err != nil {
return
}
}
@@ -244,7 +253,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
serverKexInit: marshal(msgKexInit, serverKexInit),
clientKexInit: clientKexInitPacket,
}
- result, err := kex.Server(s, s.config.rand(), &magics, hostKey)
+ result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey)
if err != nil {
return err
}
@@ -253,10 +262,10 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
return err
}
- if err = s.writePacket([]byte{msgNewKeys}); err != nil {
+ if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil {
return
}
- if packet, err := s.readPacket(); err != nil {
+ if packet, err := s.transport.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
@@ -308,7 +317,7 @@ func (s *ServerConn) authenticate(H []byte) error {
userAuthLoop:
for {
- if packet, err = s.readPacket(); err != nil {
+ if packet, err = s.transport.readPacket(); err != nil {
return err
}
if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
@@ -382,7 +391,7 @@ userAuthLoop:
Algo: algo,
PubKey: string(pubKey),
}
- if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
+ if err = s.transport.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
return err
}
continue userAuthLoop
@@ -432,13 +441,13 @@ userAuthLoop:
return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
- if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
+ if err = s.transport.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
return err
}
}
packet = []byte{msgUserAuthSuccess}
- if err = s.writePacket(packet); err != nil {
+ if err = s.transport.writePacket(packet); err != nil {
return err
}
@@ -462,7 +471,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
prompts = appendBool(prompts, echos[i])
}
- if err := c.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
+ if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
Instruction: instruction,
NumPrompts: uint32(len(questions)),
Prompts: prompts,
@@ -470,7 +479,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
return nil, err
}
- packet, err := c.readPacket()
+ packet, err := c.transport.readPacket()
if err != nil {
return nil, err
}
@@ -511,7 +520,7 @@ func (s *ServerConn) Accept() (Channel, error) {
}
for {
- packet, err := s.readPacket()
+ packet, err := s.transport.readPacket()
if err != nil {
s.lock.Lock()
@@ -557,7 +566,7 @@ func (s *ServerConn) Accept() (Channel, error) {
}
c := &serverChan{
channel: channel{
- packetConn: s,
+ packetConn: s.transport,
remoteId: msg.PeersId,
remoteWin: window{Cond: newCond()},
maxPacket: msg.MaxPacketSize,
@@ -619,7 +628,7 @@ func (s *ServerConn) Accept() (Channel, error) {
case *globalRequestMsg:
if msg.WantReply {
- if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
+ if err := s.transport.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}