diff options
| author | Han-Wen Nienhuys <hanwen@google.com> | 2013-06-21 12:46:35 -0400 |
|---|---|---|
| committer | Adam Langley <agl@golang.org> | 2013-06-21 12:46:35 -0400 |
| commit | afdc305bc8582a7ba5d9ea2c622ce9927a92050a (patch) | |
| tree | 43bf856358cb1c757d3e650ea8d74a1b8025e0c8 /ssh | |
| parent | b88b0165229e30fa2f41d4cdfa5ac2b6e282917d (diff) | |
| download | go-x-crypto-afdc305bc8582a7ba5d9ea2c622ce9927a92050a.tar.xz | |
go.crypto/ssh: add hook for host key checking.
R=dave, agl
CC=gobot, golang-dev
https://golang.org/cl/9922043
Diffstat (limited to 'ssh')
| -rw-r--r-- | ssh/client.go | 22 | ||||
| -rw-r--r-- | ssh/client_auth.go | 12 | ||||
| -rw-r--r-- | ssh/test/session_test.go | 19 | ||||
| -rw-r--r-- | ssh/test/test_unix_test.go | 47 |
4 files changed, 95 insertions, 5 deletions
diff --git a/ssh/client.go b/ssh/client.go index a42d13a..16569a8 100644 --- a/ssh/client.go +++ b/ssh/client.go @@ -26,6 +26,9 @@ type ClientConn struct { chanList // channels associated with this connection forwardList // forwarded tcpip connections from the remote side globalRequest + + // Address as passed to the Dial function. + dialAddress string } type globalRequest struct { @@ -35,11 +38,17 @@ type globalRequest struct { // Client returns a new SSH client connection using c as the underlying transport. func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) { + return clientWithAddress(c, "", config) +} + +func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) { conn := &ClientConn{ transport: newTransport(c, config.rand()), config: config, globalRequest: globalRequest{response: make(chan interface{}, 1)}, + dialAddress: addr, } + if err := conn.handshake(); err != nil { conn.Close() return nil, fmt.Errorf("handshake failed: %v", err) @@ -168,6 +177,12 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha return nil, nil, err } + if checker := c.config.HostKeyChecker; checker != nil { + if err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, kexDHReply.HostKey); err != nil { + return nil, nil, err + } + } + kInt, err := group.diffieHellman(kexDHReply.Y, x) if err != nil { return nil, nil, err @@ -445,7 +460,7 @@ func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) { if err != nil { return nil, err } - return Client(conn, config) + return clientWithAddress(conn, addr, config) } // A ClientConfig structure is used to configure a ClientConn. After one has @@ -463,6 +478,11 @@ type ClientConfig struct { // of a particular RFC 4252 method will be used during authentication. Auth []ClientAuth + // HostKeyChecker, if not nil, is called during the cryptographic + // handshake to validate the server's host key. A nil HostKeyChecker + // implies that all host keys are accepted. + HostKeyChecker HostKeyChecker + // Cryptographic-related configuration. Crypto CryptoConfig } diff --git a/ssh/client_auth.go b/ssh/client_auth.go index ebb74a2..5282415 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "net" ) // authenticate authenticates with the remote server. See RFC 4252. @@ -63,6 +64,17 @@ func keys(m map[string]bool) (s []string) { return } +// HostKeyChecker represents a database of known server host keys. +type HostKeyChecker interface { + // Check is called during the handshake to check server's + // public key for unexpected changes. The hostKey argument is + // in SSH wire format. It can be parsed using + // ssh.ParsePublicKey. The address before DNS resolution is + // passed in the addr argument, so the key can also be checked + // against the hostname. + Check(addr string, remote net.Addr, algorithm string, hostKey []byte) error +} + // A ClientAuth represents an instance of an RFC 4252 authentication method. type ClientAuth interface { // auth authenticates user over transport t. diff --git a/ssh/test/session_test.go b/ssh/test/session_test.go index 4393ee9..5c849fc 100644 --- a/ssh/test/session_test.go +++ b/ssh/test/session_test.go @@ -33,6 +33,25 @@ func TestRunCommandSuccess(t *testing.T) { } } +func TestHostKeyCheck(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + + conf := clientConfig() + k := conf.HostKeyChecker.(*storedHostKey) + + // change the key. + k.keys["ssh-rsa"][25]++ + + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + t.Fatalf("dial should have failed.") + } else if !strings.Contains(err.Error(), "host key mismatch") { + t.Fatalf("'host key mismatch' not found in %v", err) + } +} + func TestRunCommandFailed(t *testing.T) { server := newServer(t) defer server.Shutdown() diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go index 4553254..bc4967a 100644 --- a/ssh/test/test_unix_test.go +++ b/ssh/test/test_unix_test.go @@ -55,14 +55,25 @@ HostbasedAuthentication no ` var ( - configTmpl template.Template - rsakey *rsa.PrivateKey + configTmpl template.Template + rsakey *rsa.PrivateKey + serializedHostKey []byte ) func init() { template.Must(configTmpl.Parse(sshd_config)) block, _ := pem.Decode([]byte(testClientPrivateKey)) rsakey, _ = x509.ParsePKCS1PrivateKey(block.Bytes) + + block, _ = pem.Decode([]byte(keys["ssh_host_rsa_key"])) + if block == nil { + panic("pem.Decode ssh_host_rsa_key") + } + priv, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + panic("ParsePKCS1PrivateKey: " + err.Error()) + } + serializedHostKey = ssh.MarshalPublicKey(&priv.PublicKey) } type server struct { @@ -89,7 +100,29 @@ func username() string { return username } +type storedHostKey struct { + // keys map from an algorithm string to binary key data. + keys map[string][]byte +} + +func (k *storedHostKey) Add(algo string, public []byte) { + if k.keys == nil { + k.keys = map[string][]byte{} + } + k.keys[algo] = append([]byte(nil), public...) +} + +func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error { + if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 { + return errors.New("host key mismatch") + } + return nil +} + func clientConfig() *ssh.ClientConfig { + keyChecker := storedHostKey{} + keyChecker.Add("ssh-rsa", serializedHostKey) + kc := new(keychain) kc.keys = append(kc.keys, rsakey) config := &ssh.ClientConfig{ @@ -97,11 +130,12 @@ func clientConfig() *ssh.ClientConfig { Auth: []ssh.ClientAuth{ ssh.ClientAuthKeyring(kc), }, + HostKeyChecker: &keyChecker, } return config } -func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn { +func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) { sshd, err := exec.LookPath("sshd") if err != nil { s.t.Skipf("skipping test: %v", err) @@ -123,7 +157,12 @@ func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn { s.Shutdown() s.t.Fatalf("s.cmd.Start: %v", err) } - conn, err := ssh.Client(&client{wc: w2, r: r1}, config) + + return ssh.Client(&client{wc: w2, r: r1}, config) +} + +func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn { + conn, err := s.TryDial(config) if err != nil { s.t.Fail() s.Shutdown() |
