aboutsummaryrefslogtreecommitdiff
path: root/ssh
diff options
context:
space:
mode:
authorHan-Wen Nienhuys <hanwen@google.com>2013-06-21 12:46:35 -0400
committerAdam Langley <agl@golang.org>2013-06-21 12:46:35 -0400
commitafdc305bc8582a7ba5d9ea2c622ce9927a92050a (patch)
tree43bf856358cb1c757d3e650ea8d74a1b8025e0c8 /ssh
parentb88b0165229e30fa2f41d4cdfa5ac2b6e282917d (diff)
downloadgo-x-crypto-afdc305bc8582a7ba5d9ea2c622ce9927a92050a.tar.xz
go.crypto/ssh: add hook for host key checking.
R=dave, agl CC=gobot, golang-dev https://golang.org/cl/9922043
Diffstat (limited to 'ssh')
-rw-r--r--ssh/client.go22
-rw-r--r--ssh/client_auth.go12
-rw-r--r--ssh/test/session_test.go19
-rw-r--r--ssh/test/test_unix_test.go47
4 files changed, 95 insertions, 5 deletions
diff --git a/ssh/client.go b/ssh/client.go
index a42d13a..16569a8 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -26,6 +26,9 @@ type ClientConn struct {
chanList // channels associated with this connection
forwardList // forwarded tcpip connections from the remote side
globalRequest
+
+ // Address as passed to the Dial function.
+ dialAddress string
}
type globalRequest struct {
@@ -35,11 +38,17 @@ type globalRequest struct {
// Client returns a new SSH client connection using c as the underlying transport.
func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
+ return clientWithAddress(c, "", config)
+}
+
+func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) {
conn := &ClientConn{
transport: newTransport(c, config.rand()),
config: config,
globalRequest: globalRequest{response: make(chan interface{}, 1)},
+ dialAddress: addr,
}
+
if err := conn.handshake(); err != nil {
conn.Close()
return nil, fmt.Errorf("handshake failed: %v", err)
@@ -168,6 +177,12 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
return nil, nil, err
}
+ if checker := c.config.HostKeyChecker; checker != nil {
+ if err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, kexDHReply.HostKey); err != nil {
+ return nil, nil, err
+ }
+ }
+
kInt, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil {
return nil, nil, err
@@ -445,7 +460,7 @@ func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
if err != nil {
return nil, err
}
- return Client(conn, config)
+ return clientWithAddress(conn, addr, config)
}
// A ClientConfig structure is used to configure a ClientConn. After one has
@@ -463,6 +478,11 @@ type ClientConfig struct {
// of a particular RFC 4252 method will be used during authentication.
Auth []ClientAuth
+ // HostKeyChecker, if not nil, is called during the cryptographic
+ // handshake to validate the server's host key. A nil HostKeyChecker
+ // implies that all host keys are accepted.
+ HostKeyChecker HostKeyChecker
+
// Cryptographic-related configuration.
Crypto CryptoConfig
}
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index ebb74a2..5282415 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
+ "net"
)
// authenticate authenticates with the remote server. See RFC 4252.
@@ -63,6 +64,17 @@ func keys(m map[string]bool) (s []string) {
return
}
+// HostKeyChecker represents a database of known server host keys.
+type HostKeyChecker interface {
+ // Check is called during the handshake to check server's
+ // public key for unexpected changes. The hostKey argument is
+ // in SSH wire format. It can be parsed using
+ // ssh.ParsePublicKey. The address before DNS resolution is
+ // passed in the addr argument, so the key can also be checked
+ // against the hostname.
+ Check(addr string, remote net.Addr, algorithm string, hostKey []byte) error
+}
+
// A ClientAuth represents an instance of an RFC 4252 authentication method.
type ClientAuth interface {
// auth authenticates user over transport t.
diff --git a/ssh/test/session_test.go b/ssh/test/session_test.go
index 4393ee9..5c849fc 100644
--- a/ssh/test/session_test.go
+++ b/ssh/test/session_test.go
@@ -33,6 +33,25 @@ func TestRunCommandSuccess(t *testing.T) {
}
}
+func TestHostKeyCheck(t *testing.T) {
+ server := newServer(t)
+ defer server.Shutdown()
+
+ conf := clientConfig()
+ k := conf.HostKeyChecker.(*storedHostKey)
+
+ // change the key.
+ k.keys["ssh-rsa"][25]++
+
+ conn, err := server.TryDial(conf)
+ if err == nil {
+ conn.Close()
+ t.Fatalf("dial should have failed.")
+ } else if !strings.Contains(err.Error(), "host key mismatch") {
+ t.Fatalf("'host key mismatch' not found in %v", err)
+ }
+}
+
func TestRunCommandFailed(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go
index 4553254..bc4967a 100644
--- a/ssh/test/test_unix_test.go
+++ b/ssh/test/test_unix_test.go
@@ -55,14 +55,25 @@ HostbasedAuthentication no
`
var (
- configTmpl template.Template
- rsakey *rsa.PrivateKey
+ configTmpl template.Template
+ rsakey *rsa.PrivateKey
+ serializedHostKey []byte
)
func init() {
template.Must(configTmpl.Parse(sshd_config))
block, _ := pem.Decode([]byte(testClientPrivateKey))
rsakey, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
+
+ block, _ = pem.Decode([]byte(keys["ssh_host_rsa_key"]))
+ if block == nil {
+ panic("pem.Decode ssh_host_rsa_key")
+ }
+ priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ if err != nil {
+ panic("ParsePKCS1PrivateKey: " + err.Error())
+ }
+ serializedHostKey = ssh.MarshalPublicKey(&priv.PublicKey)
}
type server struct {
@@ -89,7 +100,29 @@ func username() string {
return username
}
+type storedHostKey struct {
+ // keys map from an algorithm string to binary key data.
+ keys map[string][]byte
+}
+
+func (k *storedHostKey) Add(algo string, public []byte) {
+ if k.keys == nil {
+ k.keys = map[string][]byte{}
+ }
+ k.keys[algo] = append([]byte(nil), public...)
+}
+
+func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error {
+ if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 {
+ return errors.New("host key mismatch")
+ }
+ return nil
+}
+
func clientConfig() *ssh.ClientConfig {
+ keyChecker := storedHostKey{}
+ keyChecker.Add("ssh-rsa", serializedHostKey)
+
kc := new(keychain)
kc.keys = append(kc.keys, rsakey)
config := &ssh.ClientConfig{
@@ -97,11 +130,12 @@ func clientConfig() *ssh.ClientConfig {
Auth: []ssh.ClientAuth{
ssh.ClientAuthKeyring(kc),
},
+ HostKeyChecker: &keyChecker,
}
return config
}
-func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
+func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) {
sshd, err := exec.LookPath("sshd")
if err != nil {
s.t.Skipf("skipping test: %v", err)
@@ -123,7 +157,12 @@ func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
s.Shutdown()
s.t.Fatalf("s.cmd.Start: %v", err)
}
- conn, err := ssh.Client(&client{wc: w2, r: r1}, config)
+
+ return ssh.Client(&client{wc: w2, r: r1}, config)
+}
+
+func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
+ conn, err := s.TryDial(config)
if err != nil {
s.t.Fail()
s.Shutdown()