diff options
| author | Adam Langley <agl@golang.org> | 2014-04-09 13:57:52 -0700 |
|---|---|---|
| committer | Adam Langley <agl@golang.org> | 2014-04-09 13:57:52 -0700 |
| commit | fa50e7408b9ef89ff2965535b59f1a0010c0770b (patch) | |
| tree | e045a3f48f9ffd3bb712002f8f9f6fd489e8f7ef | |
| parent | 8f45c680ceb25c200b8c301d9184532aeb7cb36e (diff) | |
| download | go-x-crypto-fa50e7408b9ef89ff2965535b59f1a0010c0770b.tar.xz | |
go.crypto/ssh: import gosshnew.
See https://groups.google.com/d/msg/Golang-nuts/AoVxQ4bB5XQ/i8kpMxdbVlEJ
R=hanwen
CC=golang-codereviews
https://golang.org/cl/86190043
53 files changed, 7215 insertions, 3601 deletions
diff --git a/ssh/agent/client.go b/ssh/agent/client.go new file mode 100644 index 0000000..9c11d32 --- /dev/null +++ b/ssh/agent/client.go @@ -0,0 +1,563 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + Package agent implements a client to an ssh-agent daemon. + +References: + [PROTOCOL.agent]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent +*/ +package agent + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "sync" + + "code.google.com/p/go.crypto/ssh" +) + +// Agent represents the capabilities of an ssh-agent. +type Agent interface { + // List returns the identities known to the agent. + List() ([]*Key, error) + + // Sign has the agent sign the data using a protocol 2 key as defined + // in [PROTOCOL.agent] section 2.6.2. + Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) + + // Insert adds a private key to the agent. If a certificate + // is given, that certificate is added as public key. + Add(s interface{}, cert *ssh.Certificate, comment string) error + + // Remove removes all identities with the given public key. + Remove(key ssh.PublicKey) error + + // RemoveAll removes all identities. + RemoveAll() error + + // Lock locks the agent. Sign and Remove will fail, and List will empty an empty list. + Lock(passphrase []byte) error + + // Unlock undoes the effect of Lock + Unlock(passphrase []byte) error + + // Signers returns signers for all the known keys. + Signers() ([]ssh.Signer, error) +} + +// See [PROTOCOL.agent], section 3. +const ( + agentRequestV1Identities = 1 + + // 3.2 Requests from client to agent for protocol 2 key operations + agentAddIdentity = 17 + agentRemoveIdentity = 18 + agentRemoveAllIdentities = 19 + agentAddIdConstrained = 25 + + // 3.3 Key-type independent requests from client to agent + agentAddSmartcardKey = 20 + agentRemoveSmartcardKey = 21 + agentLock = 22 + agentUnlock = 23 + agentAddSmartcardKeyConstrained = 26 + + // 3.7 Key constraint identifiers + agentConstrainLifetime = 1 + agentConstrainConfirm = 2 +) + +// maxAgentResponseBytes is the maximum agent reply size that is accepted. This +// is a sanity check, not a limit in the spec. +const maxAgentResponseBytes = 16 << 20 + +// Agent messages: +// These structures mirror the wire format of the corresponding ssh agent +// messages found in [PROTOCOL.agent]. + +// 3.4 Generic replies from agent to client +const agentFailure = 5 + +type failureAgentMsg struct{} + +const agentSuccess = 6 + +type successAgentMsg struct{} + +// See [PROTOCOL.agent], section 2.5.2. +const agentRequestIdentities = 11 + +type requestIdentitiesAgentMsg struct{} + +// See [PROTOCOL.agent], section 2.5.2. +const agentIdentitiesAnswer = 12 + +type identitiesAnswerAgentMsg struct { + NumKeys uint32 `sshtype:"12"` + Keys []byte `ssh:"rest"` +} + +// See [PROTOCOL.agent], section 2.6.2. +const agentSignRequest = 13 + +type signRequestAgentMsg struct { + KeyBlob []byte `sshtype:"13"` + Data []byte + Flags uint32 +} + +// See [PROTOCOL.agent], section 2.6.2. + +// 3.6 Replies from agent to client for protocol 2 key operations +const agentSignResponse = 14 + +type signResponseAgentMsg struct { + SigBlob []byte `sshtype:"14"` +} + +type publicKey struct { + Format string + Rest []byte `ssh:"rest"` +} + +// Key represents a protocol 2 public key as defined in +// [PROTOCOL.agent], section 2.5.2. +type Key struct { + Format string + Blob []byte + Comment string +} + +func clientErr(err error) error { + return fmt.Errorf("agent: client error: %v", err) +} + +// String returns the storage form of an agent key with the format, base64 +// encoded serialized key, and the comment if it is not empty. +func (k *Key) String() string { + s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob) + + if k.Comment != "" { + s += " " + k.Comment + } + + return s +} + +// Type returns the public key type. +func (k *Key) Type() string { + return k.Format +} + +// Marshal returns key blob to satisfy the ssh.PublicKey interface. +func (k *Key) Marshal() []byte { + return k.Blob +} + +// Verify satisfies the ssh.PublicKey interface, but is not +// implemented for agent keys. +func (k *Key) Verify(data []byte, sig *ssh.Signature) error { + return errors.New("agent: agent key does not know how to verify") +} + +type wireKey struct { + Format string + Rest []byte `ssh:"rest"` +} + +func parseKey(in []byte) (out *Key, rest []byte, err error) { + var record struct { + Blob []byte + Comment string + Rest []byte `ssh:"rest"` + } + + if err := ssh.Unmarshal(in, &record); err != nil { + return nil, nil, err + } + + var wk wireKey + if err := ssh.Unmarshal(record.Blob, &wk); err != nil { + return nil, nil, err + } + + return &Key{ + Format: wk.Format, + Blob: record.Blob, + Comment: record.Comment, + }, record.Rest, nil +} + +// client is a client for an ssh-agent process. +type client struct { + // conn is typically a *net.UnixConn + conn io.ReadWriter + // mu is used to prevent concurrent access to the agent + mu sync.Mutex +} + +// NewClient returns an Agent that talks to an ssh-agent process over +// the given connection. +func NewClient(rw io.ReadWriter) Agent { + return &client{conn: rw} +} + +// call sends an RPC to the agent. On success, the reply is +// unmarshaled into reply and replyType is set to the first byte of +// the reply, which contains the type of the message. +func (c *client) call(req []byte) (reply interface{}, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + msg := make([]byte, 4+len(req)) + binary.BigEndian.PutUint32(msg, uint32(len(req))) + copy(msg[4:], req) + if _, err = c.conn.Write(msg); err != nil { + return nil, clientErr(err) + } + + var respSizeBuf [4]byte + if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil { + return nil, clientErr(err) + } + respSize := binary.BigEndian.Uint32(respSizeBuf[:]) + if respSize > maxAgentResponseBytes { + return nil, clientErr(err) + } + + buf := make([]byte, respSize) + if _, err = io.ReadFull(c.conn, buf); err != nil { + return nil, clientErr(err) + } + reply, err = unmarshal(buf) + if err != nil { + return nil, clientErr(err) + } + return reply, err +} + +func (c *client) simpleCall(req []byte) error { + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +func (c *client) RemoveAll() error { + return c.simpleCall([]byte{agentRemoveAllIdentities}) +} + +func (c *client) Remove(key ssh.PublicKey) error { + req := ssh.Marshal(&agentRemoveIdentityMsg{ + KeyBlob: key.Marshal(), + }) + return c.simpleCall(req) +} + +func (c *client) Lock(passphrase []byte) error { + req := ssh.Marshal(&agentLockMsg{ + Passphrase: passphrase, + }) + return c.simpleCall(req) +} + +func (c *client) Unlock(passphrase []byte) error { + req := ssh.Marshal(&agentUnlockMsg{ + Passphrase: passphrase, + }) + return c.simpleCall(req) +} + +// List returns the identities known to the agent. +func (c *client) List() ([]*Key, error) { + // see [PROTOCOL.agent] section 2.5.2. + req := []byte{agentRequestIdentities} + + msg, err := c.call(req) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *identitiesAnswerAgentMsg: + if msg.NumKeys > maxAgentResponseBytes/8 { + return nil, errors.New("ssh: too many keys in agent reply") + } + keys := make([]*Key, msg.NumKeys) + data := msg.Keys + for i := uint32(0); i < msg.NumKeys; i++ { + var key *Key + var err error + if key, data, err = parseKey(data); err != nil { + return nil, err + } + keys[i] = key + } + return keys, nil + case *failureAgentMsg: + return nil, errors.New("ssh: failed to list keys") + } + panic("unreachable") +} + +// Sign has the agent sign the data using a protocol 2 key as defined +// in [PROTOCOL.agent] section 2.6.2. +func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { + req := ssh.Marshal(signRequestAgentMsg{ + KeyBlob: key.Marshal(), + Data: data, + }) + + msg, err := c.call(req) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *signResponseAgentMsg: + var sig ssh.Signature + if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil { + return nil, err + } + + return &sig, nil + case *failureAgentMsg: + return nil, errors.New("ssh: failed to sign challenge") + } + panic("unreachable") +} + +// unmarshal parses an agent message in packet, returning the parsed +// form and the message type of packet. +func unmarshal(packet []byte) (interface{}, error) { + if len(packet) < 1 { + return nil, errors.New("agent: empty packet") + } + var msg interface{} + switch packet[0] { + case agentFailure: + return new(failureAgentMsg), nil + case agentSuccess: + return new(successAgentMsg), nil + case agentIdentitiesAnswer: + msg = new(identitiesAnswerAgentMsg) + case agentSignResponse: + msg = new(signResponseAgentMsg) + default: + return nil, fmt.Errorf("agent: unknown type tag %d", packet[0]) + } + if err := ssh.Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} + +type rsaKeyMsg struct { + Type string `sshtype:"17"` + N *big.Int + E *big.Int + D *big.Int + Iqmp *big.Int // IQMP = Inverse Q Mod P + P *big.Int + Q *big.Int + Comments string +} + +type dsaKeyMsg struct { + Type string `sshtype:"17"` + P *big.Int + Q *big.Int + G *big.Int + Y *big.Int + X *big.Int + Comments string +} + +type ecdsaKeyMsg struct { + Type string `sshtype:"17"` + Curve string + KeyBytes []byte + D *big.Int + Comments string +} + +// Insert adds a private key to the agent. +func (c *client) insertKey(s interface{}, comment string) error { + var req []byte + switch k := s.(type) { + case *rsa.PrivateKey: + if len(k.Primes) != 2 { + return fmt.Errorf("ssh: unsupported RSA key with %d primes", len(k.Primes)) + } + k.Precompute() + req = ssh.Marshal(rsaKeyMsg{ + Type: ssh.KeyAlgoRSA, + N: k.N, + E: big.NewInt(int64(k.E)), + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comments: comment, + }) + case *dsa.PrivateKey: + req = ssh.Marshal(dsaKeyMsg{ + Type: ssh.KeyAlgoDSA, + P: k.P, + Q: k.Q, + G: k.G, + Y: k.Y, + X: k.X, + Comments: comment, + }) + case *ecdsa.PrivateKey: + nistID := fmt.Sprintf("nistp%d", k.Params().BitSize) + req = ssh.Marshal(ecdsaKeyMsg{ + Type: "ecdsa-sha2-" + nistID, + Curve: nistID, + KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y), + D: k.D, + Comments: comment, + }) + default: + return fmt.Errorf("ssh: unsupported key type %T", s) + } + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("ssh: failure") +} + +type rsaCertMsg struct { + Type string `sshtype:"17"` + CertBytes []byte + D *big.Int + Iqmp *big.Int // IQMP = Inverse Q Mod P + P *big.Int + Q *big.Int + Comments string +} + +type dsaCertMsg struct { + Type string `sshtype:"17"` + CertBytes []byte + X *big.Int + Comments string +} + +type ecdsaCertMsg struct { + Type string `sshtype:"17"` + CertBytes []byte + D *big.Int + Comments string +} + +// Insert adds a private key to the agent. If a certificate is given, +// that certificate is added instead as public key. +func (c *client) Add(s interface{}, cert *ssh.Certificate, comment string) error { + if cert == nil { + return c.insertKey(s, comment) + } else { + return c.insertCert(s, cert, comment) + } +} + +func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string) error { + var req []byte + switch k := s.(type) { + case *rsa.PrivateKey: + if len(k.Primes) != 2 { + return fmt.Errorf("ssh: unsupported RSA key with %d primes", len(k.Primes)) + } + k.Precompute() + req = ssh.Marshal(rsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comments: comment, + }) + case *dsa.PrivateKey: + req = ssh.Marshal(dsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + X: k.X, + Comments: comment, + }) + case *ecdsa.PrivateKey: + req = ssh.Marshal(ecdsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + D: k.D, + Comments: comment, + }) + default: + return fmt.Errorf("ssh: unsupported key type %T", s) + } + + signer, err := ssh.NewSignerFromKey(s) + if err != nil { + return err + } + if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { + return errors.New("ssh: signer and cert have different public key") + } + + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("ssh: failure") +} + +// Signers provides a callback for client authentication. +func (c *client) Signers() ([]ssh.Signer, error) { + keys, err := c.List() + if err != nil { + return nil, err + } + + var result []ssh.Signer + for _, k := range keys { + result = append(result, &agentKeyringSigner{c, k}) + } + return result, nil +} + +type agentKeyringSigner struct { + agent *client + pub ssh.PublicKey +} + +func (s *agentKeyringSigner) PublicKey() ssh.PublicKey { + return s.pub +} + +func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { + // The agent has its own entropy source, so the rand argument is ignored. + return s.agent.Sign(s.pub, data) +} diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go new file mode 100644 index 0000000..aa99e27 --- /dev/null +++ b/ssh/agent/client_test.go @@ -0,0 +1,270 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "bytes" + "crypto/rand" + "errors" + "net" + "os" + "os/exec" + "strconv" + "testing" + + "code.google.com/p/go.crypto/ssh" +) + +func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) { + bin, err := exec.LookPath("ssh-agent") + if err != nil { + t.Skip("could not find ssh-agent") + } + + cmd := exec.Command(bin, "-s") + out, err := cmd.Output() + if err != nil { + t.Fatalf("cmd.Output: %v", err) + } + + /* Output looks like: + + SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; + SSH_AGENT_PID=15542; export SSH_AGENT_PID; + echo Agent pid 15542; + */ + fields := bytes.Split(out, []byte(";")) + line := bytes.SplitN(fields[0], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AUTH_SOCK" { + t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) + } + socket = string(line[1]) + + line = bytes.SplitN(fields[2], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AGENT_PID" { + t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) + } + pidStr := line[1] + pid, err := strconv.Atoi(string(pidStr)) + if err != nil { + t.Fatalf("Atoi(%q): %v", pidStr, err) + } + + conn, err := net.Dial("unix", string(socket)) + if err != nil { + t.Fatalf("net.Dial: %v", err) + } + + ac := NewClient(conn) + return ac, socket, func() { + proc, _ := os.FindProcess(pid) + if proc != nil { + proc.Kill() + } + conn.Close() + } +} + +func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate) { + agent, _, cleanup := startAgent(t) + defer cleanup() + + testAgentInterface(t, agent, key, cert) +} + +func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate) { + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + t.Fatalf("NewSignerFromKey: %v", err) + } + // The agent should start up empty. + if keys, err := agent.List(); err != nil { + t.Fatalf("RequestIdentities: %v", err) + } else if len(keys) > 0 { + t.Fatalf("got %d keys, want 0: %v", len(keys), keys) + } + + // Attempt to insert the key, with certificate if specified. + var pubKey ssh.PublicKey + if cert != nil { + err = agent.Add(key, cert, "comment") + pubKey = cert + } else { + err = agent.Add(key, nil, "comment") + pubKey = signer.PublicKey() + } + if err != nil { + t.Fatalf("insert: %v", err) + } + + // Did the key get inserted successfully? + if keys, err := agent.List(); err != nil { + t.Fatalf("List: %v", err) + } else if len(keys) != 1 { + t.Fatalf("got %v, want 1 key", keys) + } else if keys[0].Comment != "comment" { + t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment") + } else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) { + t.Fatalf("key mismatch") + } + + // Can the agent make a valid signature? + data := []byte("hello") + sig, err := agent.Sign(pubKey, data) + if err != nil { + t.Fatalf("Sign: %v", err) + } + + if err := pubKey.Verify(data, sig); err != nil { + t.Fatalf("key signature Verify: %v", err) + } +} + +func TestAgent(t *testing.T) { + for _, keyType := range []string{"rsa", "dsa", "ecdsa"} { + t.Log(keyType) + testAgent(t, testPrivateKeys[keyType], nil) + } +} + +func TestCert(t *testing.T) { + cert := &ssh.Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: ssh.CertTimeInfinity, + CertType: ssh.UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + testAgent(t, testPrivateKeys["rsa"], cert) +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +func TestAuth(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + defer a.Close() + defer b.Close() + + agent, _, cleanup := startAgent(t) + defer cleanup() + + if err := agent.Add(testPrivateKeys["rsa"], nil, "comment"); err != nil { + t.Errorf("Add: %v", err) + } + + serverConf := ssh.ServerConfig{} + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, errors.New("pubkey rejected") + } + + go func() { + conn, _, _, err := ssh.NewServerConn(a, &serverConf) + if err != nil { + t.Fatalf("Server: %v", err) + } + conn.Close() + }() + + conf := ssh.ClientConfig{} + conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) + conn, _, _, err := ssh.NewClientConn(b, "", &conf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + conn.Close() +} + +func TestLockClient(t *testing.T) { + agent, _, cleanup := startAgent(t) + defer cleanup() + testLockAgent(agent, t) +} + +func testLockAgent(agent Agent, t *testing.T) { + if err := agent.Add(testPrivateKeys["rsa"], nil, "comment 1"); err != nil { + t.Errorf("Add: %v", err) + } + if err := agent.Add(testPrivateKeys["dsa"], nil, "comment dsa"); err != nil { + t.Errorf("Add: %v", err) + } + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 2 { + t.Errorf("Want 2 keys, got %v", keys) + } + + passphrase := []byte("secret") + if err := agent.Lock(passphrase); err != nil { + t.Errorf("Lock: %v", err) + } + + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 0 { + t.Errorf("Want 0 keys, got %v", keys) + } + + signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"]) + if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil { + t.Fatalf("Sign did not fail") + } + + if err := agent.Remove(signer.PublicKey()); err == nil { + t.Fatalf("Remove did not fail") + } + + if err := agent.RemoveAll(); err == nil { + t.Fatalf("RemoveAll did not fail") + } + + if err := agent.Unlock(nil); err == nil { + t.Errorf("Unlock with wrong passphrase succeeded") + } + if err := agent.Unlock(passphrase); err != nil { + t.Errorf("Unlock: %v", err) + } + + if err := agent.Remove(signer.PublicKey()); err != nil { + t.Fatalf("Remove: %v", err) + } + + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 1 { + t.Errorf("Want 1 keys, got %v", keys) + } +} diff --git a/ssh/agent/forward.go b/ssh/agent/forward.go new file mode 100644 index 0000000..dd45c3e --- /dev/null +++ b/ssh/agent/forward.go @@ -0,0 +1,103 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "errors" + "io" + "net" + "sync" + + "code.google.com/p/go.crypto/ssh" +) + +// RequestAgentForwarding sets up agent forwarding for the session. +// SetupForwardKeyring or SetupForwardAgent should be called to route +// the authentication requests. +func RequestAgentForwarding(session *ssh.Session) error { + ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) + if err != nil { + return err + } + if !ok { + return errors.New("forwarding request denied") + } + return nil +} + +// ForwardToAgent routes authentication requests to the given keyring. +func ForwardToAgent(client *ssh.Client, keyring Agent) error { + channels := client.HandleChannelOpen(channelType) + if channels == nil { + return errors.New("agent: already have handler for " + channelType) + } + + go func() { + for ch := range channels { + channel, reqs, err := ch.Accept() + if err != nil { + continue + } + go ssh.DiscardRequests(reqs) + go func() { + ServeAgent(keyring, channel) + channel.Close() + }() + } + }() + return nil +} + +const channelType = "auth-agent@openssh.com" + +// ForwardToRemote routes authentication requests to the ssh-agent +// process serving on the given unix socket. +func ForwardToRemote(client *ssh.Client, addr string) error { + channels := client.HandleChannelOpen(channelType) + if channels == nil { + return errors.New("agent: already have handler for " + channelType) + } + conn, err := net.Dial("unix", addr) + if err != nil { + return err + } + conn.Close() + + go func() { + for ch := range channels { + channel, reqs, err := ch.Accept() + if err != nil { + continue + } + go ssh.DiscardRequests(reqs) + go forwardUnixSocket(channel, addr) + } + }() + return nil +} + +func forwardUnixSocket(channel ssh.Channel, addr string) { + conn, err := net.Dial("unix", addr) + if err != nil { + return + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + io.Copy(conn, channel) + conn.(*net.UnixConn).CloseWrite() + wg.Done() + }() + go func() { + io.Copy(channel, conn) + channel.CloseWrite() + wg.Done() + }() + + wg.Wait() + conn.Close() + channel.Close() +} diff --git a/ssh/agent/keyring.go b/ssh/agent/keyring.go new file mode 100644 index 0000000..ecfa66f --- /dev/null +++ b/ssh/agent/keyring.go @@ -0,0 +1,183 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "bytes" + "crypto/rand" + "crypto/subtle" + "errors" + "fmt" + "sync" + + "code.google.com/p/go.crypto/ssh" +) + +type privKey struct { + signer ssh.Signer + comment string +} + +type keyring struct { + mu sync.Mutex + keys []privKey + + locked bool + passphrase []byte +} + +var errLocked = errors.New("agent: locked") + +// NewKeyring returns an Agent that holds keys in memory. It is safe +// for concurrent use by multiple goroutines. +func NewKeyring() Agent { + return &keyring{} +} + +// RemoveAll removes all identities. +func (r *keyring) RemoveAll() error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + r.keys = nil + return nil +} + +// Remove removes all identities with the given public key. +func (r *keyring) Remove(key ssh.PublicKey) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + want := key.Marshal() + found := false + for i := 0; i < len(r.keys); { + if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) { + found = true + r.keys[i] = r.keys[len(r.keys)-1] + r.keys = r.keys[len(r.keys)-1:] + continue + } else { + i++ + } + } + + if !found { + return errors.New("agent: key not found") + } + return nil +} + +// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list. +func (r *keyring) Lock(passphrase []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + r.locked = true + r.passphrase = passphrase + return nil +} + +// Unlock undoes the effect of Lock +func (r *keyring) Unlock(passphrase []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + if !r.locked { + return errors.New("agent: not locked") + } + if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) { + return fmt.Errorf("agent: incorrect passphrase") + } + + r.locked = false + r.passphrase = nil + return nil +} + +// List returns the identities known to the agent. +func (r *keyring) List() ([]*Key, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + // section 2.7: locked agents return empty. + return nil, nil + } + + var ids []*Key + for _, k := range r.keys { + pub := k.signer.PublicKey() + ids = append(ids, &Key{ + Format: pub.Type(), + Blob: pub.Marshal(), + Comment: k.comment}) + } + return ids, nil +} + +// Insert adds a private key to the keyring. If a certificate +// is given, that certificate is added as public key. +func (r *keyring) Add(priv interface{}, cert *ssh.Certificate, comment string) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + signer, err := ssh.NewSignerFromKey(priv) + + if err != nil { + return err + } + + if cert != nil { + signer, err = ssh.NewCertSigner(cert, signer) + if err != nil { + return err + } + } + + r.keys = append(r.keys, privKey{signer, comment}) + + return nil +} + +// Sign returns a signature for the data. +func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return nil, errLocked + } + + wanted := key.Marshal() + for _, k := range r.keys { + if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) { + return k.signer.Sign(rand.Reader, data) + } + } + return nil, errors.New("not found") +} + +// Signers returns signers for all the known keys. +func (r *keyring) Signers() ([]ssh.Signer, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return nil, errLocked + } + + s := make([]ssh.Signer, len(r.keys)) + for _, k := range r.keys { + s = append(s, k.signer) + } + return s, nil +} diff --git a/ssh/agent/server.go b/ssh/agent/server.go new file mode 100644 index 0000000..2d55dc9 --- /dev/null +++ b/ssh/agent/server.go @@ -0,0 +1,209 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "crypto/rsa" + "encoding/binary" + "fmt" + "io" + "log" + "math/big" + + "code.google.com/p/go.crypto/ssh" +) + +// Server wraps an Agent and uses it to implement the agent side of +// the SSH-agent, wire protocol. +type server struct { + agent Agent +} + +func (s *server) processRequestBytes(reqData []byte) []byte { + rep, err := s.processRequest(reqData) + if err != nil { + if err != errLocked { + // TODO(hanwen): provide better logging interface? + log.Printf("agent %d: %v", reqData[0], err) + } + return []byte{agentFailure} + } + + if err == nil && rep == nil { + return []byte{agentSuccess} + } + + return ssh.Marshal(rep) +} + +func marshalKey(k *Key) []byte { + var record struct { + Blob []byte + Comment string + } + record.Blob = k.Marshal() + record.Comment = k.Comment + + return ssh.Marshal(&record) +} + +type agentV1IdentityMsg struct { + Numkeys uint32 `sshtype:"2"` +} + +type agentRemoveIdentityMsg struct { + KeyBlob []byte `sshtype:"18"` +} + +type agentLockMsg struct { + Passphrase []byte `sshtype:"22"` +} + +type agentUnlockMsg struct { + Passphrase []byte `sshtype:"23"` +} + +func (s *server) processRequest(data []byte) (interface{}, error) { + switch data[0] { + case agentRequestV1Identities: + return &agentV1IdentityMsg{0}, nil + case agentRemoveIdentity: + var req agentRemoveIdentityMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + var wk wireKey + if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { + return nil, err + } + + return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) + + case agentRemoveAllIdentities: + return nil, s.agent.RemoveAll() + + case agentLock: + var req agentLockMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + return nil, s.agent.Lock(req.Passphrase) + + case agentUnlock: + var req agentLockMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + return nil, s.agent.Unlock(req.Passphrase) + + case agentSignRequest: + var req signRequestAgentMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + var wk wireKey + if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { + return nil, err + } + + k := &Key{ + Format: wk.Format, + Blob: req.KeyBlob, + } + + sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags. + if err != nil { + return nil, err + } + return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil + case agentRequestIdentities: + keys, err := s.agent.List() + if err != nil { + return nil, err + } + + rep := identitiesAnswerAgentMsg{ + NumKeys: uint32(len(keys)), + } + for _, k := range keys { + rep.Keys = append(rep.Keys, marshalKey(k)...) + } + return rep, nil + case agentAddIdentity: + return nil, s.insertIdentity(data) + } + + return nil, fmt.Errorf("unknown opcode %d", data[0]) +} + +func (s *server) insertIdentity(req []byte) error { + var record struct { + Type string `sshtype:"17"` + Rest []byte `ssh:"rest"` + } + if err := ssh.Unmarshal(req, &record); err != nil { + return err + } + + switch record.Type { + case ssh.KeyAlgoRSA: + var k rsaKeyMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return err + } + + priv := rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + E: int(k.E.Int64()), + N: k.N, + }, + D: k.D, + Primes: []*big.Int{k.P, k.Q}, + } + priv.Precompute() + + return s.agent.Add(&priv, nil, k.Comments) + } + return fmt.Errorf("not implemented: %s", record.Type) +} + +// ServeAgent serves the agent protocol on the given connection. It +// returns when an I/O error occurs. +func ServeAgent(agent Agent, c io.ReadWriter) error { + s := &server{agent} + + var length [4]byte + for { + if _, err := io.ReadFull(c, length[:]); err != nil { + return err + } + l := binary.BigEndian.Uint32(length[:]) + if l > maxAgentResponseBytes { + // We also cap requests. + return fmt.Errorf("agent: request too large: %d", l) + } + + req := make([]byte, l) + if _, err := io.ReadFull(c, req); err != nil { + return err + } + + repData := s.processRequestBytes(req) + if len(repData) > maxAgentResponseBytes { + return fmt.Errorf("agent: reply too large: %d bytes", len(repData)) + } + + binary.BigEndian.PutUint32(length[:], uint32(len(repData))) + if _, err := c.Write(length[:]); err != nil { + return err + } + if _, err := c.Write(repData); err != nil { + return err + } + } +} diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go new file mode 100644 index 0000000..ad2996b --- /dev/null +++ b/ssh/agent/server_test.go @@ -0,0 +1,77 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "testing" + + "code.google.com/p/go.crypto/ssh" +) + +func TestServer(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + client := NewClient(c1) + + go ServeAgent(NewKeyring(), c2) + + testAgentInterface(t, client, testPrivateKeys["rsa"], nil) +} + +func TestLockServer(t *testing.T) { + testLockAgent(NewKeyring(), t) +} + +func TestSetupForwardAgent(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + defer a.Close() + defer b.Close() + + _, socket, cleanup := startAgent(t) + defer cleanup() + + serverConf := ssh.ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + incoming := make(chan *ssh.ServerConn, 1) + go func() { + conn, _, _, err := ssh.NewServerConn(a, &serverConf) + if err != nil { + t.Fatalf("Server: %v", err) + } + incoming <- conn + }() + + conf := ssh.ClientConfig{} + conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + client := ssh.NewClient(conn, chans, reqs) + + if err := ForwardToRemote(client, socket); err != nil { + t.Fatalf("SetupForwardAgent: %v", err) + } + + server := <-incoming + ch, reqs, err := server.OpenChannel(channelType, nil) + if err != nil { + t.Fatalf("OpenChannel(%q): %v", channelType, err) + } + go ssh.DiscardRequests(reqs) + + agentClient := NewClient(ch) + testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil) + conn.Close() +} diff --git a/ssh/agent/testdata_test.go b/ssh/agent/testdata_test.go new file mode 100644 index 0000000..6bb75a9 --- /dev/null +++ b/ssh/agent/testdata_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package agent + +import ( + "crypto/rand" + "fmt" + + "code.google.com/p/go.crypto/ssh" + "code.google.com/p/go.crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]ssh.Signer + testPublicKeys map[string]ssh.PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]ssh.Signer, n) + testPublicKeys = make(map[string]ssh.PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &ssh.Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/ssh/benchmark_test.go b/ssh/benchmark_test.go new file mode 100644 index 0000000..d9f7eb9 --- /dev/null +++ b/ssh/benchmark_test.go @@ -0,0 +1,122 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "io" + "net" + "testing" +) + +type server struct { + *ServerConn + chans <-chan NewChannel +} + +func newServer(c net.Conn, conf *ServerConfig) (*server, error) { + sconn, chans, reqs, err := NewServerConn(c, conf) + if err != nil { + return nil, err + } + go DiscardRequests(reqs) + return &server{sconn, chans}, nil +} + +func (s *server) Accept() (NewChannel, error) { + n, ok := <-s.chans + if !ok { + return nil, io.EOF + } + return n, nil +} + +func sshPipe() (Conn, *server, error) { + c1, c2, err := netPipe() + if err != nil { + return nil, nil, err + } + + clientConf := ClientConfig{ + User: "user", + } + serverConf := ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + done := make(chan *server, 1) + go func() { + server, err := newServer(c2, &serverConf) + if err != nil { + done <- nil + } + done <- server + }() + + client, _, reqs, err := NewClientConn(c1, "", &clientConf) + if err != nil { + return nil, nil, err + } + + server := <-done + if server == nil { + return nil, nil, errors.New("server handshake failed.") + } + go DiscardRequests(reqs) + + return client, server, nil +} + +func BenchmarkEndToEnd(b *testing.B) { + b.StopTimer() + + client, server, err := sshPipe() + if err != nil { + b.Fatalf("sshPipe: %v", err) + } + + defer client.Close() + defer server.Close() + + size := (1 << 20) + input := make([]byte, size) + output := make([]byte, size) + b.SetBytes(int64(size)) + done := make(chan int, 1) + + go func() { + newCh, err := server.Accept() + if err != nil { + b.Fatalf("Client: %v", err) + } + ch, incoming, err := newCh.Accept() + go DiscardRequests(incoming) + for i := 0; i < b.N; i++ { + if _, err := io.ReadFull(ch, output); err != nil { + b.Fatalf("ReadFull: %v", err) + } + } + ch.Close() + done <- 1 + }() + + ch, in, err := client.OpenChannel("speed", nil) + if err != nil { + b.Fatalf("OpenChannel: %v", err) + } + go DiscardRequests(in) + + b.ResetTimer() + b.StartTimer() + for i := 0; i < b.N; i++ { + if _, err := ch.Write(input); err != nil { + b.Fatalf("WriteFull: %v", err) + } + } + ch.Close() + b.StopTimer() + + <-done +} diff --git a/ssh/buffer.go b/ssh/buffer.go index 601dad3..6931b51 100644 --- a/ssh/buffer.go +++ b/ssh/buffer.go @@ -43,29 +43,29 @@ func newBuffer() *buffer { // buf must not be modified after the call to write. func (b *buffer) write(buf []byte) { b.Cond.L.Lock() - defer b.Cond.L.Unlock() e := &element{buf: buf} b.tail.next = e b.tail = e b.Cond.Signal() + b.Cond.L.Unlock() } // eof closes the buffer. Reads from the buffer once all // the data has been consumed will receive os.EOF. func (b *buffer) eof() error { b.Cond.L.Lock() - defer b.Cond.L.Unlock() b.closed = true b.Cond.Signal() + b.Cond.L.Unlock() return nil } -// Read reads data from the internal buffer in buf. -// Reads will block if no data is available, or until -// the buffer is closed. +// Read reads data from the internal buffer in buf. Reads will block +// if no data is available, or until the buffer is closed. func (b *buffer) Read(buf []byte) (n int, err error) { b.Cond.L.Lock() defer b.Cond.L.Unlock() + for len(buf) > 0 { // if there is data in b.head, copy it if len(b.head.buf) > 0 { @@ -79,10 +79,12 @@ func (b *buffer) Read(buf []byte) (n int, err error) { b.head = b.head.next continue } + // if at least one byte has been copied, return if n > 0 { break } + // if nothing was read, and there is nothing outstanding // check to see if the buffer is closed. if b.closed { diff --git a/ssh/buffer_test.go b/ssh/buffer_test.go index 135c4ae..d5781cb 100644 --- a/ssh/buffer_test.go +++ b/ssh/buffer_test.go @@ -9,33 +9,33 @@ import ( "testing" ) -var BYTES = []byte("abcdefghijklmnopqrstuvwxyz") +var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") func TestBufferReadwrite(t *testing.T) { b := newBuffer() - b.write(BYTES[:10]) + b.write(alphabet[:10]) r, _ := b.Read(make([]byte, 10)) if r != 10 { t.Fatalf("Expected written == read == 10, written: 10, read %d", r) } b = newBuffer() - b.write(BYTES[:5]) + b.write(alphabet[:5]) r, _ = b.Read(make([]byte, 10)) if r != 5 { t.Fatalf("Expected written == read == 5, written: 5, read %d", r) } b = newBuffer() - b.write(BYTES[:10]) + b.write(alphabet[:10]) r, _ = b.Read(make([]byte, 5)) if r != 5 { t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) } b = newBuffer() - b.write(BYTES[:5]) - b.write(BYTES[5:15]) + b.write(alphabet[:5]) + b.write(alphabet[5:15]) r, _ = b.Read(make([]byte, 10)) r2, _ := b.Read(make([]byte, 10)) if r != 10 || r2 != 5 || 15 != r+r2 { @@ -45,14 +45,14 @@ func TestBufferReadwrite(t *testing.T) { func TestBufferClose(t *testing.T) { b := newBuffer() - b.write(BYTES[:10]) + b.write(alphabet[:10]) b.eof() _, err := b.Read(make([]byte, 5)) if err != nil { t.Fatal("expected read of 5 to not return EOF") } b = newBuffer() - b.write(BYTES[:10]) + b.write(alphabet[:10]) b.eof() r, err := b.Read(make([]byte, 5)) r2, err2 := b.Read(make([]byte, 10)) @@ -61,7 +61,7 @@ func TestBufferClose(t *testing.T) { } b = newBuffer() - b.write(BYTES[:10]) + b.write(alphabet[:10]) b.eof() r, err = b.Read(make([]byte, 5)) r2, err2 = b.Read(make([]byte, 10)) diff --git a/ssh/certs.go b/ssh/certs.go index d958f31..9962ff0 100644 --- a/ssh/certs.go +++ b/ssh/certs.go @@ -5,6 +5,12 @@ package ssh import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sort" "time" ) @@ -18,342 +24,432 @@ const ( CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" ) -// Certificate types are used to specify whether a certificate is for identification -// of a user or a host. Current identities are defined in [PROTOCOL.certkeys]. +// Certificate types distinguish between host and user +// certificates. The values can be set in the CertType field of +// Certificate. const ( UserCert = 1 HostCert = 2 ) -type signature struct { +// Signature represents a cryptographic signature. +type Signature struct { Format string Blob []byte } -type tuple struct { - Name string - Data string -} +// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that +// a certificate does not expire. +const CertTimeInfinity = 1<<64 - 1 -const ( - maxUint64 = 1<<64 - 1 - maxInt64 = 1<<63 - 1 -) +// An Certificate represents an OpenSSH certificate as defined in +// [PROTOCOL.certkeys]?rev=1.8. +type Certificate struct { + Nonce []byte + Key PublicKey + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []string + ValidAfter uint64 + ValidBefore uint64 + Permissions + Reserved []byte + SignatureKey PublicKey + Signature *Signature +} -// CertTime represents an unsigned 64-bit time value in seconds starting from -// UNIX epoch. We use CertTime instead of time.Time in order to properly handle -// the "infinite" time value ^0, which would become negative when expressed as -// an int64. -type CertTime uint64 +// genericCertData holds the key-independent part of the certificate data. +// Overall, certificates contain an nonce, public key fields and +// key-independent fields. +type genericCertData struct { + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []byte + ValidAfter uint64 + ValidBefore uint64 + CriticalOptions []byte + Extensions []byte + Reserved []byte + SignatureKey []byte + Signature []byte +} -func (ct CertTime) Time() time.Time { - if ct > maxInt64 { - return time.Unix(maxInt64, 0) +func marshalStringList(namelist []string) []byte { + var to []byte + for _, name := range namelist { + s := struct{ N string }{name} + to = append(to, Marshal(&s)...) } - return time.Unix(int64(ct), 0) + return to } -func (ct CertTime) IsInfinite() bool { - return ct == maxUint64 -} +func marshalTuples(tups map[string]string) []byte { + keys := make([]string, 0, len(tups)) + for k := range tups { + keys = append(keys, k) + } + sort.Strings(keys) -// An OpenSSHCertV01 represents an OpenSSH certificate as defined in -// [PROTOCOL.certkeys]?rev=1.8. -type OpenSSHCertV01 struct { - Nonce []byte - Key PublicKey - Serial uint64 - Type uint32 - KeyId string - ValidPrincipals []string - ValidAfter, ValidBefore CertTime - CriticalOptions []tuple - Extensions []tuple - Reserved []byte - SignatureKey PublicKey - Signature *signature + var r []byte + for _, k := range keys { + s := struct{ K, V string }{k, tups[k]} + r = append(r, Marshal(&s)...) + } + return r } -// validateOpenSSHCertV01Signature uses the cert's SignatureKey to verify that -// the cert's Signature.Blob is the result of signing the cert bytes starting -// from the algorithm string and going up to and including the SignatureKey. -func validateOpenSSHCertV01Signature(cert *OpenSSHCertV01) bool { - return cert.SignatureKey.Verify(cert.BytesForSigning(), cert.Signature.Blob) -} +func parseTuples(in []byte) (map[string]string, error) { + tups := map[string]string{} + var lastKey string + var haveLastKey bool -var certAlgoNames = map[string]string{ - KeyAlgoRSA: CertAlgoRSAv01, - KeyAlgoDSA: CertAlgoDSAv01, - KeyAlgoECDSA256: CertAlgoECDSA256v01, - KeyAlgoECDSA384: CertAlgoECDSA384v01, - KeyAlgoECDSA521: CertAlgoECDSA521v01, -} + for len(in) > 0 { + nameBytes, rest, ok := parseString(in) + if !ok { + return nil, errShortRead + } + data, rest, ok := parseString(rest) + if !ok { + return nil, errShortRead + } + name := string(nameBytes) -// certToPrivAlgo returns the underlying algorithm for a certificate algorithm. -// Panics if a non-certificate algorithm is passed. -func certToPrivAlgo(algo string) string { - for privAlgo, pubAlgo := range certAlgoNames { - if pubAlgo == algo { - return privAlgo + // according to [PROTOCOL.certkeys], the names must be in + // lexical order. + if haveLastKey && name <= lastKey { + return nil, fmt.Errorf("ssh: certificate options are not in lexical order") } + lastKey, haveLastKey = name, true + + tups[name] = string(data) + in = rest } - panic("unknown cert algorithm") + return tups, nil } -func (cert *OpenSSHCertV01) marshal(includeAlgo, includeSig bool) []byte { - algoName := cert.PublicKeyAlgo() - pubKey := cert.Key.Marshal() - sigKey := MarshalPublicKey(cert.SignatureKey) +func parseCert(in []byte, privAlgo string) (*Certificate, error) { + nonce, rest, ok := parseString(in) + if !ok { + return nil, errShortRead + } + + key, rest, err := parsePubKey(rest, privAlgo) + if err != nil { + return nil, err + } - var length int - if includeAlgo { - length += stringLength(len(algoName)) + var g genericCertData + if err := Unmarshal(rest, &g); err != nil { + return nil, err } - length += stringLength(len(cert.Nonce)) - length += len(pubKey) - length += 8 // Length of Serial - length += 4 // Length of Type - length += stringLength(len(cert.KeyId)) - length += lengthPrefixedNameListLength(cert.ValidPrincipals) - length += 8 // Length of ValidAfter - length += 8 // Length of ValidBefore - length += tupleListLength(cert.CriticalOptions) - length += tupleListLength(cert.Extensions) - length += stringLength(len(cert.Reserved)) - length += stringLength(len(sigKey)) - if includeSig { - length += signatureLength(cert.Signature) + + c := &Certificate{ + Nonce: nonce, + Key: key, + Serial: g.Serial, + CertType: g.CertType, + KeyId: g.KeyId, + ValidAfter: g.ValidAfter, + ValidBefore: g.ValidBefore, } - ret := make([]byte, length) - r := ret - if includeAlgo { - r = marshalString(r, []byte(algoName)) + for principals := g.ValidPrincipals; len(principals) > 0; { + principal, rest, ok := parseString(principals) + if !ok { + return nil, errShortRead + } + c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) + principals = rest + } + + c.CriticalOptions, err = parseTuples(g.CriticalOptions) + if err != nil { + return nil, err } - r = marshalString(r, cert.Nonce) - copy(r, pubKey) - r = r[len(pubKey):] - r = marshalUint64(r, cert.Serial) - r = marshalUint32(r, cert.Type) - r = marshalString(r, []byte(cert.KeyId)) - r = marshalLengthPrefixedNameList(r, cert.ValidPrincipals) - r = marshalUint64(r, uint64(cert.ValidAfter)) - r = marshalUint64(r, uint64(cert.ValidBefore)) - r = marshalTupleList(r, cert.CriticalOptions) - r = marshalTupleList(r, cert.Extensions) - r = marshalString(r, cert.Reserved) - r = marshalString(r, sigKey) - if includeSig { - r = marshalSignature(r, cert.Signature) + c.Extensions, err = parseTuples(g.Extensions) + if err != nil { + return nil, err } - if len(r) > 0 { - panic("ssh: internal error, marshaling certificate did not fill the entire buffer") + c.Reserved = g.Reserved + k, err := ParsePublicKey(g.SignatureKey) + if err != nil { + return nil, err } - return ret -} -func (cert *OpenSSHCertV01) BytesForSigning() []byte { - return cert.marshal(true, false) + c.SignatureKey = k + c.Signature, rest, ok = parseSignatureBody(g.Signature) + if !ok || len(rest) > 0 { + return nil, errors.New("ssh: signature parse error") + } + + return c, nil } -func (cert *OpenSSHCertV01) Marshal() []byte { - return cert.marshal(false, true) +type openSSHCertSigner struct { + pub *Certificate + signer Signer } -func (c *OpenSSHCertV01) PublicKeyAlgo() string { - algo, ok := certAlgoNames[c.Key.PublicKeyAlgo()] - if !ok { - panic("unknown cert key type") +// NewCertSigner returns a Signer that signs with the given Certificate, whose +// private key is held by signer. It returns an error if the public key in cert +// doesn't match the key used by signer. +func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { + if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { + return nil, errors.New("ssh: signer and cert have different public key") } - return algo + + return &openSSHCertSigner{cert, signer}, nil } -func (c *OpenSSHCertV01) PrivateKeyAlgo() string { - return c.Key.PrivateKeyAlgo() +func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.signer.Sign(rand, data) } -func (c *OpenSSHCertV01) Verify(data []byte, sig []byte) bool { - return c.Key.Verify(data, sig) +func (s *openSSHCertSigner) PublicKey() PublicKey { + return s.pub } -func parseOpenSSHCertV01(in []byte, algo string) (out *OpenSSHCertV01, rest []byte, ok bool) { - cert := new(OpenSSHCertV01) +const sourceAddressCriticalOption = "source-address" - if cert.Nonce, in, ok = parseString(in); !ok { - return - } +// CertChecker does the work of verifying a certificate. Its methods +// can be plugged into ClientConfig.HostKeyCallback and +// ServerConfig.PublicKeyCallback. For the CertChecker to work, +// minimally, the IsAuthority callback should be set. +type CertChecker struct { + // SupportedCriticalOptions lists the CriticalOptions that the + // server application layer understands. These are only used + // for user certificates. + SupportedCriticalOptions []string + + // IsAuthority should return true if the key is recognized as + // an authority. This allows for certificates to be signed by other + // certificates. + IsAuthority func(auth PublicKey) bool + + // Clock is used for verifying time stamps. If nil, time.Now + // is used. + Clock func() time.Time + + // UserKeyFallback is called when CertChecker.Authenticate encounters a + // public key that is not a certificate. It must implement validation + // of user keys or else, if nil, all such keys are rejected. + UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) - privAlgo := certToPrivAlgo(algo) - cert.Key, in, ok = parsePubKey(in, privAlgo) + // HostKeyFallback is called when CertChecker.CheckHostKey encounters a + // public key that is not a certificate. It must implement host key + // validation or else, if nil, all such keys are rejected. + HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error + + // IsRevoked is called for each certificate so that revocation checking + // can be implemented. It should return true if the given certificate + // is revoked and false otherwise. If nil, no certificates are + // considered to have been revoked. + IsRevoked func(cert *Certificate) bool +} + +// CheckHostKey checks a host key certificate. This method can be +// plugged into ClientConfig.HostKeyCallback. +func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { + cert, ok := key.(*Certificate) if !ok { - return + if c.HostKeyFallback != nil { + return c.HostKeyFallback(addr, remote, key) + } + return errors.New("ssh: non-certificate host key") } - - // We test PublicKeyAlgo to make sure we don't use some weird sub-cert. - if cert.Key.PublicKeyAlgo() != privAlgo { - ok = false - return + if cert.CertType != HostCert { + return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) } - if cert.Serial, in, ok = parseUint64(in); !ok { - return - } + return c.CheckCert(addr, cert) +} - if cert.Type, in, ok = parseUint32(in); !ok { - return +// Authenticate checks a user certificate. Authenticate can be used as +// a value for ServerConfig.PublicKeyCallback. +func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { + cert, ok := pubKey.(*Certificate) + if !ok { + if c.UserKeyFallback != nil { + return c.UserKeyFallback(conn, pubKey) + } + return nil, errors.New("ssh: normal key pairs not accepted") } - keyId, in, ok := parseString(in) - if !ok { - return + if cert.CertType != UserCert { + return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) } - cert.KeyId = string(keyId) - if cert.ValidPrincipals, in, ok = parseLengthPrefixedNameList(in); !ok { - return + if err := c.CheckCert(conn.User(), cert); err != nil { + return nil, err } - va, in, ok := parseUint64(in) - if !ok { - return + return &cert.Permissions, nil +} + +// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and +// the signature of the certificate. +func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { + if c.IsRevoked != nil && c.IsRevoked(cert) { + return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) } - cert.ValidAfter = CertTime(va) - vb, in, ok := parseUint64(in) - if !ok { - return + for opt, _ := range cert.CriticalOptions { + // sourceAddressCriticalOption will be enforced by + // serverAuthenticate + if opt == sourceAddressCriticalOption { + continue + } + + found := false + for _, supp := range c.SupportedCriticalOptions { + if supp == opt { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) + } } - cert.ValidBefore = CertTime(vb) - if cert.CriticalOptions, in, ok = parseTupleList(in); !ok { - return + if len(cert.ValidPrincipals) > 0 { + // By default, certs are valid for all users/hosts. + found := false + for _, p := range cert.ValidPrincipals { + if p == principal { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) + } } - if cert.Extensions, in, ok = parseTupleList(in); !ok { - return + if !c.IsAuthority(cert.SignatureKey) { + return fmt.Errorf("ssh: certificate signed by unrecognized authority") } - if cert.Reserved, in, ok = parseString(in); !ok { - return + clock := c.Clock + if clock == nil { + clock = time.Now } - sigKey, in, ok := parseString(in) - if !ok { - return + unixNow := clock().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return fmt.Errorf("ssh: cert is not yet valid") } - if cert.SignatureKey, _, ok = ParsePublicKey(sigKey); !ok { - return + if before := int64(cert.ValidBefore); cert.ValidBefore != CertTimeInfinity && (unixNow >= before || before < 0) { + return fmt.Errorf("ssh: cert has expired") } - - if cert.Signature, in, ok = parseSignature(in); !ok { - return + if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { + return fmt.Errorf("ssh: certificate signature does not verify") } - ok = true - return cert, in, ok + return nil } -func lengthPrefixedNameListLength(namelist []string) int { - length := 4 // length prefix for list - for _, name := range namelist { - length += 4 // length prefix for name - length += len(name) +// SignCert sets c.SignatureKey to the authority's public key and stores a +// Signature, by authority, in the certificate. +func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { + c.Nonce = make([]byte, 32) + if _, err := io.ReadFull(rand, c.Nonce); err != nil { + return err } - return length -} + c.SignatureKey = authority.PublicKey() -func marshalLengthPrefixedNameList(to []byte, namelist []string) []byte { - length := uint32(lengthPrefixedNameListLength(namelist) - 4) - to = marshalUint32(to, length) - for _, name := range namelist { - to = marshalString(to, []byte(name)) + sig, err := authority.Sign(rand, c.bytesForSigning()) + if err != nil { + return err } - return to + c.Signature = sig + return nil } -func parseLengthPrefixedNameList(in []byte) (out []string, rest []byte, ok bool) { - list, rest, ok := parseString(in) - if !ok { - return - } +var certAlgoNames = map[string]string{ + KeyAlgoRSA: CertAlgoRSAv01, + KeyAlgoDSA: CertAlgoDSAv01, + KeyAlgoECDSA256: CertAlgoECDSA256v01, + KeyAlgoECDSA384: CertAlgoECDSA384v01, + KeyAlgoECDSA521: CertAlgoECDSA521v01, +} - for len(list) > 0 { - var next []byte - if next, list, ok = parseString(list); !ok { - return nil, nil, false +// certToPrivAlgo returns the underlying algorithm for a certificate algorithm. +// Panics if a non-certificate algorithm is passed. +func certToPrivAlgo(algo string) string { + for privAlgo, pubAlgo := range certAlgoNames { + if pubAlgo == algo { + return privAlgo } - out = append(out, string(next)) } - ok = true - return + panic("unknown cert algorithm") } -func tupleListLength(tupleList []tuple) int { - length := 4 // length prefix for list - for _, t := range tupleList { - length += 4 // length prefix for t.Name - length += len(t.Name) - length += 4 // length prefix for t.Data - length += len(t.Data) - } - return length +func (cert *Certificate) bytesForSigning() []byte { + c2 := *cert + c2.Signature = nil + out := c2.Marshal() + // Drop trailing signature length. + return out[:len(out)-4] } -func marshalTupleList(to []byte, tuplelist []tuple) []byte { - length := uint32(tupleListLength(tuplelist) - 4) - to = marshalUint32(to, length) - for _, t := range tuplelist { - to = marshalString(to, []byte(t.Name)) - to = marshalString(to, []byte(t.Data)) +// Marshal serializes c into OpenSSH's wire format. It is part of the +// PublicKey interface. +func (c *Certificate) Marshal() []byte { + generic := genericCertData{ + Serial: c.Serial, + CertType: c.CertType, + KeyId: c.KeyId, + ValidPrincipals: marshalStringList(c.ValidPrincipals), + ValidAfter: uint64(c.ValidAfter), + ValidBefore: uint64(c.ValidBefore), + CriticalOptions: marshalTuples(c.CriticalOptions), + Extensions: marshalTuples(c.Extensions), + Reserved: c.Reserved, + SignatureKey: c.SignatureKey.Marshal(), } - return to -} - -func parseTupleList(in []byte) (out []tuple, rest []byte, ok bool) { - list, rest, ok := parseString(in) - if !ok { - return + if c.Signature != nil { + generic.Signature = Marshal(c.Signature) } + genericBytes := Marshal(&generic) + keyBytes := c.Key.Marshal() + _, keyBytes, _ = parseString(keyBytes) + prefix := Marshal(&struct { + Name string + Nonce []byte + Key []byte `ssh:"rest"` + }{c.Type(), c.Nonce, keyBytes}) - for len(list) > 0 { - var name, data []byte - var ok bool - name, list, ok = parseString(list) - if !ok { - return nil, nil, false - } - data, list, ok = parseString(list) - if !ok { - return nil, nil, false - } - out = append(out, tuple{string(name), string(data)}) - } - ok = true - return + result := make([]byte, 0, len(prefix)+len(genericBytes)) + result = append(result, prefix...) + result = append(result, genericBytes...) + return result } -func signatureLength(sig *signature) int { - length := 4 // length prefix for signature - length += stringLength(len(sig.Format)) - length += stringLength(len(sig.Blob)) - return length +// Type returns the key name. It is part of the PublicKey interface. +func (c *Certificate) Type() string { + algo, ok := certAlgoNames[c.Key.Type()] + if !ok { + panic("unknown cert key type") + } + return algo } -func marshalSignature(to []byte, sig *signature) []byte { - length := uint32(signatureLength(sig) - 4) - to = marshalUint32(to, length) - to = marshalString(to, []byte(sig.Format)) - to = marshalString(to, sig.Blob) - return to +// Verify verifies a signature against the certificate's public +// key. It is part of the PublicKey interface. +func (c *Certificate) Verify(data []byte, sig *Signature) error { + return c.Key.Verify(data, sig) } -func parseSignatureBody(in []byte) (out *signature, rest []byte, ok bool) { - var format []byte - if format, in, ok = parseString(in); !ok { +func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { + format, in, ok := parseString(in) + if !ok { return } - out = &signature{ + out = &Signature{ Format: string(format), } @@ -364,14 +460,14 @@ func parseSignatureBody(in []byte) (out *signature, rest []byte, ok bool) { return out, in, ok } -func parseSignature(in []byte) (out *signature, rest []byte, ok bool) { - var sigBytes []byte - if sigBytes, rest, ok = parseString(in); !ok { +func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { + sigBytes, rest, ok := parseString(in) + if !ok { return } - out, sigBytes, ok = parseSignatureBody(sigBytes) - if !ok || len(sigBytes) > 0 { + out, trailing, ok := parseSignatureBody(sigBytes) + if !ok || len(trailing) > 0 { return nil, nil, false } return diff --git a/ssh/certs_test.go b/ssh/certs_test.go index 3cec28e..7d1b00f 100644 --- a/ssh/certs_test.go +++ b/ssh/certs_test.go @@ -6,7 +6,9 @@ package ssh import ( "bytes" + "crypto/rand" "testing" + "time" ) // Cert generated by ssh-keygen 6.0p1 Debian-4. @@ -16,16 +18,16 @@ var exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb func TestParseCert(t *testing.T) { authKeyBytes := []byte(exampleSSHCert) - key, _, _, rest, ok := ParseAuthorizedKey(authKeyBytes) - if !ok { - t.Fatalf("could not parse certificate") + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) } if len(rest) > 0 { t.Errorf("rest: got %q, want empty", rest) } - if _, ok = key.(*OpenSSHCertV01); !ok { - t.Fatalf("got %#v, want *OpenSSHCertV01", key) + if _, ok := key.(*Certificate); !ok { + t.Fatalf("got %#v, want *Certificate", key) } marshaled := MarshalAuthorizedKey(key) @@ -37,19 +39,118 @@ func TestParseCert(t *testing.T) { } } -func TestVerifyCert(t *testing.T) { - key, _, _, _, _ := ParseAuthorizedKey([]byte(exampleSSHCert)) - validCert := key.(*OpenSSHCertV01) - if ok := validateOpenSSHCertV01Signature(validCert); !ok { - t.Error("Unable to validate certificate!") +func TestValidateCert(t *testing.T) { + key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert)) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + validCert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + checker := CertChecker{} + checker.IsAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) + } + + if err := checker.CheckCert("user", validCert); err != nil { + t.Errorf("Unable to validate certificate: %v", err) + } + invalidCert := &Certificate{ + Key: testPublicKeys["rsa"], + SignatureKey: testPublicKeys["ecdsa"], + ValidBefore: CertTimeInfinity, + Signature: &Signature{}, } + if err := checker.CheckCert("user", invalidCert); err == nil { + t.Error("Invalid cert signature passed validation") + } +} + +func TestValidateCertTime(t *testing.T) { + cert := Certificate{ + ValidPrincipals: []string{"user"}, + Key: testPublicKeys["rsa"], + ValidAfter: 50, + ValidBefore: 100, + } + + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + for ts, ok := range map[int64]bool{ + 25: false, + 50: true, + 99: true, + 100: false, + 125: false, + } { + checker := CertChecker{ + Clock: func() time.Time { return time.Unix(ts, 0) }, + } + checker.IsAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), + testPublicKeys["ecdsa"].Marshal()) + } - invalidCert := &OpenSSHCertV01{ - Key: rsaKey.PublicKey(), - SignatureKey: ecdsaKey.PublicKey(), - Signature: &signature{}, + if v := checker.CheckCert("user", &cert); (v == nil) != ok { + t.Errorf("Authenticate(%d): %v", ts, v) + } } - if ok := validateOpenSSHCertV01Signature(invalidCert); ok { - t.Error("Invalid cert signature passed validation!") +} + +// TODO(hanwen): tests for +// +// host keys: +// * fallbacks + +func TestHostKeyCert(t *testing.T) { + cert := &Certificate{ + ValidPrincipals: []string{"hostname", "hostname.domain"}, + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: HostCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + checker := &CertChecker{ + IsAuthority: func(p PublicKey) bool { + return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) + }, + } + + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Errorf("NewCertSigner: %v", err) + } + + for _, name := range []string{"hostname", "otherhost"} { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go func() { + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(certSigner) + _, _, _, err := NewServerConn(c1, &conf) + if err != nil { + t.Fatalf("NewServerConn: %v", err) + } + }() + + config := &ClientConfig{ + User: "user", + HostKeyCallback: checker.CheckHostKey, + } + _, _, _, err = NewClientConn(c2, name, config) + + succeed := name == "hostname" + if (err == nil) != succeed { + t.Fatalf("NewClientConn(%q): %v", name, err) + } } } diff --git a/ssh/channel.go b/ssh/channel.go index c5413c9..8e777bb 100644 --- a/ssh/channel.go +++ b/ssh/channel.go @@ -5,71 +5,100 @@ package ssh import ( + "encoding/binary" "errors" "fmt" "io" + "log" "sync" - "sync/atomic" ) -// extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254, -// section 5.2. -type extendedDataTypeCode uint32 - const ( - // extendedDataStderr is the extended data type that is used for stderr. - extendedDataStderr extendedDataTypeCode = 1 - - // minPacketLength defines the smallest valid packet minPacketLength = 9 - - // channelMaxPacketSize defines the maximum packet size advertised in open messages - channelMaxPacketSize = 1 << 15 // RFC 4253 6.1, minimum 32 KiB - - // channelWindowSize defines the window size advertised in open messages - channelWindowSize = 64 * channelMaxPacketSize // Like OpenSSH + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket ) -// A Channel is an ordered, reliable, duplex stream that is multiplexed over an -// SSH connection. Channel.Read can return a ChannelRequest as an error. -type Channel interface { - // Accept accepts the channel creation request. - Accept() error - // Reject rejects the channel creation request. After calling this, no - // other methods on the Channel may be called. If they are then the - // peer is likely to signal a protocol error and drop the connection. - Reject(reason RejectionReason, message string) error +// NewChannel represents an incoming request to a channel. It must either be +// accepted for use by calling Accept, or rejected by calling Reject. +type NewChannel interface { + // Accept accepts the channel creation request. It returns the Channel + // and a Go channel containing SSH requests. The Go channel must be + // serviced otherwise the Channel will hang. + Accept() (Channel, <-chan *Request, error) - // Read may return a ChannelRequest as an error. - Read(data []byte) (int, error) - Write(data []byte) (int, error) - Close() error - - // Stderr returns an io.Writer that writes to this channel with the - // extended data type set to stderr. - Stderr() io.Writer - - // AckRequest either sends an ack or nack to the channel request. - AckRequest(ok bool) error + // Reject rejects the channel creation request. After calling + // this, no other methods on the Channel may be called. + Reject(reason RejectionReason, message string) error // ChannelType returns the type of the channel, as supplied by the // client. ChannelType() string + // ExtraData returns the arbitrary payload for this channel, as supplied // by the client. This data is specific to the channel type. ExtraData() []byte } -// ChannelRequest represents a request sent on a channel, outside of the normal -// stream of bytes. It may result from calling Read on a Channel. -type ChannelRequest struct { - Request string +// A Channel is an ordered, reliable, flow-controlled, duplex stream +// that is multiplexed over an SSH connection. +type Channel interface { + // Read reads up to len(data) bytes from the channel. + Read(data []byte) (int, error) + + // Write writes len(data) bytes to the channel. + Write(data []byte) (int, error) + + // Close signals end of channel use. No data may be sent after this + // call. + Close() error + + // CloseWrite signals the end of sending in-band + // data. Requests may still be sent, and the other side may + // still send data + CloseWrite() error + + // SendRequest sends a channel request. If wantReply is true, + // it will wait for a reply and return the result as a + // boolean, otherwise the return value will be false. Channel + // requests are out-of-band messages so they may be sent even + // if the data stream is closed or blocked by flow control. + SendRequest(name string, wantReply bool, payload []byte) (bool, error) + + // Stderr returns an io.ReadWriter that writes to this channel with the + // extended data type set to stderr. + Stderr() io.ReadWriter +} + +// Request is a request sent outside of the normal stream of +// data. Requests can either be specific to an SSH channel, or they +// can be global. +type Request struct { + Type string WantReply bool Payload []byte + + ch *channel + mux *mux } -func (c ChannelRequest) Error() string { - return "ssh: channel request received" +// Reply sends a response to a request. It must be called for all requests +// where WantReply is true and is a no-op otherwise. The payload argument is +// ignored for replies to channel-specific requests. +func (r *Request) Reply(ok bool, payload []byte) error { + if !r.WantReply { + return nil + } + + if r.ch == nil { + return r.mux.ackRequest(ok, payload) + } + + return r.ch.ackRequest(ok) } // RejectionReason is an enumeration used when rejecting channel creation @@ -98,497 +127,482 @@ func (r RejectionReason) String() string { return fmt.Sprintf("unknown reason %d", int(r)) } -type channel struct { - packetConn // the underlying transport - localId, remoteId uint32 - remoteWin window - maxPacket uint32 - isClosed uint32 // atomic bool, non zero if true -} - -func (c *channel) sendWindowAdj(n int) error { - msg := windowAdjustMsg{ - PeersId: c.remoteId, - AdditionalBytes: uint32(n), +func min(a uint32, b int) uint32 { + if a < uint32(b) { + return a } - return c.writePacket(marshal(msgChannelWindowAdjust, msg)) + return uint32(b) } -// sendEOF sends EOF to the remote side. RFC 4254 Section 5.3 -func (c *channel) sendEOF() error { - return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{ - PeersId: c.remoteId, - })) -} +type channelDirection uint8 -// sendClose informs the remote side of our intent to close the channel. -func (c *channel) sendClose() error { - return c.packetConn.writePacket(marshal(msgChannelClose, channelCloseMsg{ - PeersId: c.remoteId, - })) -} +const ( + channelInbound channelDirection = iota + channelOutbound +) -func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error { - reject := channelOpenFailureMsg{ - PeersId: c.remoteId, - Reason: reason, - Message: message, - Language: "en", - } - return c.writePacket(marshal(msgChannelOpenFailure, reject)) -} +// channel is an implementation of the Channel interface that works +// with the mux class. +type channel struct { + // R/O after creation + chanType string + extraData []byte + localId, remoteId uint32 -func (c *channel) writePacket(b []byte) error { - if c.closed() { - return io.EOF - } - if uint32(len(b)) > c.maxPacket { - return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket) - } - return c.packetConn.writePacket(b) -} + // maxIncomingPayload and maxRemotePayload are the maximum + // payload sizes of normal and extended data packets for + // receiving and sending, respectively. The wire packet will + // be 9 or 13 bytes larger (excluding encryption overhead). + maxIncomingPayload uint32 + maxRemotePayload uint32 -func (c *channel) closed() bool { - return atomic.LoadUint32(&c.isClosed) > 0 -} + mux *mux -func (c *channel) setClosed() bool { - return atomic.CompareAndSwapUint32(&c.isClosed, 0, 1) -} + // decided is set to true if an accept or reject message has been sent + // (for outbound channels) or received (for inbound channels). + decided bool -type serverChan struct { - channel - // immutable once created - chanType string - extraData []byte + // direction contains either channelOutbound, for channels created + // locally, or channelInbound, for channels created by the peer. + direction channelDirection - serverConn *ServerConn - myWindow uint32 - theyClosed bool // indicates the close msg has been received from the remote side - theySentEOF bool - isDead uint32 - err error + // Pending internal channel messages. + msg chan interface{} - pendingRequests []ChannelRequest - pendingData []byte - head, length int + // Since requests have no ID, there can be only one request + // with WantReply=true outstanding. This lock is held by a + // goroutine that has such an outgoing request pending. + sentRequestMu sync.Mutex - // This lock is inferior to serverConn.lock - cond *sync.Cond -} + incomingRequests chan *Request -func (c *serverChan) Accept() error { - c.serverConn.lock.Lock() - defer c.serverConn.lock.Unlock() + sentEOF bool - if c.serverConn.err != nil { - return c.serverConn.err - } + // thread-safe data + remoteWin window + pending *buffer + extPending *buffer - confirm := channelOpenConfirmMsg{ - PeersId: c.remoteId, - MyId: c.localId, - MyWindow: c.myWindow, - MaxPacketSize: c.maxPacket, - } - return c.writePacket(marshal(msgChannelOpenConfirm, confirm)) + // windowMu protects myWindow, the flow-control window. + windowMu sync.Mutex + myWindow uint32 + + // writeMu serializes calls to mux.conn.writePacket() and + // protects sentClose. This mutex must be different from + // windowMu, as writePacket can block if there is a key + // exchange pending + writeMu sync.Mutex + sentClose bool } -func (c *serverChan) Reject(reason RejectionReason, message string) error { - c.serverConn.lock.Lock() - defer c.serverConn.lock.Unlock() +// writePacket sends a packet. If the packet is a channel close, it updates +// sentClose. This method takes the lock c.writeMu. +func (c *channel) writePacket(packet []byte) error { + c.writeMu.Lock() + if c.sentClose { + c.writeMu.Unlock() + return io.EOF + } + c.sentClose = (packet[0] == msgChannelClose) + err := c.mux.conn.writePacket(packet) + c.writeMu.Unlock() + return err +} - if c.serverConn.err != nil { - return c.serverConn.err +func (c *channel) sendMessage(msg interface{}) error { + if debugMux { + log.Printf("send %d: %#v", c.mux.chanList.offset, msg) } - return c.sendChannelOpenFailure(reason, message) + p := Marshal(msg) + binary.BigEndian.PutUint32(p[1:], c.remoteId) + return c.writePacket(p) } -func (c *serverChan) handlePacket(packet interface{}) { - c.cond.L.Lock() - defer c.cond.L.Unlock() +// WriteExtended writes data to a specific extended stream. These streams are +// used, for example, for stderr. +func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { + if c.sentEOF { + return 0, io.EOF + } + // 1 byte message type, 4 bytes remoteId, 4 bytes data length + opCode := byte(msgChannelData) + headerLength := uint32(9) + if extendedCode > 0 { + headerLength += 4 + opCode = msgChannelExtendedData + } - switch packet := packet.(type) { - case *channelRequestMsg: - req := ChannelRequest{ - Request: packet.Request, - WantReply: packet.WantReply, - Payload: packet.RequestSpecificData, + for len(data) > 0 { + space := min(c.maxRemotePayload, len(data)) + if space, err = c.remoteWin.reserve(space); err != nil { + return n, err } + todo := data[:space] - c.pendingRequests = append(c.pendingRequests, req) - c.cond.Signal() - case *channelCloseMsg: - c.theyClosed = true - c.cond.Signal() - case *channelEOFMsg: - c.theySentEOF = true - c.cond.Signal() - case *windowAdjustMsg: - if !c.remoteWin.add(packet.AdditionalBytes) { - panic("illegal window update") + packet := make([]byte, headerLength+uint32(len(todo))) + packet[0] = opCode + binary.BigEndian.PutUint32(packet[1:], c.remoteId) + if extendedCode > 0 { + binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) } - default: - panic("unknown packet type") + binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) + copy(packet[headerLength:], todo) + if err = c.writePacket(packet); err != nil { + return n, err + } + + n += len(todo) + data = data[len(todo):] } + + return n, err } -func (c *serverChan) handleData(data []byte) { - c.cond.L.Lock() - defer c.cond.L.Unlock() +func (c *channel) handleData(packet []byte) error { + headerLen := 9 + isExtendedData := packet[0] == msgChannelExtendedData + if isExtendedData { + headerLen = 13 + } + if len(packet) < headerLen { + // malformed data packet + return parseError(packet[0]) + } - // The other side should never send us more than our window. - if len(data)+c.length > len(c.pendingData) { - // TODO(agl): we should tear down the channel with a protocol - // error. - return + var extended uint32 + if isExtendedData { + extended = binary.BigEndian.Uint32(packet[5:]) } - c.myWindow -= uint32(len(data)) - for i := 0; i < 2; i++ { - tail := c.head + c.length - if tail >= len(c.pendingData) { - tail -= len(c.pendingData) - } - n := copy(c.pendingData[tail:], data) - data = data[n:] - c.length += n + length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) + if length == 0 { + return nil + } + if length > c.maxIncomingPayload { + // TODO(hanwen): should send Disconnect? + return errors.New("ssh: incoming packet exceeds maximum payload size") } - c.cond.Signal() -} + data := packet[headerLen:] + if length != uint32(len(data)) { + return errors.New("ssh: wrong packet length") + } -func (c *serverChan) Stderr() io.Writer { - return extendedDataChannel{c: c, t: extendedDataStderr} -} + c.windowMu.Lock() + if c.myWindow < length { + c.windowMu.Unlock() + // TODO(hanwen): should send Disconnect with reason? + return errors.New("ssh: remote side wrote too much") + } + c.myWindow -= length + c.windowMu.Unlock() -// extendedDataChannel is an io.Writer that writes any data to c as extended -// data of the given type. -type extendedDataChannel struct { - t extendedDataTypeCode - c *serverChan + if extended == 1 { + c.extPending.write(data) + } else if extended > 0 { + // discard other extended data. + } else { + c.pending.write(data) + } + return nil } -func (edc extendedDataChannel) Write(data []byte) (n int, err error) { - const headerLength = 13 // 1 byte message type, 4 bytes remoteId, 4 bytes extended message type, 4 bytes data length - c := edc.c - for len(data) > 0 { - space := min(c.maxPacket-headerLength, len(data)) - if space, err = c.getWindowSpace(space); err != nil { - return 0, err - } - todo := data - if uint32(len(todo)) > space { - todo = todo[:space] - } +func (c *channel) adjustWindow(n uint32) error { + c.windowMu.Lock() + // Since myWindow is managed on our side, and can never exceed + // the initial window setting, we don't worry about overflow. + c.myWindow += uint32(n) + c.windowMu.Unlock() + return c.sendMessage(windowAdjustMsg{ + AdditionalBytes: uint32(n), + }) +} - packet := make([]byte, headerLength+len(todo)) - packet[0] = msgChannelExtendedData - marshalUint32(packet[1:], c.remoteId) - marshalUint32(packet[5:], uint32(edc.t)) - marshalUint32(packet[9:], uint32(len(todo))) - copy(packet[13:], todo) +func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { + switch extended { + case 1: + n, err = c.extPending.Read(data) + case 0: + n, err = c.pending.Read(data) + default: + return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) + } - if err = c.writePacket(packet); err != nil { - return + if n > 0 { + err = c.adjustWindow(uint32(n)) + // sendWindowAdjust can return io.EOF if the remote + // peer has closed the connection, however we want to + // defer forwarding io.EOF to the caller of Read until + // the buffer has been drained. + if n > 0 && err == io.EOF { + err = nil } - - n += len(todo) - data = data[len(todo):] } - return + return n, err } -func (c *serverChan) Read(data []byte) (n int, err error) { - n, err, windowAdjustment := c.read(data) +func (c *channel) close() { + c.pending.eof() + c.extPending.eof() + close(c.msg) + close(c.incomingRequests) + c.writeMu.Lock() + // This is not necesary for a normal channel teardown, but if + // there was another error, it is. + c.sentClose = true + c.writeMu.Unlock() + // Unblock writers. + c.remoteWin.close() +} - if windowAdjustment > 0 { - packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{ - PeersId: c.remoteId, - AdditionalBytes: windowAdjustment, - }) - err = c.writePacket(packet) +// responseMessageReceived is called when a success or failure message is +// received on a channel to check that such a message is reasonable for the +// given channel. +func (c *channel) responseMessageReceived() error { + if c.direction == channelInbound { + return errors.New("ssh: channel response message received on inbound channel") } - - return + if c.decided { + return errors.New("ssh: duplicate response received for channel") + } + c.decided = true + return nil } -func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) { - c.cond.L.Lock() - defer c.cond.L.Unlock() +func (c *channel) handlePacket(packet []byte) error { + switch packet[0] { + case msgChannelData, msgChannelExtendedData: + return c.handleData(packet) + case msgChannelClose: + c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) + c.mux.chanList.remove(c.localId) + c.close() + return nil + case msgChannelEOF: + // RFC 4254 is mute on how EOF affects dataExt messages but + // it is logical to signal EOF at the same time. + c.extPending.eof() + c.pending.eof() + return nil + } - if c.err != nil { - return 0, c.err, 0 + decoded, err := decode(packet) + if err != nil { + return err } - for { - if c.theySentEOF || c.theyClosed || c.dead() { - return 0, io.EOF, 0 + switch msg := decoded.(type) { + case *channelOpenFailureMsg: + if err := c.responseMessageReceived(); err != nil { + return err } - - if len(c.pendingRequests) > 0 { - req := c.pendingRequests[0] - if len(c.pendingRequests) == 1 { - c.pendingRequests = nil - } else { - oldPendingRequests := c.pendingRequests - c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1) - copy(c.pendingRequests, oldPendingRequests[1:]) - } - - return 0, req, 0 + c.mux.chanList.remove(msg.PeersId) + c.msg <- msg + case *channelOpenConfirmMsg: + if err := c.responseMessageReceived(); err != nil { + return err } - - if c.length > 0 { - tail := min(uint32(c.head+c.length), len(c.pendingData)) - n = copy(data, c.pendingData[c.head:tail]) - c.head += n - c.length -= n - if c.head == len(c.pendingData) { - c.head = 0 - } - - windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow - if windowAdjustment < uint32(len(c.pendingData)/2) { - windowAdjustment = 0 - } - c.myWindow += windowAdjustment - - return + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) + } + c.remoteId = msg.MyId + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.MyWindow) + c.msg <- msg + case *windowAdjustMsg: + if !c.remoteWin.add(msg.AdditionalBytes) { + return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) + } + case *channelRequestMsg: + req := Request{ + Type: msg.Request, + WantReply: msg.WantReply, + Payload: msg.RequestSpecificData, + ch: c, } - c.cond.Wait() + c.incomingRequests <- &req + default: + c.msg <- msg } - - panic("unreachable") + return nil } -// getWindowSpace takes, at most, max bytes of space from the peer's window. It -// returns the number of bytes actually reserved. -func (c *serverChan) getWindowSpace(max uint32) (uint32, error) { - if c.dead() || c.closed() { - return 0, io.EOF +func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { + ch := &channel{ + remoteWin: window{Cond: newCond()}, + myWindow: channelWindowSize, + pending: newBuffer(), + extPending: newBuffer(), + direction: direction, + incomingRequests: make(chan *Request, 16), + msg: make(chan interface{}, 16), + chanType: chanType, + extraData: extraData, + mux: m, } - return c.remoteWin.reserve(max), nil + ch.localId = m.chanList.add(ch) + return ch } -func (c *serverChan) dead() bool { - return atomic.LoadUint32(&c.isDead) > 0 -} +var errUndecided = errors.New("ssh: must Accept or Reject channel") +var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") -func (c *serverChan) setDead() { - atomic.StoreUint32(&c.isDead, 1) +type extChannel struct { + code uint32 + ch *channel } -func (c *serverChan) Write(data []byte) (n int, err error) { - const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length - for len(data) > 0 { - space := min(c.maxPacket-headerLength, len(data)) - if space, err = c.getWindowSpace(space); err != nil { - return 0, err - } - todo := data - if uint32(len(todo)) > space { - todo = todo[:space] - } - - packet := make([]byte, headerLength+len(todo)) - packet[0] = msgChannelData - marshalUint32(packet[1:], c.remoteId) - marshalUint32(packet[5:], uint32(len(todo))) - copy(packet[9:], todo) +func (e *extChannel) Write(data []byte) (n int, err error) { + return e.ch.WriteExtended(data, e.code) +} - if err = c.writePacket(packet); err != nil { - return - } +func (e *extChannel) Read(data []byte) (n int, err error) { + return e.ch.ReadExtended(data, e.code) +} - n += len(todo) - data = data[len(todo):] +func (c *channel) Accept() (Channel, <-chan *Request, error) { + if c.decided { + return nil, nil, errDecidedAlready + } + c.maxIncomingPayload = channelMaxPacket + confirm := channelOpenConfirmMsg{ + PeersId: c.remoteId, + MyId: c.localId, + MyWindow: c.myWindow, + MaxPacketSize: c.maxIncomingPayload, + } + c.decided = true + if err := c.sendMessage(confirm); err != nil { + return nil, nil, err } - return + return c, c.incomingRequests, nil } -// Close signals the intent to close the channel. -func (c *serverChan) Close() error { - c.serverConn.lock.Lock() - defer c.serverConn.lock.Unlock() - - if c.serverConn.err != nil { - return c.serverConn.err +func (ch *channel) Reject(reason RejectionReason, message string) error { + if ch.decided { + return errDecidedAlready } - - if !c.setClosed() { - return errors.New("ssh: channel already closed") + reject := channelOpenFailureMsg{ + PeersId: ch.remoteId, + Reason: reason, + Message: message, + Language: "en", } - return c.sendClose() + ch.decided = true + return ch.sendMessage(reject) } -func (c *serverChan) AckRequest(ok bool) error { - c.serverConn.lock.Lock() - defer c.serverConn.lock.Unlock() - - if c.serverConn.err != nil { - return c.serverConn.err +func (ch *channel) Read(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided } + return ch.ReadExtended(data, 0) +} - if !ok { - ack := channelRequestFailureMsg{ - PeersId: c.remoteId, - } - return c.writePacket(marshal(msgChannelFailure, ack)) +func (ch *channel) Write(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided } + return ch.WriteExtended(data, 0) +} - ack := channelRequestSuccessMsg{ - PeersId: c.remoteId, +func (ch *channel) CloseWrite() error { + if !ch.decided { + return errUndecided } - return c.writePacket(marshal(msgChannelSuccess, ack)) + ch.sentEOF = true + return ch.sendMessage(channelEOFMsg{ + PeersId: ch.remoteId}) } -func (c *serverChan) ChannelType() string { - return c.chanType +func (ch *channel) Close() error { + if !ch.decided { + return errUndecided + } + + return ch.sendMessage(channelCloseMsg{ + PeersId: ch.remoteId}) } -func (c *serverChan) ExtraData() []byte { - return c.extraData +// Extended returns an io.ReadWriter that sends and receives data on the given, +// SSH extended stream. Such streams are used, for example, for stderr. +func (ch *channel) Extended(code uint32) io.ReadWriter { + if !ch.decided { + return nil + } + return &extChannel{code, ch} } -// A clientChan represents a single RFC 4254 channel multiplexed -// over a SSH connection. -type clientChan struct { - channel - stdin *chanWriter - stdout *chanReader - stderr *chanReader - msg chan interface{} +func (ch *channel) Stderr() io.ReadWriter { + return ch.Extended(1) } -// newClientChan returns a partially constructed *clientChan -// using the local id provided. To be usable clientChan.remoteId -// needs to be assigned once known. -func newClientChan(cc packetConn, id uint32) *clientChan { - c := &clientChan{ - channel: channel{ - packetConn: cc, - localId: id, - remoteWin: window{Cond: newCond()}, - }, - msg: make(chan interface{}, 16), - } - c.stdin = &chanWriter{ - channel: &c.channel, - } - c.stdout = &chanReader{ - channel: &c.channel, - buffer: newBuffer(), - } - c.stderr = &chanReader{ - channel: &c.channel, - buffer: newBuffer(), +func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + if !ch.decided { + return false, errUndecided } - return c -} -// waitForChannelOpenResponse, if successful, fills out -// the remoteId and records any initial window advertisement. -func (c *clientChan) waitForChannelOpenResponse() error { - switch msg := (<-c.msg).(type) { - case *channelOpenConfirmMsg: - if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { - return errors.New("ssh: invalid MaxPacketSize from peer") - } - // fixup remoteId field - c.remoteId = msg.MyId - c.maxPacket = msg.MaxPacketSize - c.remoteWin.add(msg.MyWindow) - return nil - case *channelOpenFailureMsg: - return errors.New(safeString(msg.Message)) + if wantReply { + ch.sentRequestMu.Lock() + defer ch.sentRequestMu.Unlock() } - return errors.New("ssh: unexpected packet") -} -// Close signals the intent to close the channel. -func (c *clientChan) Close() error { - if !c.setClosed() { - return errors.New("ssh: channel already closed") + msg := channelRequestMsg{ + PeersId: ch.remoteId, + Request: name, + WantReply: wantReply, + RequestSpecificData: payload, } - c.stdout.eof() - c.stderr.eof() - return c.sendClose() -} -// A chanWriter represents the stdin of a remote process. -type chanWriter struct { - *channel - // indicates the writer has been closed. eof is owned by the - // caller of Write/Close. - eof bool -} + if err := ch.sendMessage(msg); err != nil { + return false, err + } -// Write writes data to the remote process's standard input. -func (w *chanWriter) Write(data []byte) (written int, err error) { - const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length - for len(data) > 0 { - if w.eof || w.closed() { - err = io.EOF - return - } - // never send more data than maxPacket even if - // there is sufficient window. - n := min(w.maxPacket-headerLength, len(data)) - r := w.remoteWin.reserve(n) - n = r - remoteId := w.remoteId - packet := []byte{ - msgChannelData, - byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId), - byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n), + if wantReply { + m, ok := (<-ch.msg) + if !ok { + return false, io.EOF } - if err = w.writePacket(append(packet, data[:n]...)); err != nil { - break + switch m.(type) { + case *channelRequestFailureMsg: + return false, nil + case *channelRequestSuccessMsg: + return true, nil + default: + return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) } - data = data[n:] - written += int(n) } - return + + return false, nil } -func min(a uint32, b int) uint32 { - if a < uint32(b) { - return a +// ackRequest either sends an ack or nack to the channel request. +func (ch *channel) ackRequest(ok bool) error { + if !ch.decided { + return errUndecided } - return uint32(b) -} -func (w *chanWriter) Close() error { - w.eof = true - return w.sendEOF() + var msg interface{} + if !ok { + msg = channelRequestFailureMsg{ + PeersId: ch.remoteId, + } + } else { + msg = channelRequestSuccessMsg{ + PeersId: ch.remoteId, + } + } + return ch.sendMessage(msg) } -// A chanReader represents stdout or stderr of a remote process. -type chanReader struct { - *channel // the channel backing this reader - *buffer +func (ch *channel) ChannelType() string { + return ch.chanType } -// Read reads data from the remote process's stdout or stderr. -func (r *chanReader) Read(buf []byte) (int, error) { - n, err := r.buffer.Read(buf) - if err != nil { - if err == io.EOF { - return n, err - } - return 0, err - } - err = r.sendWindowAdj(n) - if err == io.EOF && n > 0 { - // sendWindowAdjust can return io.EOF if the remote peer has - // closed the connection, however we want to defer forwarding io.EOF to the - // caller of Read until the buffer has been drained. - err = nil - } - return n, err +func (ch *channel) ExtraData() []byte { + return ch.extraData } diff --git a/ssh/cipher.go b/ssh/cipher.go index bc2e983..a58f10b 100644 --- a/ssh/cipher.go +++ b/ssh/cipher.go @@ -8,11 +8,28 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rc4" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" ) -// streamDump is used to dump the initial keystream for stream ciphers. It is a -// a write-only buffer, and not intended for reading so do not require a mutex. -var streamDump [512]byte +const ( + packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. + + // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations + // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC + // indicates implementations SHOULD be able to handle larger packet sizes, but then + // waffles on about reasonable limits. + // + // OpenSSH caps their maxPacket at 256kB so we choose to do + // the same. maxPacket is also used to ensure that uint32 + // length fields do not overflow, so it should remain well + // below 4G. + maxPacket = 256 * 1024 +) // noneCipher implements cipher.Stream and provides no encryption. It is used // by the transport before the first key-exchange. @@ -34,14 +51,14 @@ func newRC4(key, iv []byte) (cipher.Stream, error) { return rc4.NewCipher(key) } -type cipherMode struct { +type streamCipherMode struct { keySize int ivSize int skip int createFunc func(key, iv []byte) (cipher.Stream, error) } -func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) { +func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) { if len(key) < c.keySize { panic("ssh: key length too small for cipher") } @@ -54,6 +71,11 @@ func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) { return nil, err } + var streamDump []byte + if c.skip > 0 { + streamDump = make([]byte, 512) + } + for remainingToDump := c.skip; remainingToDump > 0; { dumpThisTime := remainingToDump if dumpThisTime > len(streamDump) { @@ -66,18 +88,10 @@ func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) { return stream, nil } -// Specifies a default set of ciphers and a preference order. This is based on -// OpenSSH's default client preference order, minus algorithms that are not -// implemented. -var DefaultCipherOrder = []string{ - "aes128-ctr", "aes192-ctr", "aes256-ctr", - "arcfour256", "arcfour128", -} - // cipherModes documents properties of supported ciphers. Ciphers not included // are not supported and will not be negotiated, even if explicitly requested in // ClientConfig.Crypto.Ciphers. -var cipherModes = map[string]*cipherMode{ +var cipherModes = map[string]*streamCipherMode{ // Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms // are defined in the order specified in the RFC. "aes128-ctr": {16, aes.BlockSize, 0, newAESCTR}, @@ -88,13 +102,237 @@ var cipherModes = map[string]*cipherMode{ // They are defined in the order specified in the RFC. "arcfour128": {16, 0, 1536, newRC4}, "arcfour256": {32, 0, 1536, newRC4}, + + // AES-GCM is not a stream cipher, so it is constructed with a + // special case. If we add any more non-stream ciphers, we + // should invest a cleaner way to do this. + gcmCipherID: {16, 12, 0, nil}, +} + +// prefixLen is the length of the packet prefix that contains the packet length +// and number of padding bytes. +const prefixLen = 5 + +// streamPacketCipher is a packetCipher using a stream cipher. +type streamPacketCipher struct { + mac hash.Hash + cipher cipher.Stream + + // The following members are to avoid per-packet allocations. + prefix [prefixLen]byte + seqNumBytes [4]byte + padding [2 * packetSizeMultiple]byte + packetData []byte + macResult []byte +} + +// readPacket reads and decrypt a single packet from the reader argument. +func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, s.prefix[:]); err != nil { + return nil, err + } + + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + length := binary.BigEndian.Uint32(s.prefix[0:4]) + paddingLength := uint32(s.prefix[4]) + + var macSize uint32 + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + s.mac.Write(s.prefix[:]) + macSize = uint32(s.mac.Size()) + } + + if length <= paddingLength+1 { + return nil, errors.New("ssh: invalid packet length, packet too small") + } + + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + // the maxPacket check above ensures that length-1+macSize + // does not overflow. + if uint32(cap(s.packetData)) < length-1+macSize { + s.packetData = make([]byte, length-1+macSize) + } else { + s.packetData = s.packetData[:length-1+macSize] + } + + if _, err := io.ReadFull(r, s.packetData); err != nil { + return nil, err + } + mac := s.packetData[length-1:] + data := s.packetData[:length-1] + s.cipher.XORKeyStream(data, data) + + if s.mac != nil { + s.mac.Write(data) + s.macResult = s.mac.Sum(s.macResult[:0]) + if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { + return nil, errors.New("ssh: MAC failure") + } + } + + return s.packetData[:length-paddingLength-1], nil +} + +// writePacket encrypts and sends a packet of data to the writer argument +func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + if len(packet) > maxPacket { + return errors.New("ssh: packet too large") + } + + paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple + if paddingLength < 4 { + paddingLength += packetSizeMultiple + } + + length := len(packet) + 1 + paddingLength + binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) + s.prefix[4] = byte(paddingLength) + padding := s.padding[:paddingLength] + if _, err := io.ReadFull(rand, padding); err != nil { + return err + } + + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + s.mac.Write(s.prefix[:]) + s.mac.Write(packet) + s.mac.Write(padding) + } + + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + s.cipher.XORKeyStream(packet, packet) + s.cipher.XORKeyStream(padding, padding) + + if _, err := w.Write(s.prefix[:]); err != nil { + return err + } + if _, err := w.Write(packet); err != nil { + return err + } + if _, err := w.Write(padding); err != nil { + return err + } + + if s.mac != nil { + s.macResult = s.mac.Sum(s.macResult[:0]) + if _, err := w.Write(s.macResult); err != nil { + return err + } + } + + return nil +} + +type gcmCipher struct { + aead cipher.AEAD + prefix [4]byte + iv []byte + buf []byte +} + +func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + return &gcmCipher{ + aead: aead, + iv: iv, + }, nil } -// defaultKeyExchangeOrder specifies a default set of key exchange algorithms -// with preferences. -var defaultKeyExchangeOrder = []string{ - // P384 and P521 are not constant-time yet, but since we don't - // reuse ephemeral keys, using them for ECDH should be OK. - kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, - kexAlgoDH14SHA1, kexAlgoDH1SHA1, +const gcmTagSize = 16 + +func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + // Pad out to multiple of 16 bytes. This is different from the + // stream cipher because that encrypts the length too. + padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) + if padding < 4 { + padding += packetSizeMultiple + } + + length := uint32(len(packet) + int(padding) + 1) + binary.BigEndian.PutUint32(c.prefix[:], length) + if _, err := w.Write(c.prefix[:]); err != nil { + return err + } + + if cap(c.buf) < int(length) { + c.buf = make([]byte, length) + } else { + c.buf = c.buf[:length] + } + + c.buf[0] = padding + copy(c.buf[1:], packet) + if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { + return err + } + c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if _, err := w.Write(c.buf); err != nil { + return err + } + c.incIV() + + return nil +} + +func (c *gcmCipher) incIV() { + for i := 4 + 7; i >= 4; i-- { + c.iv[i]++ + if c.iv[i] != 0 { + break + } + } +} + +func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, c.prefix[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(c.prefix[:]) + if length > maxPacket { + return nil, errors.New("ssh: max packet length exceeded.") + } + + if cap(c.buf) < int(length+gcmTagSize) { + c.buf = make([]byte, length+gcmTagSize) + } else { + c.buf = c.buf[:length+gcmTagSize] + } + + if _, err := io.ReadFull(r, c.buf); err != nil { + return nil, err + } + + plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if err != nil { + return nil, err + } + c.incIV() + + padding := plain[0] + if padding < 4 || padding >= 20 { + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding+1) >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + plain = plain[1 : length-uint32(padding)] + return plain, nil } diff --git a/ssh/cipher_test.go b/ssh/cipher_test.go index ea27bd8..e279af0 100644 --- a/ssh/cipher_test.go +++ b/ssh/cipher_test.go @@ -6,57 +6,54 @@ package ssh import ( "bytes" + "crypto" + "crypto/rand" "testing" ) -// TestCipherReversal tests that each cipher factory produces ciphers that can -// encrypt and decrypt some data successfully. -func TestCipherReversal(t *testing.T) { - testData := []byte("abcdefghijklmnopqrstuvwxyz012345") - testKey := []byte("AbCdEfGhIjKlMnOpQrStUvWxYz012345") - testIv := []byte("sdflkjhsadflkjhasdflkjhsadfklhsa") - - cryptBuffer := make([]byte, 32) +func TestDefaultCiphersExist(t *testing.T) { + for _, cipherAlgo := range supportedCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("default cipher %q is unknown", cipherAlgo) + } + } +} - for name, cipherMode := range cipherModes { - encrypter, err := cipherMode.createCipher(testKey, testIv) +func TestPacketCiphers(t *testing.T) { + for cipher := range cipherModes { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: cipher, + MAC: "hmac-sha1", + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) if err != nil { - t.Errorf("failed to create encrypter for %q: %s", name, err) + t.Errorf("newPacketCipher(client, %q): %v", cipher, err) continue } - decrypter, err := cipherMode.createCipher(testKey, testIv) + server, err := newPacketCipher(clientKeys, algs, kr) if err != nil { - t.Errorf("failed to create decrypter for %q: %s", name, err) + t.Errorf("newPacketCipher(client, %q): %v", cipher, err) continue } - copy(cryptBuffer, testData) - - encrypter.XORKeyStream(cryptBuffer, cryptBuffer) - if name == "none" { - if !bytes.Equal(cryptBuffer, testData) { - t.Errorf("encryption made change with 'none' cipher") - continue - } - } else { - if bytes.Equal(cryptBuffer, testData) { - t.Errorf("encryption made no change with %q", name) - continue - } + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writePacket(0, buf, rand.Reader, input); err != nil { + t.Errorf("writePacket(%q): %v", cipher, err) + continue } - decrypter.XORKeyStream(cryptBuffer, cryptBuffer) - if !bytes.Equal(cryptBuffer, testData) { - t.Errorf("decrypted bytes not equal to input with %q", name) + packet, err := server.readPacket(0, buf) + if err != nil { + t.Errorf("readPacket(%q): %v", cipher, err) continue } - } -} -func TestDefaultCiphersExist(t *testing.T) { - for _, cipherAlgo := range DefaultCipherOrder { - if _, ok := cipherModes[cipherAlgo]; !ok { - t.Errorf("default cipher %q is unknown", cipherAlgo) + if string(packet) != want { + t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want) } } } diff --git a/ssh/client.go b/ssh/client.go index e2d2557..a8d5235 100644 --- a/ssh/client.go +++ b/ssh/client.go @@ -5,403 +5,158 @@ package ssh import ( - "crypto/rand" - "encoding/binary" "errors" "fmt" - "io" "net" "sync" ) -// ClientConn represents the client side of an SSH connection. -type ClientConn struct { - transport *transport - config *ClientConfig - chanList // channels associated with this connection - forwardList // forwarded tcpip connections from the remote side - globalRequest +// Client implements a traditional SSH client that supports shells, +// subprocesses, port forwarding and tunneled dialing. +type Client struct { + Conn - // Address as passed to the Dial function. - dialAddress string - - serverVersion string -} - -type globalRequest struct { - sync.Mutex - response chan interface{} + forwards forwardList // forwarded tcpip connections from the remote side + mu sync.Mutex + channelHandlers map[string]chan NewChannel } -// Client returns a new SSH client connection using c as the underlying transport. -func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) { - return clientWithAddress(c, "", config) -} - -func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) { - conn := &ClientConn{ - transport: newTransport(c, config.rand(), true /* is client */), - config: config, - globalRequest: globalRequest{response: make(chan interface{}, 1)}, - dialAddress: addr, +// HandleChannelOpen returns a channel on which NewChannel requests +// for the given type are sent. If the type already is being handled, +// nil is returned. The channel is closed when the connection is closed. +func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { + c.mu.Lock() + defer c.mu.Unlock() + if c.channelHandlers == nil { + // The SSH channel has been closed. + c := make(chan NewChannel) + close(c) + return c } - if err := conn.handshake(); err != nil { - conn.transport.Close() - return nil, fmt.Errorf("handshake failed: %v", err) - } - go conn.mainLoop() - return conn, nil -} - -// Close closes the connection. -func (c *ClientConn) Close() error { return c.transport.Close() } - -// LocalAddr returns the local network address. -func (c *ClientConn) LocalAddr() net.Addr { return c.transport.LocalAddr() } - -// RemoteAddr returns the remote network address. -func (c *ClientConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() } - -// handshake performs the client side key exchange. See RFC 4253 Section 7. -func (c *ClientConn) handshake() error { - clientVersion := []byte(packageVersion) - if c.config.ClientVersion != "" { - clientVersion = []byte(c.config.ClientVersion) + ch := c.channelHandlers[channelType] + if ch != nil { + return nil } - serverVersion, err := exchangeVersions(c.transport.Conn, clientVersion) - if err != nil { - return err - } - c.serverVersion = string(serverVersion) + ch = make(chan NewChannel, 16) + c.channelHandlers[channelType] = ch + return ch +} - clientKexInit := kexInitMsg{ - KexAlgos: c.config.Crypto.kexes(), - ServerHostKeyAlgos: supportedHostKeyAlgos, - CiphersClientServer: c.config.Crypto.ciphers(), - CiphersServerClient: c.config.Crypto.ciphers(), - MACsClientServer: c.config.Crypto.macs(), - MACsServerClient: c.config.Crypto.macs(), - CompressionClientServer: supportedCompressions, - CompressionServerClient: supportedCompressions, - } - kexInitPacket := marshal(msgKexInit, clientKexInit) - if err := c.transport.writePacket(kexInitPacket); err != nil { - return err - } - packet, err := c.transport.readPacket() - if err != nil { - return err +// NewClient creates a Client on top of the given connection. +func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { + conn := &Client{ + Conn: c, + channelHandlers: make(map[string]chan NewChannel, 1), } - var serverKexInit kexInitMsg - if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil { - return err - } + go conn.handleGlobalRequests(reqs) + go conn.handleChannelOpens(chans) + go func() { + conn.Wait() + conn.forwards.closeAll() + }() + go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) + return conn +} - algs := findAgreedAlgorithms(&clientKexInit, &serverKexInit) - if algs == nil { - return errors.New("ssh: no common algorithms") +// NewClientConn establishes an authenticated SSH connection using c +// as the underlying transport. The Request and NewChannel channels +// must be serviced or the connection will hang. +func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + conn := &connection{ + sshConn: sshConn{conn: c}, } - if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] { - // The server sent a Kex message for the wrong algorithm, - // which we have to ignore. - if _, err := c.transport.readPacket(); err != nil { - return err - } + if err := conn.clientHandshake(addr, &fullConf); err != nil { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) } + conn.mux = newMux(conn.transport) + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} - kex, ok := kexAlgoMap[algs.kex] - if !ok { - return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) +// clientHandshake performs the client side key exchange. See RFC 4253 Section +// 7. +func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { + c.clientVersion = []byte(packageVersion) + if config.ClientVersion != "" { + c.clientVersion = []byte(config.ClientVersion) } - magics := handshakeMagics{ - clientVersion: clientVersion, - serverVersion: serverVersion, - clientKexInit: kexInitPacket, - serverKexInit: packet, - } - result, err := kex.Client(c.transport, c.config.rand(), &magics) + var err error + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) if err != nil { return err } - err = verifyHostKeySignature(algs.hostKey, result.HostKey, result.H, result.Signature) - if err != nil { + c.transport = newClientTransport( + newTransport(c.sshConn.conn, config.Rand, true /* is client */), + c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) + if err := c.transport.requestKeyChange(); err != nil { return err } - if checker := c.config.HostKeyChecker; checker != nil { - err = checker.Check(c.dialAddress, c.transport.RemoteAddr(), algs.hostKey, result.HostKey) - if err != nil { - return err - } - } - - c.transport.prepareKeyChange(algs, result) - - if err = c.transport.writePacket([]byte{msgNewKeys}); err != nil { - return err - } - if packet, err = c.transport.readPacket(); err != nil { + if packet, err := c.transport.readPacket(); err != nil { return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) } - if packet[0] != msgNewKeys { - return UnexpectedMessageError{msgNewKeys, packet[0]} - } - return c.authenticate() + return c.clientAuthenticate(config) } -// Verify the host key obtained in the key exchange. -func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte, signature []byte) error { - hostKey, rest, ok := ParsePublicKey(hostKeyBytes) - if len(rest) > 0 || !ok { - return errors.New("ssh: could not parse hostkey") - } - - sig, rest, ok := parseSignatureBody(signature) +// verifyHostKeySignature verifies the host key obtained in the key +// exchange. +func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error { + sig, rest, ok := parseSignatureBody(result.Signature) if len(rest) > 0 || !ok { return errors.New("ssh: signature parse error") } - if sig.Format != hostKeyAlgo { - return fmt.Errorf("ssh: unexpected signature type %q", sig.Format) - } - if !hostKey.Verify(data, sig.Blob) { - return errors.New("ssh: host key signature error") - } - return nil + return hostKey.Verify(result.H, sig) } -// mainLoop reads incoming messages and routes channel messages -// to their respective ClientChans. -func (c *ClientConn) mainLoop() { - defer func() { - c.transport.Close() - c.chanList.closeAll() - c.forwardList.closeAll() - }() - - for { - packet, err := c.transport.readPacket() - if err != nil { - break - } - // TODO(dfc) A note on blocking channel use. - // The msg, data and dataExt channels of a clientChan can - // cause this loop to block indefinitely if the consumer does - // not service them. - switch packet[0] { - case msgChannelData: - if len(packet) < 9 { - // malformed data packet - return - } - remoteId := binary.BigEndian.Uint32(packet[1:5]) - length := binary.BigEndian.Uint32(packet[5:9]) - packet = packet[9:] - - if length != uint32(len(packet)) { - return - } - ch, ok := c.getChan(remoteId) - if !ok { - return - } - ch.stdout.write(packet) - case msgChannelExtendedData: - if len(packet) < 13 { - // malformed data packet - return - } - remoteId := binary.BigEndian.Uint32(packet[1:5]) - datatype := binary.BigEndian.Uint32(packet[5:9]) - length := binary.BigEndian.Uint32(packet[9:13]) - packet = packet[13:] - - if length != uint32(len(packet)) { - return - } - // RFC 4254 5.2 defines data_type_code 1 to be data destined - // for stderr on interactive sessions. Other data types are - // silently discarded. - if datatype == 1 { - ch, ok := c.getChan(remoteId) - if !ok { - return - } - ch.stderr.write(packet) - } - default: - decoded, err := decode(packet) - if err != nil { - if _, ok := err.(UnexpectedMessageError); ok { - fmt.Printf("mainLoop: unexpected message: %v\n", err) - continue - } - return - } - switch msg := decoded.(type) { - case *channelOpenMsg: - c.handleChanOpen(msg) - case *channelOpenConfirmMsg: - ch, ok := c.getChan(msg.PeersId) - if !ok { - return - } - ch.msg <- msg - case *channelOpenFailureMsg: - ch, ok := c.getChan(msg.PeersId) - if !ok { - return - } - ch.msg <- msg - case *channelCloseMsg: - ch, ok := c.getChan(msg.PeersId) - if !ok { - return - } - ch.Close() - close(ch.msg) - c.chanList.remove(msg.PeersId) - case *channelEOFMsg: - ch, ok := c.getChan(msg.PeersId) - if !ok { - return - } - ch.stdout.eof() - // RFC 4254 is mute on how EOF affects dataExt messages but - // it is logical to signal EOF at the same time. - ch.stderr.eof() - case *channelRequestSuccessMsg: - ch, ok := c.getChan(msg.PeersId) - if !ok { - return - } - ch.msg <- msg - case *channelRequestFailureMsg: - ch, ok := c.getChan(msg.PeersId) - if !ok { - return - } - ch.msg <- msg - case *channelRequestMsg: - ch, ok := c.getChan(msg.PeersId) - if !ok { - return - } - ch.msg <- msg - case *windowAdjustMsg: - ch, ok := c.getChan(msg.PeersId) - if !ok { - return - } - if !ch.remoteWin.add(msg.AdditionalBytes) { - // invalid window update - return - } - case *globalRequestMsg: - // This handles keepalive messages and matches - // the behaviour of OpenSSH. - if msg.WantReply { - c.transport.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{})) - } - case *globalRequestSuccessMsg, *globalRequestFailureMsg: - c.globalRequest.response <- msg - case *disconnectMsg: - return - default: - fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg) - } - } +// NewSession opens a new Session for this client. (A session is a remote +// execution of a program.) +func (c *Client) NewSession() (*Session, error) { + ch, in, err := c.OpenChannel("session", nil) + if err != nil { + return nil, err } + return newSession(ch, in) } -// Handle channel open messages from the remote side. -func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) { - if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { - c.sendConnectionFailed(msg.PeersId) +func (c *Client) handleGlobalRequests(incoming <-chan *Request) { + for r := range incoming { + // This handles keepalive messages and matches + // the behaviour of OpenSSH. + r.Reply(false, nil) } +} - switch msg.ChanType { - case "forwarded-tcpip": - laddr, rest, ok := parseTCPAddr(msg.TypeSpecificData) - if !ok { - // invalid request - c.sendConnectionFailed(msg.PeersId) - return - } - - l, ok := c.forwardList.lookup(*laddr) - if !ok { - // TODO: print on a more structured log. - fmt.Println("could not find forward list entry for", laddr) - // Section 7.2, implementations MUST reject spurious incoming - // connections. - c.sendConnectionFailed(msg.PeersId) - return - } - raddr, rest, ok := parseTCPAddr(rest) - if !ok { - // invalid request - c.sendConnectionFailed(msg.PeersId) - return - } - ch := c.newChan(c.transport) - ch.remoteId = msg.PeersId - ch.remoteWin.add(msg.PeersWindow) - ch.maxPacket = msg.MaxPacketSize - - m := channelOpenConfirmMsg{ - PeersId: ch.remoteId, - MyId: ch.localId, - MyWindow: channelWindowSize, - MaxPacketSize: channelMaxPacketSize, - } +// handleChannelOpens channel open messages from the remote side. +func (c *Client) handleChannelOpens(in <-chan NewChannel) { + for ch := range in { + c.mu.Lock() + handler := c.channelHandlers[ch.ChannelType()] + c.mu.Unlock() - c.transport.writePacket(marshal(msgChannelOpenConfirm, m)) - l <- forward{ch, raddr} - default: - // unknown channel type - m := channelOpenFailureMsg{ - PeersId: msg.PeersId, - Reason: UnknownChannelType, - Message: fmt.Sprintf("unknown channel type: %v", msg.ChanType), - Language: "en_US.UTF-8", + if handler != nil { + handler <- ch + } else { + ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) } - c.transport.writePacket(marshal(msgChannelOpenFailure, m)) } -} -// sendGlobalRequest sends a global request message as specified -// in RFC4254 section 4. To correctly synchronise messages, a lock -// is held internally until a response is returned. -func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) { - c.globalRequest.Lock() - defer c.globalRequest.Unlock() - if err := c.transport.writePacket(marshal(msgGlobalRequest, m)); err != nil { - return nil, err - } - r := <-c.globalRequest.response - if r, ok := r.(*globalRequestSuccessMsg); ok { - return r, nil + c.mu.Lock() + for _, ch := range c.channelHandlers { + close(ch) } - return nil, errors.New("request failed") -} - -// sendConnectionFailed rejects an incoming channel identified -// by remoteId. -func (c *ClientConn) sendConnectionFailed(remoteId uint32) error { - m := channelOpenFailureMsg{ - PeersId: remoteId, - Reason: ConnectionFailed, - Message: "invalid request", - Language: "en_US.UTF-8", - } - return c.transport.writePacket(marshal(msgChannelOpenFailure, m)) + c.channelHandlers = nil + c.mu.Unlock() } // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. @@ -413,7 +168,7 @@ func parseTCPAddr(b []byte) (*net.TCPAddr, []byte, bool) { return nil, b, false } port, b, ok := parseUint32(b) - if !ok { + if !ok || port == 0 || port > 65535 { return nil, b, false } ip := net.ParseIP(string(addr)) @@ -423,102 +178,44 @@ func parseTCPAddr(b []byte) (*net.TCPAddr, []byte, bool) { return &net.TCPAddr{IP: ip, Port: int(port)}, b, true } -// Dial connects to the given network address using net.Dial and -// then initiates a SSH handshake, returning the resulting client connection. -func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) { +// Dial starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. For access +// to incoming channels and requests, use net.Dial with NewClientConn +// instead. +func Dial(network, addr string, config *ClientConfig) (*Client, error) { conn, err := net.Dial(network, addr) if err != nil { return nil, err } - return clientWithAddress(conn, addr, config) + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return NewClient(c, chans, reqs), nil } -// A ClientConfig structure is used to configure a ClientConn. After one has -// been passed to an SSH function it must not be modified. +// A ClientConfig structure is used to configure a Client. It must not be +// modified after having been passed to an SSH function. type ClientConfig struct { - // Rand provides the source of entropy for key exchange. If Rand is - // nil, the cryptographic random reader in package crypto/rand will - // be used. - Rand io.Reader + // Config contains configuration that is shared between clients and + // servers. + Config - // The username to authenticate. + // User contains the username to authenticate as. User string - // A slice of ClientAuth methods. Only the first instance - // of a particular RFC 4252 method will be used during authentication. - Auth []ClientAuth + // Auth contains possible authentication methods to use with the + // server. Only the first instance of a particular RFC 4252 method will + // be used during authentication. + Auth []AuthMethod - // HostKeyChecker, if not nil, is called during the cryptographic - // handshake to validate the server's host key. A nil HostKeyChecker + // HostKeyCallback, if not nil, is called during the cryptographic + // handshake to validate the server's host key. A nil HostKeyCallback // implies that all host keys are accepted. - HostKeyChecker HostKeyChecker - - // Cryptographic-related configuration. - Crypto CryptoConfig + HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error - // The identification string that will be used for the connection. - // If empty, a reasonable default is used. + // ClientVersion contains the version identification string that will + // be used for the connection. If empty, a reasonable default is used. ClientVersion string } - -func (c *ClientConfig) rand() io.Reader { - if c.Rand == nil { - return rand.Reader - } - return c.Rand -} - -// Thread safe channel list. -type chanList struct { - // protects concurrent access to chans - sync.Mutex - // chans are indexed by the local id of the channel, clientChan.localId. - // The PeersId value of messages received by ClientConn.mainLoop is - // used to locate the right local clientChan in this slice. - chans []*clientChan -} - -// Allocate a new ClientChan with the next avail local id. -func (c *chanList) newChan(p packetConn) *clientChan { - c.Lock() - defer c.Unlock() - for i := range c.chans { - if c.chans[i] == nil { - ch := newClientChan(p, uint32(i)) - c.chans[i] = ch - return ch - } - } - i := len(c.chans) - ch := newClientChan(p, uint32(i)) - c.chans = append(c.chans, ch) - return ch -} - -func (c *chanList) getChan(id uint32) (*clientChan, bool) { - c.Lock() - defer c.Unlock() - if id >= uint32(len(c.chans)) { - return nil, false - } - return c.chans[id], true -} - -func (c *chanList) remove(id uint32) { - c.Lock() - defer c.Unlock() - c.chans[id] = nil -} - -func (c *chanList) closeAll() { - c.Lock() - defer c.Unlock() - - for _, ch := range c.chans { - if ch == nil { - continue - } - ch.Close() - close(ch.msg) - } -} diff --git a/ssh/client_auth.go b/ssh/client_auth.go index 29be0ca..5b7aa30 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -5,16 +5,16 @@ package ssh import ( + "bytes" "errors" "fmt" "io" - "net" ) -// authenticate authenticates with the remote server. See RFC 4252. -func (c *ClientConn) authenticate() error { +// clientAuthenticate authenticates with the remote server. See RFC 4252. +func (c *connection) clientAuthenticate(config *ClientConfig) error { // initiate user auth session - if err := c.transport.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil { + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { return err } packet, err := c.transport.readPacket() @@ -22,14 +22,15 @@ func (c *ClientConn) authenticate() error { return err } var serviceAccept serviceAcceptMsg - if err := unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil { + if err := Unmarshal(packet, &serviceAccept); err != nil { return err } + // during the authentication phase the client first attempts the "none" method // then any untried methods suggested by the server. - tried, remain := make(map[string]bool), make(map[string]bool) - for auth := ClientAuth(new(noneAuth)); auth != nil; { - ok, methods, err := auth.auth(c.transport.sessionID, c.config.User, c.transport, c.config.rand()) + tried := make(map[string]bool) + for auth := AuthMethod(new(noneAuth)); auth != nil; { + ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand) if err != nil { return err } @@ -38,45 +39,35 @@ func (c *ClientConn) authenticate() error { return nil } tried[auth.method()] = true - delete(remain, auth.method()) - for _, meth := range methods { - if tried[meth] { - // if we've tried meth already, skip it. - continue - } - remain[meth] = true - } + auth = nil - for _, a := range c.config.Auth { - if remain[a.method()] { - auth = a - break + for _, a := range config.Auth { + candidateMethod := a.method() + for _, meth := range methods { + if meth != candidateMethod { + continue + } + if !tried[meth] { + auth = a + break + } } } } return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried)) } -func keys(m map[string]bool) (s []string) { - for k := range m { - s = append(s, k) - } - return -} +func keys(m map[string]bool) []string { + s := make([]string, 0, len(m)) -// HostKeyChecker represents a database of known server host keys. -type HostKeyChecker interface { - // Check is called during the handshake to check server's - // public key for unexpected changes. The hostKey argument is - // in SSH wire format. It can be parsed using - // ssh.ParsePublicKey. The address before DNS resolution is - // passed in the addr argument, so the key can also be checked - // against the hostname. - Check(addr string, remote net.Addr, algorithm string, hostKey []byte) error + for key := range m { + s = append(s, key) + } + return s } -// A ClientAuth represents an instance of an RFC 4252 authentication method. -type ClientAuth interface { +// An AuthMethod represents an instance of an RFC 4252 authentication method. +type AuthMethod interface { // auth authenticates user over transport t. // Returns true if authentication is successful. // If authentication is not successful, a []string of alternative @@ -91,7 +82,7 @@ type ClientAuth interface { type noneAuth int func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { - if err := c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{ + if err := c.writePacket(Marshal(&userAuthRequestMsg{ User: user, Service: serviceSSH, Method: "none", @@ -106,29 +97,31 @@ func (n *noneAuth) method() string { return "none" } -// "password" authentication, RFC 4252 Section 8. -type passwordAuth struct { - ClientPassword -} +// passwordCallback is an AuthMethod that fetches the password through +// a function call, e.g. by prompting the user. +type passwordCallback func() (password string, err error) -func (p *passwordAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { +func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { type passwordAuthMsg struct { - User string + User string `sshtype:"50"` Service string Method string Reply bool Password string } - pw, err := p.Password(user) + pw, err := cb() + // REVIEW NOTE: is there a need to support skipping a password attempt? + // The program may only find out that the user doesn't have a password + // when prompting. if err != nil { return false, nil, err } - if err := c.writePacket(marshal(msgUserAuthRequest, passwordAuthMsg{ + if err := c.writePacket(Marshal(&passwordAuthMsg{ User: user, Service: serviceSSH, - Method: "password", + Method: cb.method(), Reply: false, Password: pw, })); err != nil { @@ -138,106 +131,93 @@ func (p *passwordAuth) auth(session []byte, user string, c packetConn, rand io.R return handleAuthResponse(c) } -func (p *passwordAuth) method() string { +func (cb passwordCallback) method() string { return "password" } -// A ClientPassword implements access to a client's passwords. -type ClientPassword interface { - // Password returns the password to use for user. - Password(user string) (password string, err error) -} - -// ClientAuthPassword returns a ClientAuth using password authentication. -func ClientAuthPassword(impl ClientPassword) ClientAuth { - return &passwordAuth{impl} -} - -// ClientKeyring implements access to a client key ring. -type ClientKeyring interface { - // Key returns the i'th Publickey, or nil if no key exists at i. - Key(i int) (key PublicKey, err error) - - // Sign returns a signature of the given data using the i'th key - // and the supplied random source. - Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) +// Password returns an AuthMethod using the given password. +func Password(secret string) AuthMethod { + return passwordCallback(func() (string, error) { return secret, nil }) } -// "publickey" authentication, RFC 4252 Section 7. -type publickeyAuth struct { - ClientKeyring +// PasswordCallback returns an AuthMethod that uses a callback for +// fetching a password. +func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { + return passwordCallback(prompt) } type publickeyAuthMsg struct { - User string + User string `sshtype:"50"` Service string Method string // HasSig indicates to the receiver packet that the auth request is signed and // should be used for authentication of the request. HasSig bool Algoname string - Pubkey string - // Sig is defined as []byte so marshal will exclude it during validateKey + PubKey []byte + // Sig is tagged with "rest" so Marshal will exclude it during + // validateKey Sig []byte `ssh:"rest"` } -func (p *publickeyAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { +// publicKeyCallback is an AuthMethod that uses a set of key +// pairs for authentication. +type publicKeyCallback func() ([]Signer, error) + +func (cb publicKeyCallback) method() string { + return "publickey" +} + +func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { // Authentication is performed in two stages. The first stage sends an // enquiry to test if each key is acceptable to the remote. The second // stage attempts to authenticate with the valid keys obtained in the // first stage. - var index int - // a map of public keys to their index in the keyring - validKeys := make(map[int]PublicKey) - for { - key, err := p.Key(index) - if err != nil { - return false, nil, err - } - if key == nil { - // no more keys in the keyring - break - } - - if ok, err := p.validateKey(key, user, c); ok { - validKeys[index] = key + signers, err := cb() + if err != nil { + return false, nil, err + } + var validKeys []Signer + for _, signer := range signers { + if ok, err := validateKey(signer.PublicKey(), user, c); ok { + validKeys = append(validKeys, signer) } else { if err != nil { return false, nil, err } } - index++ } // methods that may continue if this auth is not successful. var methods []string - for i, key := range validKeys { - pubkey := MarshalPublicKey(key) - algoname := key.PublicKeyAlgo() - data := buildDataSignedForAuth(session, userAuthRequestMsg{ + for _, signer := range validKeys { + pub := signer.PublicKey() + + pubKey := pub.Marshal() + sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ User: user, Service: serviceSSH, - Method: p.method(), - }, []byte(algoname), pubkey) - sigBlob, err := p.Sign(i, rand, data) + Method: cb.method(), + }, []byte(pub.Type()), pubKey)) if err != nil { return false, nil, err } + // manually wrap the serialized signature in a string - s := serializeSignature(key.PublicKeyAlgo(), sigBlob) + s := Marshal(sign) sig := make([]byte, stringLength(len(s))) marshalString(sig, s) msg := publickeyAuthMsg{ User: user, Service: serviceSSH, - Method: p.method(), + Method: cb.method(), HasSig: true, - Algoname: algoname, - Pubkey: string(pubkey), + Algoname: pub.Type(), + PubKey: pubKey, Sig: sig, } - p := marshal(msgUserAuthRequest, msg) + p := Marshal(&msg) if err := c.writePacket(p); err != nil { return false, nil, err } @@ -252,28 +232,27 @@ func (p *publickeyAuth) auth(session []byte, user string, c packetConn, rand io. return false, methods, nil } -// validateKey validates the key provided it is acceptable to the server. -func (p *publickeyAuth) validateKey(key PublicKey, user string, c packetConn) (bool, error) { - pubkey := MarshalPublicKey(key) - algoname := key.PublicKeyAlgo() +// validateKey validates the key provided is acceptable to the server. +func validateKey(key PublicKey, user string, c packetConn) (bool, error) { + pubKey := key.Marshal() msg := publickeyAuthMsg{ User: user, Service: serviceSSH, - Method: p.method(), + Method: "publickey", HasSig: false, - Algoname: algoname, - Pubkey: string(pubkey), + Algoname: key.Type(), + PubKey: pubKey, } - if err := c.writePacket(marshal(msgUserAuthRequest, msg)); err != nil { + if err := c.writePacket(Marshal(&msg)); err != nil { return false, err } - return p.confirmKeyAck(key, c) + return confirmKeyAck(key, c) } -func (p *publickeyAuth) confirmKeyAck(key PublicKey, c packetConn) (bool, error) { - pubkey := MarshalPublicKey(key) - algoname := key.PublicKeyAlgo() +func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { + pubKey := key.Marshal() + algoname := key.Type() for { packet, err := c.readPacket() @@ -284,30 +263,32 @@ func (p *publickeyAuth) confirmKeyAck(key PublicKey, c packetConn) (bool, error) case msgUserAuthBanner: // TODO(gpaul): add callback to present the banner to the user case msgUserAuthPubKeyOk: - msg := userAuthPubKeyOkMsg{} - if err := unmarshal(&msg, packet, msgUserAuthPubKeyOk); err != nil { + var msg userAuthPubKeyOkMsg + if err := Unmarshal(packet, &msg); err != nil { return false, err } - if msg.Algo != algoname || msg.PubKey != string(pubkey) { + if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { return false, nil } return true, nil case msgUserAuthFailure: return false, nil default: - return false, UnexpectedMessageError{msgUserAuthSuccess, packet[0]} + return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) } } - panic("unreachable") } -func (p *publickeyAuth) method() string { - return "publickey" +// PublicKeys returns an AuthMethod that uses the given key +// pairs. +func PublicKeys(signers ...Signer) AuthMethod { + return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) } -// ClientAuthKeyring returns a ClientAuth using public key authentication. -func ClientAuthKeyring(impl ClientKeyring) ClientAuth { - return &publickeyAuth{impl} +// PublicKeysCallback returns an AuthMethod that runs the given +// function to obtain a list of key pairs. +func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { + return publicKeyCallback(getSigners) } // handleAuthResponse returns whether the preceding authentication request succeeded @@ -324,8 +305,8 @@ func handleAuthResponse(c packetConn) (bool, []string, error) { case msgUserAuthBanner: // TODO: add callback to present the banner to the user case msgUserAuthFailure: - msg := userAuthFailureMsg{} - if err := unmarshal(&msg, packet, msgUserAuthFailure); err != nil { + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { return false, nil, err } return false, msg.Methods, nil @@ -334,98 +315,40 @@ func handleAuthResponse(c packetConn) (bool, []string, error) { case msgDisconnect: return false, nil, io.EOF default: - return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]} + return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) } } - panic("unreachable") -} - -// ClientAuthAgent returns a ClientAuth using public key authentication via -// an agent. -func ClientAuthAgent(agent *AgentClient) ClientAuth { - return ClientAuthKeyring(&agentKeyring{agent: agent}) -} - -// agentKeyring implements ClientKeyring. -type agentKeyring struct { - agent *AgentClient - keys []*AgentKey -} - -func (kr *agentKeyring) Key(i int) (key PublicKey, err error) { - if kr.keys == nil { - if kr.keys, err = kr.agent.RequestIdentities(); err != nil { - return - } - } - if i >= len(kr.keys) { - return - } - return kr.keys[i].Key() } -func (kr *agentKeyring) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) { - var key PublicKey - if key, err = kr.Key(i); err != nil { - return - } - if key == nil { - return nil, errors.New("ssh: key index out of range") - } - if sig, err = kr.agent.SignRequest(key, data); err != nil { - return - } - - // Unmarshal the signature. - - var ok bool - if _, sig, ok = parseString(sig); !ok { - return nil, errors.New("ssh: malformed signature response from agent") - } - if sig, _, ok = parseString(sig); !ok { - return nil, errors.New("ssh: malformed signature response from agent") - } - return sig, nil -} - -// ClientKeyboardInteractive should prompt the user for the given -// questions. -type ClientKeyboardInteractive interface { - // Challenge should print the questions, optionally disabling - // echoing (eg. for passwords), and return all the answers. - // Challenge may be called multiple times in a single - // session. After successful authentication, the server may - // send a challenge with no questions, for which the user and - // instruction messages should be printed. RFC 4256 section - // 3.3 details how the UI should behave for both CLI and - // GUI environments. - Challenge(user, instruction string, questions []string, echos []bool) ([]string, error) -} - -// ClientAuthKeyboardInteractive returns a ClientAuth using a -// prompt/response sequence controlled by the server. -func ClientAuthKeyboardInteractive(impl ClientKeyboardInteractive) ClientAuth { - return &keyboardInteractiveAuth{impl} -} +// KeyboardInteractiveChallenge should print questions, optionally +// disabling echoing (e.g. for passwords), and return all the answers. +// Challenge may be called multiple times in a single session. After +// successful authentication, the server may send a challenge with no +// questions, for which the user and instruction messages should be +// printed. RFC 4256 section 3.3 details how the UI should behave for +// both CLI and GUI environments. +type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) -type keyboardInteractiveAuth struct { - ClientKeyboardInteractive +// KeyboardInteractive returns a AuthMethod using a prompt/response +// sequence controlled by the server. +func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { + return challenge } -func (k *keyboardInteractiveAuth) method() string { +func (cb KeyboardInteractiveChallenge) method() string { return "keyboard-interactive" } -func (k *keyboardInteractiveAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { +func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { type initiateMsg struct { - User string + User string `sshtype:"50"` Service string Method string Language string Submethods string } - if err := c.writePacket(marshal(msgUserAuthRequest, initiateMsg{ + if err := c.writePacket(Marshal(&initiateMsg{ User: user, Service: serviceSSH, Method: "keyboard-interactive", @@ -448,18 +371,18 @@ func (k *keyboardInteractiveAuth) auth(session []byte, user string, c packetConn // OK case msgUserAuthFailure: var msg userAuthFailureMsg - if err := unmarshal(&msg, packet, msgUserAuthFailure); err != nil { + if err := Unmarshal(packet, &msg); err != nil { return false, nil, err } return false, msg.Methods, nil case msgUserAuthSuccess: return true, nil, nil default: - return false, nil, UnexpectedMessageError{msgUserAuthInfoRequest, packet[0]} + return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) } var msg userAuthInfoRequestMsg - if err := unmarshal(&msg, packet, packet[0]); err != nil { + if err := Unmarshal(packet, &msg); err != nil { return false, nil, err } @@ -478,10 +401,10 @@ func (k *keyboardInteractiveAuth) auth(session []byte, user string, c packetConn } if len(rest) != 0 { - return false, nil, fmt.Errorf("ssh: junk following message %q", rest) + return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs") } - answers, err := k.Challenge(msg.User, msg.Instruction, prompts, echos) + answers, err := cb(msg.User, msg.Instruction, prompts, echos) if err != nil { return false, nil, err } diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go index f2fc9c6..3173255 100644 --- a/ssh/client_auth_test.go +++ b/ssh/client_auth_test.go @@ -6,363 +6,317 @@ package ssh import ( "bytes" - "crypto/dsa" - "io" - "io/ioutil" - "math/big" + "crypto/rand" + "errors" + "fmt" "strings" "testing" - - _ "crypto/sha1" ) -// private key for mock server -const testServerPrivateKey = `-----BEGIN RSA PRIVATE KEY----- -MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU -70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx -9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF -tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z -s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc -qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT -+IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea -riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH -D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh -atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT -b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN -ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M -MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4 -KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8 -e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1 -D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+ -3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj -orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw -64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc -XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc -QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g -/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ -I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk -gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl -NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw== ------END RSA PRIVATE KEY-----` - -const testClientPrivateKey = `-----BEGIN RSA PRIVATE KEY----- -MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld -r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ -tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC -nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW -2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB -y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr -rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg== ------END RSA PRIVATE KEY-----` - -// keychain implements the ClientKeyring interface -type keychain struct { - keys []Signer -} +type keyboardInteractive map[string]string -func (k *keychain) Key(i int) (PublicKey, error) { - if i < 0 || i >= len(k.keys) { - return nil, nil +func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { + var answers []string + for _, q := range questions { + answers = append(answers, cr[q]) } - - return k.keys[i].PublicKey(), nil -} - -func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) { - return k.keys[i].Sign(rand, data) + return answers, nil } -func (k *keychain) add(key Signer) { - k.keys = append(k.keys, key) -} +// reused internally by tests +var clientPassword = "tiger" -func (k *keychain) loadPEM(file string) error { - buf, err := ioutil.ReadFile(file) - if err != nil { - return err - } - key, err := ParsePrivateKey(buf) +// tryAuth runs a handshake with a given config against an SSH server +// with config serverConfig +func tryAuth(t *testing.T, config *ClientConfig) error { + c1, c2, err := netPipe() if err != nil { - return err + t.Fatalf("netPipe: %v", err) } - k.add(key) - return nil -} - -// password implements the ClientPassword interface -type password string + defer c1.Close() + defer c2.Close() -func (p password) Password(user string) (string, error) { - return string(p), nil -} - -type keyboardInteractive map[string]string + certChecker := CertChecker{ + IsAuthority: func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) + }, + UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } -func (cr *keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { - var answers []string - for _, q := range questions { - answers = append(answers, (*cr)[q]) + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + IsRevoked: func(c *Certificate) bool { + return c.Serial == 666 + }, } - return answers, nil -} -// reused internally by tests -var ( - rsaKey Signer - dsaKey Signer - clientKeychain = new(keychain) - clientPassword = password("tiger") - serverConfig = &ServerConfig{ - PasswordCallback: func(conn *ServerConn, user, pass string) bool { - return user == "testuser" && pass == string(clientPassword) - }, - PublicKeyCallback: func(conn *ServerConn, user, algo string, pubkey []byte) bool { - key, _ := clientKeychain.Key(0) - expected := MarshalPublicKey(key) - algoname := key.PublicKeyAlgo() - return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected) + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") }, - KeyboardInteractiveCallback: func(conn *ServerConn, user string, client ClientKeyboardInteractive) bool { - ans, err := client.Challenge("user", + PublicKeyCallback: certChecker.Authenticate, + KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { + ans, err := challenge("user", "instruction", []string{"question1", "question2"}, []bool{true, true}) if err != nil { - return false + return nil, err + } + ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" + if ok { + challenge("user", "motd", nil, nil) + return nil, nil } - ok := user == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" - client.Challenge("user", "motd", nil, nil) - return ok + return nil, errors.New("keyboard-interactive failed") + }, + AuthLogCallback: func(conn ConnMetadata, method string, err error) { + t.Logf("user %q, method %q: %v", conn.User(), method, err) }, } -) - -func init() { - var err error - rsaKey, err = ParsePrivateKey([]byte(testServerPrivateKey)) - if err != nil { - panic("unable to set private key: " + err.Error()) - } - rawDSAKey := new(dsa.PrivateKey) - - // taken from crypto/dsa/dsa_test.go - rawDSAKey.P, _ = new(big.Int).SetString("A9B5B793FB4785793D246BAE77E8FF63CA52F442DA763C440259919FE1BC1D6065A9350637A04F75A2F039401D49F08E066C4D275A5A65DA5684BC563C14289D7AB8A67163BFBF79D85972619AD2CFF55AB0EE77A9002B0EF96293BDD0F42685EBB2C66C327079F6C98000FBCB79AACDE1BC6F9D5C7B1A97E3D9D54ED7951FEF", 16) - rawDSAKey.Q, _ = new(big.Int).SetString("E1D3391245933D68A0714ED34BBCB7A1F422B9C1", 16) - rawDSAKey.G, _ = new(big.Int).SetString("634364FC25248933D01D1993ECABD0657CC0CB2CEED7ED2E3E8AECDFCDC4A25C3B15E9E3B163ACA2984B5539181F3EFF1A5E8903D71D5B95DA4F27202B77D2C44B430BB53741A8D59A8F86887525C9F2A6A5980A195EAA7F2FF910064301DEF89D3AA213E1FAC7768D89365318E370AF54A112EFBA9246D9158386BA1B4EEFDA", 16) - rawDSAKey.Y, _ = new(big.Int).SetString("32969E5780CFE1C849A1C276D7AEB4F38A23B591739AA2FE197349AEEBD31366AEE5EB7E6C6DDB7C57D02432B30DB5AA66D9884299FAA72568944E4EEDC92EA3FBC6F39F53412FBCC563208F7C15B737AC8910DBC2D9C9B8C001E72FDC40EB694AB1F06A5A2DBD18D9E36C66F31F566742F11EC0A52E9F7B89355C02FB5D32D2", 16) - rawDSAKey.X, _ = new(big.Int).SetString("5078D4D29795CBE76D3AACFE48C9AF0BCDBEE91A", 16) - - dsaKey, err = NewSignerFromKey(rawDSAKey) - if err != nil { - panic("NewSignerFromKey: " + err.Error()) - } - clientKeychain.add(rsaKey) - serverConfig.AddHostKey(rsaKey) -} + serverConfig.AddHostKey(testSigners["rsa"]) -// newMockAuthServer creates a new Server bound to -// the loopback interface. The server exits after -// processing one handshake. -func newMockAuthServer(t *testing.T) string { - l, err := Listen("tcp", "127.0.0.1:0", serverConfig) - if err != nil { - t.Fatalf("unable to newMockAuthServer: %s", err) - } - go func() { - defer l.Close() - c, err := l.Accept() - if err != nil { - t.Errorf("Unable to accept incoming connection: %v", err) - return - } - if err := c.Handshake(); err != nil { - // not Errorf because this is expected to - // fail for some tests. - t.Logf("Handshaking error: %v", err) - return - } - defer c.Close() - }() - return l.Addr().String() + go newServer(c1, serverConfig) + _, _, _, err = NewClientConn(c2, "", config) + return err } func TestClientAuthPublicKey(t *testing.T) { config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthKeyring(clientKeychain), + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err != nil { + if err := tryAuth(t, config); err != nil { t.Fatalf("unable to dial remote side: %s", err) } - c.Close() } -func TestClientAuthPassword(t *testing.T) { +func TestAuthMethodPassword(t *testing.T) { config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthPassword(clientPassword), + Auth: []AuthMethod{ + Password(clientPassword), }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err != nil { + if err := tryAuth(t, config); err != nil { t.Fatalf("unable to dial remote side: %s", err) } - c.Close() } -func TestClientAuthWrongPassword(t *testing.T) { - wrongPw := password("wrong") +func TestAuthMethodWrongPassword(t *testing.T) { config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthPassword(wrongPw), - ClientAuthKeyring(clientKeychain), + Auth: []AuthMethod{ + Password("wrong"), + PublicKeys(testSigners["rsa"]), }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err != nil { + if err := tryAuth(t, config); err != nil { t.Fatalf("unable to dial remote side: %s", err) } - c.Close() } -func TestClientAuthKeyboardInteractive(t *testing.T) { +func TestAuthMethodKeyboardInteractive(t *testing.T) { answers := keyboardInteractive(map[string]string{ "question1": "answer1", "question2": "answer2", }) config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthKeyboardInteractive(&answers), + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err != nil { + if err := tryAuth(t, config); err != nil { t.Fatalf("unable to dial remote side: %s", err) } - c.Close() } -func TestClientAuthWrongKeyboardInteractive(t *testing.T) { +func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { answers := keyboardInteractive(map[string]string{ "question1": "answer1", "question2": "WRONG", }) config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthKeyboardInteractive(&answers), + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err == nil { - c.Close() + if err := tryAuth(t, config); err == nil { t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") } } // the mock server will only authenticate ssh-rsa keys -func TestClientAuthInvalidPublicKey(t *testing.T) { - kc := new(keychain) - - kc.add(dsaKey) +func TestAuthMethodInvalidPublicKey(t *testing.T) { config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthKeyring(kc), + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"]), }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err == nil { - c.Close() + if err := tryAuth(t, config); err == nil { t.Fatalf("dsa private key should not have authenticated with rsa public key") } } // the client should authenticate with the second key -func TestClientAuthRSAandDSA(t *testing.T) { - kc := new(keychain) - kc.add(dsaKey) - kc.add(rsaKey) +func TestAuthMethodRSAandDSA(t *testing.T) { config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthKeyring(kc), + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"], testSigners["rsa"]), }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err != nil { + if err := tryAuth(t, config); err != nil { t.Fatalf("client could not authenticate with rsa key: %v", err) } - c.Close() } func TestClientHMAC(t *testing.T) { - kc := new(keychain) - kc.add(rsaKey) - for _, mac := range DefaultMACOrder { + for _, mac := range supportedMACs { config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthKeyring(kc), + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), }, - Crypto: CryptoConfig{ + Config: Config{ MACs: []string{mac}, }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err != nil { + if err := tryAuth(t, config); err != nil { t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) } - c.Close() } } // issue 4285. func TestClientUnsupportedCipher(t *testing.T) { - kc := new(keychain) config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthKeyring(kc), + Auth: []AuthMethod{ + PublicKeys(), }, - Crypto: CryptoConfig{ + Config: Config{ Ciphers: []string{"aes128-cbc"}, // not currently supported }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err == nil { + if err := tryAuth(t, config); err == nil { t.Errorf("expected no ciphers in common") - c.Close() } } func TestClientUnsupportedKex(t *testing.T) { - kc := new(keychain) config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthKeyring(kc), + Auth: []AuthMethod{ + PublicKeys(), }, - Crypto: CryptoConfig{ + Config: Config{ KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported }, } - c, err := Dial("tcp", newMockAuthServer(t), config) - if err == nil || !strings.Contains(err.Error(), "no common algorithms") { + if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "no common algorithms") { t.Errorf("got %v, expected 'no common algorithms'", err) } - if c != nil { - c.Close() +} + +func TestClientLoginCert(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + } + clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) + + t.Log("should succeed") + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + t.Log("corrupted signature") + cert.Signature.Blob[0]++ + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with corrupted sig") + } + + t.Log("revoked") + cert.Serial = 666 + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("revoked cert login succeeded") + } + cert.Serial = 1 + + t.Log("sign with wrong key") + cert.SignCert(rand.Reader, testSigners["dsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with non-authoritive key") + } + + t.Log("host cert") + cert.CertType = HostCert + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong type") + } + cert.CertType = UserCert + + t.Log("principal specified") + cert.ValidPrincipals = []string{"user"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + t.Log("wrong principal specified") + cert.ValidPrincipals = []string{"fred"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong principal") + } + cert.ValidPrincipals = nil + + t.Log("added critical option") + cert.CriticalOptions = map[string]string{"root-access": "yes"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with unrecognized critical option") + } + + t.Log("allowed source address") + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login with source-address failed: %v", err) + } + + t.Log("disallowed source address") + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login with source-address succeeded") } } diff --git a/ssh/client_test.go b/ssh/client_test.go index f6c11b9..1fe790c 100644 --- a/ssh/client_test.go +++ b/ssh/client_test.go @@ -1,3 +1,7 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package ssh import ( @@ -7,6 +11,7 @@ import ( func testClientVersion(t *testing.T, config *ClientConfig, expected string) { clientConn, serverConn := net.Pipe() + defer clientConn.Close() receivedVersion := make(chan string, 1) go func() { version, err := readVersion(serverConn) @@ -17,7 +22,7 @@ func testClientVersion(t *testing.T, config *ClientConfig, expected string) { } serverConn.Close() }() - Client(clientConn, config) + NewClientConn(clientConn, "", config) actual := <-receivedVersion if actual != expected { t.Fatalf("got %s; want %s", actual, expected) diff --git a/ssh/common.go b/ssh/common.go index 4870e56..2fd7fd9 100644 --- a/ssh/common.go +++ b/ssh/common.go @@ -6,7 +6,9 @@ package ssh import ( "crypto" + "crypto/rand" "fmt" + "io" "sync" _ "crypto/sha1" @@ -21,16 +23,39 @@ const ( serviceSSH = "ssh-connection" ) +// supportedCiphers specifies the supported ciphers in preference order. +var supportedCiphers = []string{ + "aes128-ctr", "aes192-ctr", "aes256-ctr", + "aes128-gcm@openssh.com", + "arcfour256", "arcfour128", +} + +// supportedKexAlgos specifies the supported key-exchange algorithms in +// preference order. var supportedKexAlgos = []string{ + // P384 and P521 are not constant-time yet, but since we don't + // reuse ephemeral keys, using them for ECDH should be OK. kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, kexAlgoDH14SHA1, kexAlgoDH1SHA1, } +// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods +// of authenticating servers) in preference order. var supportedHostKeyAlgos = []string{ + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, + CertAlgoECDSA384v01, CertAlgoECDSA521v01, + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoRSA, KeyAlgoDSA, } +// supportedMACs specifies a default set of MAC algorithms in preference order. +// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed +// because they have reached the end of their useful life. +var supportedMACs = []string{ + "hmac-sha1", "hmac-sha1-96", +} + var supportedCompressions = []string{compressionNone} // hashFuncs keeps the mapping of supported algorithms to their respective @@ -48,23 +73,15 @@ var hashFuncs = map[string]crypto.Hash{ CertAlgoECDSA521v01: crypto.SHA512, } -// UnexpectedMessageError results when the SSH message that we received didn't +// unexpectedMessageError results when the SSH message that we received didn't // match what we wanted. -type UnexpectedMessageError struct { - expected, got uint8 +func unexpectedMessageError(expected, got uint8) error { + return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) } -func (u UnexpectedMessageError) Error() string { - return fmt.Sprintf("ssh: unexpected message type %d (expected %d)", u.got, u.expected) -} - -// ParseError results from a malformed SSH message. -type ParseError struct { - msgType uint8 -} - -func (p ParseError) Error() string { - return fmt.Sprintf("ssh: parse error in message type %d", p.msgType) +// parseError results from a malformed SSH message. +func parseError(tag uint8) error { + return fmt.Errorf("ssh: parse error in message type %d", tag) } func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) { @@ -90,15 +107,17 @@ func findCommonCipher(clientCiphers []string, serverCiphers []string) (commonCip return } +type directionAlgorithms struct { + Cipher string + MAC string + Compression string +} + type algorithms struct { - kex string - hostKey string - wCipher string - rCipher string - rMAC string - wMAC string - rCompression string - wCompression string + kex string + hostKey string + w directionAlgorithms + r directionAlgorithms } func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) { @@ -114,32 +133,32 @@ func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algor return } - result.wCipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + result.w.Cipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) if !ok { return } - result.rCipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + result.r.Cipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) if !ok { return } - result.wMAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + result.w.MAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) if !ok { return } - result.rMAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + result.r.MAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) if !ok { return } - result.wCompression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + result.w.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) if !ok { return } - result.rCompression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + result.r.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) if !ok { return } @@ -147,133 +166,87 @@ func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algor return result } -// Cryptographic configuration common to both ServerConfig and ClientConfig. -type CryptoConfig struct { +// If rekeythreshold is too small, we can't make any progress sending +// stuff. +const minRekeyThreshold uint64 = 256 + +// Config contains configuration data common to both ServerConfig and +// ClientConfig. +type Config struct { + // Rand provides the source of entropy for cryptographic + // primitives. If Rand is nil, the cryptographic random reader + // in package crypto/rand will be used. + Rand io.Reader + + // The maximum number of bytes sent or received after which a + // new key is negotiated. It must be at least 256. If + // unspecified, 1 gigabyte is used. + RekeyThreshold uint64 + // The allowed key exchanges algorithms. If unspecified then a // default set of algorithms is used. KeyExchanges []string - // The allowed cipher algorithms. If unspecified then DefaultCipherOrder is - // used. + // The allowed cipher algorithms. If unspecified then a sensible + // default is used. Ciphers []string - // The allowed MAC algorithms. If unspecified then DefaultMACOrder is used. + // The allowed MAC algorithms. If unspecified then a sensible default + // is used. MACs []string } -func (c *CryptoConfig) ciphers() []string { +// SetDefaults sets sensible values for unset fields in config. This is +// exported for testing: Configs passed to SSH functions are copied and have +// default values set automatically. +func (c *Config) SetDefaults() { + if c.Rand == nil { + c.Rand = rand.Reader + } if c.Ciphers == nil { - return DefaultCipherOrder + c.Ciphers = supportedCiphers } - return c.Ciphers -} -func (c *CryptoConfig) kexes() []string { if c.KeyExchanges == nil { - return defaultKeyExchangeOrder + c.KeyExchanges = supportedKexAlgos } - return c.KeyExchanges -} -func (c *CryptoConfig) macs() []string { if c.MACs == nil { - return DefaultMACOrder + c.MACs = supportedMACs } - return c.MACs -} - -// serialize a signed slice according to RFC 4254 6.6. The name should -// be a key type name, rather than a cert type name. -func serializeSignature(name string, sig []byte) []byte { - length := stringLength(len(name)) - length += stringLength(len(sig)) - - ret := make([]byte, length) - r := marshalString(ret, []byte(name)) - r = marshalString(r, sig) - - return ret -} -// MarshalPublicKey serializes a supported key or certificate for use -// by the SSH wire protocol. It can be used for comparison with the -// pubkey argument of ServerConfig's PublicKeyCallback as well as for -// generating an authorized_keys or host_keys file. -func MarshalPublicKey(key PublicKey) []byte { - // See also RFC 4253 6.6. - algoname := key.PublicKeyAlgo() - blob := key.Marshal() - - length := stringLength(len(algoname)) - length += len(blob) - ret := make([]byte, length) - r := marshalString(ret, []byte(algoname)) - copy(r, blob) - return ret -} - -// pubAlgoToPrivAlgo returns the private key algorithm format name that -// corresponds to a given public key algorithm format name. For most -// public keys, the private key algorithm name is the same. For some -// situations, such as openssh certificates, the private key algorithm and -// public key algorithm names differ. This accounts for those situations. -func pubAlgoToPrivAlgo(pubAlgo string) string { - switch pubAlgo { - case CertAlgoRSAv01: - return KeyAlgoRSA - case CertAlgoDSAv01: - return KeyAlgoDSA - case CertAlgoECDSA256v01: - return KeyAlgoECDSA256 - case CertAlgoECDSA384v01: - return KeyAlgoECDSA384 - case CertAlgoECDSA521v01: - return KeyAlgoECDSA521 + if c.RekeyThreshold == 0 { + // RFC 4253, section 9 suggests rekeying after 1G. + c.RekeyThreshold = 1 << 30 + } + if c.RekeyThreshold < minRekeyThreshold { + c.RekeyThreshold = minRekeyThreshold } - return pubAlgo } // buildDataSignedForAuth returns the data that is signed in order to prove // possession of a private key. See RFC 4252, section 7. func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { - user := []byte(req.User) - service := []byte(req.Service) - method := []byte(req.Method) - - length := stringLength(len(sessionId)) - length += 1 - length += stringLength(len(user)) - length += stringLength(len(service)) - length += stringLength(len(method)) - length += 1 - length += stringLength(len(algo)) - length += stringLength(len(pubKey)) - - ret := make([]byte, length) - r := marshalString(ret, sessionId) - r[0] = msgUserAuthRequest - r = r[1:] - r = marshalString(r, user) - r = marshalString(r, service) - r = marshalString(r, method) - r[0] = 1 - r = r[1:] - r = marshalString(r, algo) - r = marshalString(r, pubKey) - return ret -} - -// safeString sanitises s according to RFC 4251, section 9.2. -// All control characters except tab, carriage return and newline are -// replaced by 0x20. -func safeString(s string) string { - out := []byte(s) - for i, c := range out { - if c < 0x20 && c != 0xd && c != 0xa && c != 0x9 { - out[i] = 0x20 - } + data := struct { + Session []byte + Type byte + User string + Service string + Method string + Sign bool + Algo []byte + PubKey []byte + }{ + sessionId, + msgUserAuthRequest, + req.User, + req.Service, + req.Method, + true, + algo, + pubKey, } - return string(out) + return Marshal(data) } func appendU16(buf []byte, n uint16) []byte { @@ -284,6 +257,12 @@ func appendU32(buf []byte, n uint32) []byte { return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) } +func appendU64(buf []byte, n uint64) []byte { + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + func appendInt(buf []byte, n int) []byte { return appendU32(buf, uint32(n)) } @@ -296,11 +275,9 @@ func appendString(buf []byte, s string) []byte { func appendBool(buf []byte, b bool) []byte { if b { - buf = append(buf, 1) - } else { - buf = append(buf, 0) + return append(buf, 1) } - return buf + return append(buf, 0) } // newCond is a helper to hide the fact that there is no usable zero @@ -311,7 +288,9 @@ func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } // wishing to write to a channel. type window struct { *sync.Cond - win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool } // add adds win to the amount of window available @@ -335,18 +314,44 @@ func (w *window) add(win uint32) bool { return true } +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + // reserve reserves win from the available window capacity. // If no capacity remains, reserve will block. reserve may // return less than requested. -func (w *window) reserve(win uint32) uint32 { +func (w *window) reserve(win uint32) (uint32, error) { + var err error w.L.Lock() - for w.win == 0 { + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { w.Wait() } + w.writeWaiters-- if w.win < win { win = w.win } w.win -= win + if w.closed { + err = io.EOF + } w.L.Unlock() - return win + return win, err +} + +// waitWriterBlocked waits until some goroutine is blocked for further +// writes. It is used in tests only. +func (w *window) waitWriterBlocked() { + w.Cond.L.Lock() + for w.writeWaiters == 0 { + w.Cond.Wait() + } + w.Cond.L.Unlock() } diff --git a/ssh/connection.go b/ssh/connection.go new file mode 100644 index 0000000..93551e2 --- /dev/null +++ b/ssh/connection.go @@ -0,0 +1,144 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "fmt" + "net" +) + +// OpenChannelError is returned if the other side rejects an +// OpenChannel request. +type OpenChannelError struct { + Reason RejectionReason + Message string +} + +func (e *OpenChannelError) Error() string { + return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) +} + +// ConnMetadata holds metadata for the connection. +type ConnMetadata interface { + // User returns the user ID for this connection. + // It is empty if no authentication is used. + User() string + + // SessionID returns the sesson hash, also denoted by H. + SessionID() []byte + + // ClientVersion returns the client's version string as hashed + // into the session ID. + ClientVersion() []byte + + // ServerVersion returns the client's version string as hashed + // into the session ID. + ServerVersion() []byte + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr +} + +// Conn represents an SSH connection for both server and client roles. +// Conn is the basis for implementing an application layer, such +// as ClientConn, which implements the traditional shell access for +// clients. +type Conn interface { + ConnMetadata + + // SendRequest sends a global request, and returns the + // reply. If wantReply is true, it returns the response status + // and payload. See also RFC4254, section 4. + SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) + + // OpenChannel tries to open an channel. If the request is + // rejected, it returns *OpenChannelError. On success it returns + // the SSH Channel and a Go channel for incoming, out-of-band + // requests. The Go channel must be serviced, or the + // connection will hang. + OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) + + // Close closes the underlying network connection + Close() error + + // Wait blocks until the connection has shut down, and returns the + // error causing the shutdown. + Wait() error + + // TODO(hanwen): consider exposing: + // RequestKeyChange + // Disconnect +} + +// DiscardRequests consumes and rejects all requests from the +// passed-in channel. +func DiscardRequests(in <-chan *Request) { + for req := range in { + if req.WantReply { + req.Reply(false, nil) + } + } +} + +// A connection represents an incoming connection. +type connection struct { + transport *handshakeTransport + sshConn + + // The connection protocol. + *mux +} + +func (c *connection) Close() error { + return c.sshConn.conn.Close() +} + +// sshconn provides net.Conn metadata, but disallows direct reads and +// writes. +type sshConn struct { + conn net.Conn + + user string + sessionID []byte + clientVersion []byte + serverVersion []byte +} + +func dup(src []byte) []byte { + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func (c *sshConn) User() string { + return c.user +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) Close() error { + return c.conn.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) SessionID() []byte { + return dup(c.sessionID) +} + +func (c *sshConn) ClientVersion() []byte { + return dup(c.clientVersion) +} + +func (c *sshConn) ServerVersion() []byte { + return dup(c.serverVersion) +} @@ -13,7 +13,6 @@ others. References: [PROTOCOL.certkeys]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys - [PROTOCOL.agent]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 */ package ssh diff --git a/ssh/example_test.go b/ssh/example_test.go index a88a677..d9d6a54 100644 --- a/ssh/example_test.go +++ b/ssh/example_test.go @@ -9,17 +9,23 @@ import ( "fmt" "io/ioutil" "log" + "net" "net/http" "code.google.com/p/go.crypto/ssh/terminal" ) -func ExampleListen() { +func ExampleNewServerConn() { // An SSH server is represented by a ServerConfig, which holds // certificate details and handles authentication of ServerConns. config := &ServerConfig{ - PasswordCallback: func(conn *ServerConn, user, pass string) bool { - return user == "testuser" && pass == "tiger" + PasswordCallback: func(c ConnMetadata, pass []byte) (*Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) }, } @@ -37,50 +43,65 @@ func ExampleListen() { // Once a ServerConfig has been configured, connections can be // accepted. - listener, err := Listen("tcp", "0.0.0.0:2022", config) + listener, err := net.Listen("tcp", "0.0.0.0:2022") if err != nil { panic("failed to listen for connection") } - sConn, err := listener.Accept() + nConn, err := listener.Accept() if err != nil { panic("failed to accept incoming connection") } - if err := sConn.Handshake(); err != nil { + + // Before use, a handshake must be performed on the incoming + // net.Conn. + _, chans, reqs, err := NewServerConn(nConn, config) + if err != nil { panic("failed to handshake") } + // The incoming Request channel must be serviced. + go DiscardRequests(reqs) - // A ServerConn multiplexes several channels, which must - // themselves be Accepted. - for { - // Accept reads from the connection, demultiplexes packets - // to their corresponding channels and returns when a new - // channel request is seen. Some goroutine must always be - // calling Accept; otherwise no messages will be forwarded - // to the channels. - channel, err := sConn.Accept() - if err != nil { - panic("error from Accept") - } - + // Service the incoming Channel channel. + for newChannel := range chans { // Channels have a type, depending on the application level // protocol intended. In the case of a shell, the type is // "session" and ServerShell may be used to present a simple // terminal interface. - if channel.ChannelType() != "session" { - channel.Reject(UnknownChannelType, "unknown channel type") + if newChannel.ChannelType() != "session" { + newChannel.Reject(UnknownChannelType, "unknown channel type") continue } - channel.Accept() + channel, requests, err := newChannel.Accept() + if err != nil { + panic("could not accept channel.") + } + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "shell" request. + go func(in <-chan *Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any + // commands, only the + // default shell. + ok = false + } + } + req.Reply(ok, nil) + } + }(requests) term := terminal.NewTerminal(channel, "> ") - serverTerm := &ServerTerminal{ - Term: term, - Channel: channel, - } + go func() { defer channel.Close() for { - line, err := serverTerm.ReadLine() + line, err := term.ReadLine() if err != nil { break } @@ -95,13 +116,11 @@ func ExampleDial() { // the "password" authentication method is supported. // // To authenticate with the remote server you must pass at least one - // implementation of ClientAuth via the Auth field in ClientConfig. + // implementation of AuthMethod via the Auth field in ClientConfig. config := &ClientConfig{ User: "username", - Auth: []ClientAuth{ - // ClientAuthPassword wraps a ClientPassword implementation - // in a type that implements ClientAuth. - ClientAuthPassword(password("yourpassword")), + Auth: []AuthMethod{ + Password("yourpassword"), }, } client, err := Dial("tcp", "yourserver.com:22", config) @@ -127,11 +146,11 @@ func ExampleDial() { fmt.Println(b.String()) } -func ExampleClientConn_Listen() { +func ExampleClient_Listen() { config := &ClientConfig{ User: "username", - Auth: []ClientAuth{ - ClientAuthPassword(password("password")), + Auth: []AuthMethod{ + Password("password"), }, } // Dial your ssh server. @@ -158,8 +177,8 @@ func ExampleSession_RequestPty() { // Create client config config := &ClientConfig{ User: "username", - Auth: []ClientAuth{ - ClientAuthPassword(password("password")), + Auth: []AuthMethod{ + Password("password"), }, } // Connect to ssh server diff --git a/ssh/handshake.go b/ssh/handshake.go new file mode 100644 index 0000000..a1e2c23 --- /dev/null +++ b/ssh/handshake.go @@ -0,0 +1,393 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "sync" +) + +// debugHandshake, if set, prints messages sent and received. Key +// exchange messages are printed as if DH were used, so the debug +// messages are wrong when using ECDH. +const debugHandshake = false + +// keyingTransport is a packet based transport that supports key +// changes. It need not be thread-safe. It should pass through +// msgNewKeys in both directions. +type keyingTransport interface { + packetConn + + // prepareKeyChange sets up a key change. The key change for a + // direction will be effected if a msgNewKeys message is sent + // or received. + prepareKeyChange(*algorithms, *kexResult) error + + // getSessionID returns the session ID. prepareKeyChange must + // have been called once. + getSessionID() []byte +} + +// rekeyingTransport is the interface of handshakeTransport that we +// (internally) expose to ClientConn and ServerConn. +type rekeyingTransport interface { + packetConn + + // requestKeyChange asks the remote side to change keys. All + // writes are blocked until the key change succeeds, which is + // signaled by reading a msgNewKeys. + requestKeyChange() error + + // getSessionID returns the session ID. This is only valid + // after the first key change has completed. + getSessionID() []byte +} + +// handshakeTransport implements rekeying on top of a keyingTransport +// and offers a thread-safe writePacket() interface. +type handshakeTransport struct { + conn keyingTransport + config *Config + + serverVersion []byte + clientVersion []byte + + hostKeys []Signer // If hostKeys are given, we are the server. + + // On read error, incoming is closed, and readError is set. + incoming chan []byte + readError error + + // data for host key checking + hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + dialAddress string + remoteAddr net.Addr + + readSinceKex uint64 + + // Protects the writing side of the connection + mu sync.Mutex + cond *sync.Cond + sentInitPacket []byte + sentInitMsg *kexInitMsg + writtenSinceKex uint64 + writeError error +} + +func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { + t := &handshakeTransport{ + conn: conn, + serverVersion: serverVersion, + clientVersion: clientVersion, + incoming: make(chan []byte, 16), + config: config, + } + t.cond = sync.NewCond(&t.mu) + return t +} + +func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.dialAddress = dialAddr + t.remoteAddr = addr + t.hostKeyCallback = config.HostKeyCallback + go t.readLoop() + return t +} + +func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.hostKeys = config.hostKeys + go t.readLoop() + return t +} + +func (t *handshakeTransport) getSessionID() []byte { + return t.conn.getSessionID() +} + +func (t *handshakeTransport) id() string { + if len(t.hostKeys) > 0 { + return "server" + } + return "client" +} + +func (t *handshakeTransport) readPacket() ([]byte, error) { + p, ok := <-t.incoming + if !ok { + return nil, t.readError + } + return p, nil +} + +func (t *handshakeTransport) readLoop() { + for { + p, err := t.readOnePacket() + if err != nil { + t.readError = err + close(t.incoming) + break + } + if p[0] == msgIgnore || p[0] == msgDebug { + continue + } + t.incoming <- p + } +} + +func (t *handshakeTransport) readOnePacket() ([]byte, error) { + if t.readSinceKex > t.config.RekeyThreshold { + if err := t.requestKeyChange(); err != nil { + return nil, err + } + } + + p, err := t.conn.readPacket() + if err != nil { + return nil, err + } + + t.readSinceKex += uint64(len(p)) + if debugHandshake { + msg, err := decode(p) + log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) + } + if p[0] != msgKexInit { + return p, nil + } + err = t.enterKeyExchange(p) + + t.mu.Lock() + if err != nil { + // drop connection + t.conn.Close() + t.writeError = err + } + + if debugHandshake { + log.Printf("%s exited key exchange, err %v", t.id(), err) + } + + // Unblock writers. + t.sentInitMsg = nil + t.sentInitPacket = nil + t.cond.Broadcast() + t.writtenSinceKex = 0 + t.mu.Unlock() + + if err != nil { + return nil, err + } + + t.readSinceKex = 0 + return []byte{msgNewKeys}, nil +} + +// sendKexInit sends a key change message, and returns the message +// that was sent. After initiating the key change, all writes will be +// blocked until the change is done, and a failed key change will +// close the underlying transport. This function is safe for +// concurrent use by multiple goroutines. +func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) { + t.mu.Lock() + defer t.mu.Unlock() + return t.sendKexInitLocked() +} + +func (t *handshakeTransport) requestKeyChange() error { + _, _, err := t.sendKexInit() + return err +} + +// sendKexInitLocked sends a key change message. t.mu must be locked +// while this happens. +func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) { + // kexInits may be sent either in response to the other side, + // or because our side wants to initiate a key change, so we + // may have already sent a kexInit. In that case, don't send a + // second kexInit. + if t.sentInitMsg != nil { + return t.sentInitMsg, t.sentInitPacket, nil + } + msg := &kexInitMsg{ + KexAlgos: t.config.KeyExchanges, + CiphersClientServer: t.config.Ciphers, + CiphersServerClient: t.config.Ciphers, + MACsClientServer: t.config.MACs, + MACsServerClient: t.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + io.ReadFull(rand.Reader, msg.Cookie[:]) + + if len(t.hostKeys) > 0 { + for _, k := range t.hostKeys { + msg.ServerHostKeyAlgos = append( + msg.ServerHostKeyAlgos, k.PublicKey().Type()) + } + } else { + msg.ServerHostKeyAlgos = supportedHostKeyAlgos + } + packet := Marshal(msg) + + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + + if err := t.conn.writePacket(packetCopy); err != nil { + return nil, nil, err + } + + t.sentInitMsg = msg + t.sentInitPacket = packet + return msg, packet, nil +} + +func (t *handshakeTransport) writePacket(p []byte) error { + t.mu.Lock() + if t.writtenSinceKex > t.config.RekeyThreshold { + t.sendKexInitLocked() + } + for t.sentInitMsg != nil { + t.cond.Wait() + } + if t.writeError != nil { + return t.writeError + } + t.writtenSinceKex += uint64(len(p)) + + var err error + switch p[0] { + case msgKexInit: + err = errors.New("ssh: only handshakeTransport can send kexInit") + case msgNewKeys: + err = errors.New("ssh: only handshakeTransport can send newKeys") + default: + err = t.conn.writePacket(p) + } + t.mu.Unlock() + return err +} + +func (t *handshakeTransport) Close() error { + return t.conn.Close() +} + +// enterKeyExchange runs the key exchange. +func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { + if debugHandshake { + log.Printf("%s entered key exchange", t.id()) + } + myInit, myInitPacket, err := t.sendKexInit() + if err != nil { + return err + } + + otherInit := &kexInitMsg{} + if err := Unmarshal(otherInitPacket, otherInit); err != nil { + return err + } + + magics := handshakeMagics{ + clientVersion: t.clientVersion, + serverVersion: t.serverVersion, + clientKexInit: otherInitPacket, + serverKexInit: myInitPacket, + } + + clientInit := otherInit + serverInit := myInit + if len(t.hostKeys) == 0 { + clientInit = myInit + serverInit = otherInit + + magics.clientKexInit = myInitPacket + magics.serverKexInit = otherInitPacket + } + + algs := findAgreedAlgorithms(clientInit, serverInit) + if algs == nil { + return errors.New("ssh: no common algorithms") + } + + // We don't send FirstKexFollows, but we handle receiving it. + if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] { + // other side sent a kex message for the wrong algorithm, + // which we have to ignore. + if _, err := t.conn.readPacket(); err != nil { + return err + } + } + + kex, ok := kexAlgoMap[algs.kex] + if !ok { + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) + } + + var result *kexResult + if len(t.hostKeys) > 0 { + result, err = t.server(kex, algs, &magics) + } else { + result, err = t.client(kex, algs, &magics) + } + + if err != nil { + return err + } + + t.conn.prepareKeyChange(algs, result) + if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { + return err + } + if packet, err := t.conn.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + return nil +} + +func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + var hostKey Signer + for _, k := range t.hostKeys { + if algs.hostKey == k.PublicKey().Type() { + hostKey = k + } + } + + r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) + return r, err +} + +func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + result, err := kex.Client(t.conn, t.config.Rand, magics) + if err != nil { + return nil, err + } + + hostKey, err := ParsePublicKey(result.HostKey) + if err != nil { + return nil, err + } + + if err := verifyHostKeySignature(hostKey, result); err != nil { + return nil, err + } + + if t.hostKeyCallback != nil { + err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) + if err != nil { + return nil, err + } + } + + return result, nil +} diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go new file mode 100644 index 0000000..613c498 --- /dev/null +++ b/ssh/handshake_test.go @@ -0,0 +1,311 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "fmt" + "net" + "testing" +) + +type testChecker struct { + calls []string +} + +func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + if dialAddr == "bad" { + return fmt.Errorf("dialAddr is bad") + } + + if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { + return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) + } + + t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) + + return nil +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { + a, b, err := netPipe() + if err != nil { + return nil, nil, err + } + + trC := newTransport(a, rand.Reader, true) + trS := newTransport(b, rand.Reader, false) + clientConf.SetDefaults() + + v := []byte("version") + client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.SetDefaults() + server = newServerTransport(trS, v, v, serverConf) + + return client, server, nil +} + +func TestHandshakeBasic(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + go func() { + // Client writes a bunch of stuff, and does a key + // change in the middle. This should not confuse the + // handshake in progress + for i := 0; i < 10; i++ { + p := []byte{msgRequestSuccess, byte(i)} + if err := trC.writePacket(p); err != nil { + t.Fatalf("sendPacket: %v", err) + } + if i == 5 { + // halfway through, we request a key change. + _, _, err := trC.sendKexInit() + if err != nil { + t.Fatalf("sendKexInit: %v", err) + } + } + } + trC.Close() + }() + + // Server checks that client messages come in cleanly + i := 0 + for { + p, err := trS.readPacket() + if err != nil { + break + } + if p[0] == msgNewKeys { + continue + } + want := []byte{msgRequestSuccess, byte(i)} + if bytes.Compare(p, want) != 0 { + t.Errorf("message %d: got %q, want %q", i, p, want) + } + i++ + } + if i != 10 { + t.Errorf("received %d messages, want 10.", i) + } + + // If all went well, we registered exactly 1 key change. + if len(checker.calls) != 1 { + t.Fatalf("got %d host key checks, want 1", len(checker.calls)) + } + + pub := testSigners["ecdsa"].PublicKey() + want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) + if want != checker.calls[0] { + t.Errorf("got %q want %q for host key check", checker.calls[0], want) + } +} + +func TestHandshakeError(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + // send a packet + packet := []byte{msgRequestSuccess, 42} + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + // Now request a key change. + _, _, err = trC.sendKexInit() + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + // the key change will fail, and afterwards we can't write. + if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { + t.Errorf("writePacket after botched rekey succeeded.") + } + + readback, err := trS.readPacket() + if err != nil { + t.Fatalf("server closed too soon: %v", err) + } + if bytes.Compare(readback, packet) != 0 { + t.Errorf("got %q want %q", readback, packet) + } + readback, err = trS.readPacket() + if err == nil { + t.Errorf("got a message %q after failed key change", readback) + } +} + +func TestHandshakeTwice(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + // send a packet + packet := make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + // Now request a key change. + _, _, err = trC.sendKexInit() + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + // Send another packet. Use a fresh one, since writePacket destroys. + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + // 2nd key change. + _, _, err = trC.sendKexInit() + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + for i := 0; i < 5; i++ { + msg, err := trS.readPacket() + if err != nil { + t.Fatalf("server closed too soon: %v", err) + } + if msg[0] == msgNewKeys { + continue + } + + if bytes.Compare(msg, packet) != 0 { + t.Errorf("packet %d: got %q want %q", i, msg, packet) + } + } + if len(checker.calls) != 2 { + t.Errorf("got %d key changes, want 2", len(checker.calls)) + } +} + +func TestHandshakeAutoRekeyWrite(t *testing.T) { + checker := &testChecker{} + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.RekeyThreshold = 500 + trC, trS, err := handshakePair(clientConf, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + for i := 0; i < 5; i++ { + packet := make([]byte, 251) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + } + + j := 0 + for ; j < 5; j++ { + _, err := trS.readPacket() + if err != nil { + break + } + } + + if j != 5 { + t.Errorf("got %d, want 5 messages", j) + } + + if len(checker.calls) != 2 { + t.Errorf("got %d key changes, wanted 2", len(checker.calls)) + } +} + +type syncChecker struct { + called chan int +} + +func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + t.called <- 1 + return nil +} + +func TestHandshakeAutoRekeyRead(t *testing.T) { + sync := &syncChecker{make(chan int, 2)} + clientConf := &ClientConfig{ + HostKeyCallback: sync.Check, + } + clientConf.RekeyThreshold = 500 + + trC, trS, err := handshakePair(clientConf, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + packet := make([]byte, 501) + packet[0] = msgRequestSuccess + if err := trS.writePacket(packet); err != nil { + t.Fatalf("writePacket: %v", err) + } + // While we read out the packet, a key change will be + // initiated. + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket(client): %v", err) + } + + <-sync.called +} @@ -30,10 +30,10 @@ type kexResult struct { // Shared secret. See also RFC 4253, section 8. K []byte - // Host key as hashed into H + // Host key as hashed into H. HostKey []byte - // Signature of H + // Signature of H. Signature []byte // A cryptographic hash function that matches the security @@ -94,7 +94,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha kexDHInit := kexDHInitMsg{ X: X, } - if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil { + if err := c.writePacket(Marshal(&kexDHInit)); err != nil { return nil, err } @@ -104,7 +104,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha } var kexDHReply kexDHReplyMsg - if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil { + if err = Unmarshal(packet, &kexDHReply); err != nil { return nil, err } @@ -138,7 +138,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha return } var kexDHInit kexDHInitMsg - if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil { + if err = Unmarshal(packet, &kexDHInit); err != nil { return } @@ -153,7 +153,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha return nil, err } - hostKeyBytes := MarshalPublicKey(priv.PublicKey()) + hostKeyBytes := priv.PublicKey().Marshal() h := hashFunc.New() magics.write(h) @@ -179,7 +179,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha Y: Y, Signature: sig, } - packet = marshal(msgKexDHReply, kexDHReply) + packet = Marshal(&kexDHReply) err = c.writePacket(packet) return &kexResult{ @@ -207,7 +207,7 @@ func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) ( ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), } - serialized := marshal(msgKexECDHInit, kexInit) + serialized := Marshal(&kexInit) if err := c.writePacket(serialized); err != nil { return nil, err } @@ -218,7 +218,7 @@ func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) ( } var reply kexECDHReplyMsg - if err = unmarshal(&reply, packet, msgKexECDHReply); err != nil { + if err = Unmarshal(packet, &reply); err != nil { return nil, err } @@ -297,7 +297,7 @@ func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, p } var kexECDHInit kexECDHInitMsg - if err = unmarshal(&kexECDHInit, packet, msgKexECDHInit); err != nil { + if err = Unmarshal(packet, &kexECDHInit); err != nil { return nil, err } @@ -314,7 +314,7 @@ func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, p return nil, err } - hostKeyBytes := MarshalPublicKey(priv.PublicKey()) + hostKeyBytes := priv.PublicKey().Marshal() serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) @@ -346,7 +346,7 @@ func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, p Signature: sig, } - serialized := marshal(msgKexECDHReply, reply) + serialized := Marshal(&reply) if err := c.writePacket(serialized); err != nil { return nil, err } diff --git a/ssh/kex_test.go b/ssh/kex_test.go index 1e931a3..0db5f9b 100644 --- a/ssh/kex_test.go +++ b/ssh/kex_test.go @@ -29,7 +29,7 @@ func TestKexes(t *testing.T) { c <- kexResultErr{r, e} }() go func() { - r, e := kex.Server(b, rand.Reader, &magics, ecdsaKey) + r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"]) s <- kexResultErr{r, e} }() diff --git a/ssh/keys.go b/ssh/keys.go index b41fefc..e8af511 100644 --- a/ssh/keys.go +++ b/ssh/keys.go @@ -33,7 +33,7 @@ const ( // parsePubKey parses a public key of the given algorithm. // Use ParsePublicKey for keys with prepended algorithm. -func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, ok bool) { +func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { switch algo { case KeyAlgoRSA: return parseRSA(in) @@ -42,15 +42,19 @@ func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, ok bool case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: return parseECDSA(in) case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: - return parseOpenSSHCertV01(in, algo) + cert, err := parseCert(in, certToPrivAlgo(algo)) + if err != nil { + return nil, nil, err + } + return cert, nil, nil } - return nil, nil, false + return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err) } // parseAuthorizedKey parses a public key in OpenSSH authorized_keys format // (see sshd(8) manual page) once the options and key type fields have been // removed. -func parseAuthorizedKey(in []byte) (out PublicKey, comment string, ok bool) { +func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { in = bytes.TrimSpace(in) i := bytes.IndexAny(in, " \t") @@ -62,20 +66,20 @@ func parseAuthorizedKey(in []byte) (out PublicKey, comment string, ok bool) { key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) n, err := base64.StdEncoding.Decode(key, base64Key) if err != nil { - return + return nil, "", err } key = key[:n] - out, _, ok = ParsePublicKey(key) - if !ok { - return nil, "", false + out, err = ParsePublicKey(key) + if err != nil { + return nil, "", err } comment = string(bytes.TrimSpace(in[i:])) - return + return out, comment, nil } // ParseAuthorizedKeys parses a public key from an authorized_keys // file used in OpenSSH according to the sshd(8) manual page. -func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, ok bool) { +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { for len(in) > 0 { end := bytes.IndexByte(in, '\n') if end != -1 { @@ -102,8 +106,8 @@ func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []str continue } - if out, comment, ok = parseAuthorizedKey(in[i:]); ok { - return + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + return out, comment, options, rest, nil } // No key type recognised. Maybe there's an options field at @@ -143,38 +147,42 @@ func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []str continue } - if out, comment, ok = parseAuthorizedKey(in[i:]); ok { + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { options = candidateOptions - return + return out, comment, options, rest, nil } in = rest continue } - return + return nil, "", nil, nil, errors.New("ssh: no key found") } // ParsePublicKey parses an SSH public key formatted for use in // the SSH wire protocol according to RFC 4253, section 6.6. -func ParsePublicKey(in []byte) (out PublicKey, rest []byte, ok bool) { +func ParsePublicKey(in []byte) (out PublicKey, err error) { algo, in, ok := parseString(in) if !ok { - return + return nil, errShortRead + } + var rest []byte + out, rest, err = parsePubKey(in, string(algo)) + if len(rest) > 0 { + return nil, errors.New("ssh: trailing junk in public key") } - return parsePubKey(in, string(algo)) + return out, err } -// MarshalAuthorizedKey returns a byte stream suitable for inclusion -// in an OpenSSH authorized_keys file following the format specified -// in the sshd(8) manual page. +// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH +// authorized_keys file. The return value ends with newline. func MarshalAuthorizedKey(key PublicKey) []byte { b := &bytes.Buffer{} - b.WriteString(key.PublicKeyAlgo()) + b.WriteString(key.Type()) b.WriteByte(' ') e := base64.NewEncoder(base64.StdEncoding, b) - e.Write(MarshalPublicKey(key)) + e.Write(key.Marshal()) e.Close() b.WriteByte('\n') return b.Bytes() @@ -182,84 +190,81 @@ func MarshalAuthorizedKey(key PublicKey) []byte { // PublicKey is an abstraction of different types of public keys. type PublicKey interface { - // PrivateKeyAlgo returns the name of the encryption system. - PrivateKeyAlgo() string - - // PublicKeyAlgo returns the algorithm for the public key, - // which may be different from PrivateKeyAlgo for certificates. - PublicKeyAlgo() string + // Type returns the key's type, e.g. "ssh-rsa". + Type() string // Marshal returns the serialized key data in SSH wire format, - // without the name prefix. Callers should typically use - // MarshalPublicKey(). + // with the name prefix. Marshal() []byte // Verify that sig is a signature on the given data using this // key. This function will hash the data appropriately first. - Verify(data []byte, sigBlob []byte) bool + Verify(data []byte, sig *Signature) error } -// A Signer is can create signatures that verify against a public key. +// A Signer can create signatures that verify against a public key. type Signer interface { // PublicKey returns an associated PublicKey instance. PublicKey() PublicKey // Sign returns raw signature for the given data. This method // will apply the hash specified for the keytype to the data. - Sign(rand io.Reader, data []byte) ([]byte, error) + Sign(rand io.Reader, data []byte) (*Signature, error) } type rsaPublicKey rsa.PublicKey -func (r *rsaPublicKey) PrivateKeyAlgo() string { +func (r *rsaPublicKey) Type() string { return "ssh-rsa" } -func (r *rsaPublicKey) PublicKeyAlgo() string { - return r.PrivateKeyAlgo() -} - // parseRSA parses an RSA key according to RFC 4253, section 6.6. -func parseRSA(in []byte) (out PublicKey, rest []byte, ok bool) { - key := new(rsa.PublicKey) - - bigE, in, ok := parseInt(in) - if !ok || bigE.BitLen() > 24 { - return +func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + E *big.Int + N *big.Int + Rest []byte `ssh:"rest"` } - e := bigE.Int64() - if e < 3 || e&1 == 0 { - ok = false - return + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err } - key.E = int(e) - if key.N, in, ok = parseInt(in); !ok { - return + if w.E.BitLen() > 24 { + return nil, nil, errors.New("ssh: exponent too large") + } + e := w.E.Int64() + if e < 3 || e&1 == 0 { + return nil, nil, errors.New("ssh: incorrect exponent") } - ok = true - return (*rsaPublicKey)(key), in, ok + var key rsa.PublicKey + key.E = int(e) + key.N = w.N + return (*rsaPublicKey)(&key), w.Rest, nil } func (r *rsaPublicKey) Marshal() []byte { - // See RFC 4253, section 6.6. e := new(big.Int).SetInt64(int64(r.E)) - length := intLength(e) - length += intLength(r.N) - - ret := make([]byte, length) - rest := marshalInt(ret, e) - marshalInt(rest, r.N) - - return ret + wirekey := struct { + Name string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + e, + r.N, + } + return Marshal(&wirekey) } -func (r *rsaPublicKey) Verify(data []byte, sig []byte) bool { +func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != r.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) + } h := crypto.SHA1.New() h.Write(data) digest := h.Sum(nil) - return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig) == nil + return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) } type rsaPrivateKey struct { @@ -270,64 +275,66 @@ func (r *rsaPrivateKey) PublicKey() PublicKey { return (*rsaPublicKey)(&r.PrivateKey.PublicKey) } -func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) { +func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { h := crypto.SHA1.New() h.Write(data) digest := h.Sum(nil) - return rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest) + blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest) + if err != nil { + return nil, err + } + return &Signature{ + Format: r.PublicKey().Type(), + Blob: blob, + }, nil } type dsaPublicKey dsa.PublicKey -func (r *dsaPublicKey) PrivateKeyAlgo() string { +func (r *dsaPublicKey) Type() string { return "ssh-dss" } -func (r *dsaPublicKey) PublicKeyAlgo() string { - return r.PrivateKeyAlgo() -} - // parseDSA parses an DSA key according to RFC 4253, section 6.6. -func parseDSA(in []byte) (out PublicKey, rest []byte, ok bool) { - key := new(dsa.PublicKey) - - if key.P, in, ok = parseInt(in); !ok { - return - } - - if key.Q, in, ok = parseInt(in); !ok { - return +func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + P, Q, G, Y *big.Int + Rest []byte `ssh:"rest"` } - - if key.G, in, ok = parseInt(in); !ok { - return + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err } - if key.Y, in, ok = parseInt(in); !ok { - return + key := &dsaPublicKey{ + Parameters: dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + }, + Y: w.Y, } - - ok = true - return (*dsaPublicKey)(key), in, ok + return key, w.Rest, nil } -func (r *dsaPublicKey) Marshal() []byte { - // See RFC 4253, section 6.6. - length := intLength(r.P) - length += intLength(r.Q) - length += intLength(r.G) - length += intLength(r.Y) - - ret := make([]byte, length) - rest := marshalInt(ret, r.P) - rest = marshalInt(rest, r.Q) - rest = marshalInt(rest, r.G) - marshalInt(rest, r.Y) +func (k *dsaPublicKey) Marshal() []byte { + w := struct { + Name string + P, Q, G, Y *big.Int + }{ + k.Type(), + k.P, + k.Q, + k.G, + k.Y, + } - return ret + return Marshal(&w) } -func (k *dsaPublicKey) Verify(data []byte, sigBlob []byte) bool { +func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } h := crypto.SHA1.New() h.Write(data) digest := h.Sum(nil) @@ -337,12 +344,15 @@ func (k *dsaPublicKey) Verify(data []byte, sigBlob []byte) bool { // r, followed by s (which are 160-bit integers, without lengths or // padding, unsigned, and in network byte order). // For DSS purposes, sig.Blob should be exactly 40 bytes in length. - if len(sigBlob) != 40 { - return false + if len(sig.Blob) != 40 { + return errors.New("ssh: DSA signature parse error") } - r := new(big.Int).SetBytes(sigBlob[:20]) - s := new(big.Int).SetBytes(sigBlob[20:]) - return dsa.Verify((*dsa.PublicKey)(k), digest, r, s) + r := new(big.Int).SetBytes(sig.Blob[:20]) + s := new(big.Int).SetBytes(sig.Blob[20:]) + if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { + return nil + } + return errors.New("ssh: signature did not verify") } type dsaPrivateKey struct { @@ -353,7 +363,7 @@ func (k *dsaPrivateKey) PublicKey() PublicKey { return (*dsaPublicKey)(&k.PrivateKey.PublicKey) } -func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) { +func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { h := crypto.SHA1.New() h.Write(data) digest := h.Sum(nil) @@ -363,14 +373,21 @@ func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) { } sig := make([]byte, 40) - copy(sig[:20], r.Bytes()) - copy(sig[20:], s.Bytes()) - return sig, nil + rb := r.Bytes() + sb := s.Bytes() + + copy(sig[20-len(rb):20], rb) + copy(sig[40-len(sb):], sb) + + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil } type ecdsaPublicKey ecdsa.PublicKey -func (key *ecdsaPublicKey) PrivateKeyAlgo() string { +func (key *ecdsaPublicKey) Type() string { return "ecdsa-sha2-" + key.nistID() } @@ -387,7 +404,7 @@ func (key *ecdsaPublicKey) nistID() string { } func supportedEllipticCurve(curve elliptic.Curve) bool { - return (curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521()) + return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() } // ecHash returns the hash to match the given elliptic curve, see RFC @@ -403,15 +420,11 @@ func ecHash(curve elliptic.Curve) crypto.Hash { return crypto.SHA512 } -func (key *ecdsaPublicKey) PublicKeyAlgo() string { - return key.PrivateKeyAlgo() -} - // parseECDSA parses an ECDSA key according to RFC 5656, section 3.1. -func parseECDSA(in []byte) (out PublicKey, rest []byte, ok bool) { - var identifier []byte - if identifier, in, ok = parseString(in); !ok { - return +func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { + identifier, in, ok := parseString(in) + if !ok { + return nil, nil, errShortRead } key := new(ecdsa.PublicKey) @@ -424,38 +437,42 @@ func parseECDSA(in []byte) (out PublicKey, rest []byte, ok bool) { case "nistp521": key.Curve = elliptic.P521() default: - ok = false - return + return nil, nil, errors.New("ssh: unsupported curve") } var keyBytes []byte if keyBytes, in, ok = parseString(in); !ok { - return + return nil, nil, errShortRead } key.X, key.Y = elliptic.Unmarshal(key.Curve, keyBytes) if key.X == nil || key.Y == nil { - ok = false - return + return nil, nil, errors.New("ssh: invalid curve point") } - return (*ecdsaPublicKey)(key), in, ok + return (*ecdsaPublicKey)(key), in, nil } func (key *ecdsaPublicKey) Marshal() []byte { // See RFC 5656, section 3.1. keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) + w := struct { + Name string + ID string + Key []byte + }{ + key.Type(), + key.nistID(), + keyBytes, + } - ID := key.nistID() - length := stringLength(len(ID)) - length += stringLength(len(keyBytes)) - - ret := make([]byte, length) - r := marshalString(ret, []byte(ID)) - r = marshalString(r, keyBytes) - return ret + return Marshal(&w) } -func (key *ecdsaPublicKey) Verify(data []byte, sigBlob []byte) bool { +func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != key.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) + } + h := ecHash(key.Curve).New() h.Write(data) digest := h.Sum(nil) @@ -464,15 +481,19 @@ func (key *ecdsaPublicKey) Verify(data []byte, sigBlob []byte) bool { // The ecdsa_signature_blob value has the following specific encoding: // mpint r // mpint s - r, rest, ok := parseInt(sigBlob) - if !ok { - return false + var ecSig struct { + R *big.Int + S *big.Int + } + + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err } - s, rest, ok := parseInt(rest) - if !ok || len(rest) > 0 { - return false + + if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { + return nil } - return ecdsa.Verify((*ecdsa.PublicKey)(key), digest, r, s) + return errors.New("ssh: signature did not verify") } type ecdsaPrivateKey struct { @@ -483,7 +504,7 @@ func (k *ecdsaPrivateKey) PublicKey() PublicKey { return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey) } -func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) { +func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { h := ecHash(k.PrivateKey.PublicKey.Curve).New() h.Write(data) digest := h.Sum(nil) @@ -495,10 +516,13 @@ func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) { sig := make([]byte, intLength(r)+intLength(s)) rest := marshalInt(sig, r) marshalInt(rest, s) - return sig, nil + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil } -// NewPrivateKey takes a pointer to rsa, dsa or ecdsa PrivateKey +// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey // returns a corresponding Signer instance. EC keys should use P256, // P384 or P521. func NewSignerFromKey(k interface{}) (Signer, error) { @@ -540,54 +564,49 @@ func NewPublicKey(k interface{}) (PublicKey, error) { return sshKey, nil } -// ParsePublicKey parses a PEM encoded private key. It supports -// PKCS#1, RSA, DSA and ECDSA private keys. +// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports +// the same keys as ParseRawPrivateKey. func ParsePrivateKey(pemBytes []byte) (Signer, error) { + key, err := ParseRawPrivateKey(pemBytes) + if err != nil { + return nil, err + } + + return NewSignerFromKey(key) +} + +// ParseRawPrivateKey returns a private key from a PEM encoded private key. It +// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys. +func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { block, _ := pem.Decode(pemBytes) if block == nil { return nil, errors.New("ssh: no key found") } - var rawkey interface{} switch block.Type { case "RSA PRIVATE KEY": - rsa, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return nil, err - } - rawkey = rsa + return x509.ParsePKCS1PrivateKey(block.Bytes) case "EC PRIVATE KEY": - ec, err := x509.ParseECPrivateKey(block.Bytes) - if err != nil { - return nil, err - } - rawkey = ec + return x509.ParseECPrivateKey(block.Bytes) case "DSA PRIVATE KEY": - ec, err := parseDSAPrivate(block.Bytes) - if err != nil { - return nil, err - } - rawkey = ec + return ParseDSAPrivateKey(block.Bytes) default: return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) } - - return NewSignerFromKey(rawkey) } -// parseDSAPrivate parses a DSA key in ASN.1 DER encoding, as -// documented in the OpenSSL DSA manpage. -// TODO(hanwen): move this in to crypto/x509 after the Go 1.2 freeze. -func parseDSAPrivate(p []byte) (*dsa.PrivateKey, error) { - k := struct { +// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as +// specified by the OpenSSL DSA man page. +func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { + var k struct { Version int P *big.Int Q *big.Int G *big.Int Priv *big.Int Pub *big.Int - }{} - rest, err := asn1.Unmarshal(p, &k) + } + rest, err := asn1.Unmarshal(der, &k) if err != nil { return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) } diff --git a/ssh/keys_test.go b/ssh/keys_test.go index 3c4b735..cd49565 100644 --- a/ssh/keys_test.go +++ b/ssh/keys_test.go @@ -1,66 +1,25 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package ssh import ( + "bytes" "crypto/dsa" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" + "encoding/base64" + "fmt" "reflect" "strings" "testing" -) -var ( - ecdsaKey Signer - ecdsa384Key Signer - ecdsa521Key Signer - testCertKey Signer + "code.google.com/p/go.crypto/ssh/testdata" ) -type testSigner struct { - Signer - pub PublicKey -} - -func (ts *testSigner) PublicKey() PublicKey { - if ts.pub != nil { - return ts.pub - } - return ts.Signer.PublicKey() -} - -func init() { - raw256, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - ecdsaKey, _ = NewSignerFromKey(raw256) - - raw384, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - ecdsa384Key, _ = NewSignerFromKey(raw384) - - raw521, _ := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - ecdsa521Key, _ = NewSignerFromKey(raw521) - - // Create a cert and sign it for use in tests. - testCert := &OpenSSHCertV01{ - Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil - Key: ecdsaKey.PublicKey(), - ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage - ValidAfter: 0, // unix epoch - ValidBefore: maxUint64, // The end of currently representable time. - Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil - SignatureKey: rsaKey.PublicKey(), - } - sigBytes, _ := rsaKey.Sign(rand.Reader, testCert.BytesForSigning()) - testCert.Signature = &signature{ - Format: testCert.SignatureKey.PublicKeyAlgo(), - Blob: sigBytes, - } - testCertKey = &testSigner{ - Signer: ecdsaKey, - pub: testCert, - } -} - func rawKey(pub PublicKey) interface{} { switch k := pub.(type) { case *rsaPublicKey: @@ -69,23 +28,18 @@ func rawKey(pub PublicKey) interface{} { return (*dsa.PublicKey)(k) case *ecdsaPublicKey: return (*ecdsa.PublicKey)(k) - case *OpenSSHCertV01: + case *Certificate: return k } panic("unknown key type") } func TestKeyMarshalParse(t *testing.T) { - keys := []Signer{rsaKey, dsaKey, ecdsaKey, ecdsa384Key, ecdsa521Key, testCertKey} - for _, priv := range keys { + for _, priv := range testSigners { pub := priv.PublicKey() - roundtrip, rest, ok := ParsePublicKey(MarshalPublicKey(pub)) - if !ok { - t.Errorf("ParsePublicKey(%T) failed", pub) - } - - if len(rest) > 0 { - t.Errorf("ParsePublicKey(%T): trailing junk", pub) + roundtrip, err := ParsePublicKey(pub.Marshal()) + if err != nil { + t.Errorf("ParsePublicKey(%T): %v", pub, err) } k1 := rawKey(pub) @@ -113,9 +67,12 @@ func TestUnsupportedCurves(t *testing.T) { } func TestNewPublicKey(t *testing.T) { - keys := []Signer{rsaKey, dsaKey, ecdsaKey} - for _, k := range keys { + for _, k := range testSigners { raw := rawKey(k.PublicKey()) + // Skip certificates, as NewPublicKey does not support them. + if _, ok := raw.(*Certificate); ok { + continue + } pub, err := NewPublicKey(raw) if err != nil { t.Errorf("NewPublicKey(%#v): %v", raw, err) @@ -127,8 +84,7 @@ func TestNewPublicKey(t *testing.T) { } func TestKeySignVerify(t *testing.T) { - keys := []Signer{rsaKey, dsaKey, ecdsaKey, testCertKey} - for _, priv := range keys { + for _, priv := range testSigners { pub := priv.PublicKey() data := []byte("sign me") @@ -137,19 +93,20 @@ func TestKeySignVerify(t *testing.T) { t.Fatalf("Sign(%T): %v", priv, err) } - if !pub.Verify(data, sig) { - t.Errorf("publicKey.Verify(%T) failed", priv) + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") } } } func TestParseRSAPrivateKey(t *testing.T) { - key, err := ParsePrivateKey([]byte(testServerPrivateKey)) - if err != nil { - t.Fatalf("ParsePrivateKey: %v", err) - } + key := testPrivateKeys["rsa"] - rsa, ok := key.(*rsaPrivateKey) + rsa, ok := key.(*rsa.PrivateKey) if !ok { t.Fatalf("got %T, want *rsa.PrivateKey", rsa) } @@ -160,21 +117,11 @@ func TestParseRSAPrivateKey(t *testing.T) { } func TestParseECPrivateKey(t *testing.T) { - // Taken from the data in test/ . - pem := []byte(`-----BEGIN EC PRIVATE KEY----- -MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49 -AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ -6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA== ------END EC PRIVATE KEY-----`) + key := testPrivateKeys["ecdsa"] - key, err := ParsePrivateKey(pem) - if err != nil { - t.Fatalf("ParsePrivateKey: %v", err) - } - - ecKey, ok := key.(*ecdsaPrivateKey) + ecKey, ok := key.(*ecdsa.PrivateKey) if !ok { - t.Fatalf("got %T, want *ecdsaPrivateKey", ecKey) + t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) } if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { @@ -182,22 +129,11 @@ AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ } } -// ssh-keygen -t dsa -f /tmp/idsa.pem -var dsaPEM = `-----BEGIN DSA PRIVATE KEY----- -MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB -lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3 -EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD -nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV -2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r -juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr -FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz -DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj -nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY -Fmsr0W6fHB9nhS4/UXM8 ------END DSA PRIVATE KEY-----` - func TestParseDSA(t *testing.T) { - s, err := ParsePrivateKey([]byte(dsaPEM)) + // We actually exercise the ParsePrivateKey codepath here, as opposed to + // using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go + // uses. + s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) if err != nil { t.Fatalf("ParsePrivateKey returned error: %s", err) } @@ -208,7 +144,163 @@ func TestParseDSA(t *testing.T) { t.Fatalf("dsa.Sign: %v", err) } - if !s.PublicKey().Verify(data, sig) { - t.Error("Verify failed.") + if err := s.PublicKey().Verify(data, sig); err != nil { + t.Errorf("Verify failed: %v", err) + } +} + +// Tests for authorized_keys parsing. + +// getTestKey returns a public key, and its base64 encoding. +func getTestKey() (PublicKey, string) { + k := testPublicKeys["rsa"] + + b := &bytes.Buffer{} + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(k.Marshal()) + e.Close() + + return k, b.String() +} + +func TestMarshalParsePublicKey(t *testing.T) { + pub, pubSerialized := getTestKey() + line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) + + authKeys := MarshalAuthorizedKey(pub) + actualFields := strings.Fields(string(authKeys)) + if len(actualFields) == 0 { + t.Fatalf("failed authKeys: %v", authKeys) + } + + // drop the comment + expectedFields := strings.Fields(line)[0:2] + + if !reflect.DeepEqual(actualFields, expectedFields) { + t.Errorf("got %v, expected %v", actualFields, expectedFields) + } + + actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) + if err != nil { + t.Fatalf("cannot parse %v: %v", line, err) + } + if !reflect.DeepEqual(actPub, pub) { + t.Errorf("got %v, expected %v", actPub, pub) + } +} + +type authResult struct { + pubKey PublicKey + options []string + comments string + rest string + ok bool +} + +func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) { + rest := authKeys + var values []authResult + for len(rest) > 0 { + var r authResult + var err error + r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) + r.ok = (err == nil) + t.Log(err) + r.rest = string(rest) + values = append(values, r) + } + + if !reflect.DeepEqual(values, expected) { + t.Errorf("got %#v, expected %#v", values, expected) + } +} + +func TestAuthorizedKeyBasic(t *testing.T) { + pub, pubSerialized := getTestKey() + line := "ssh-rsa " + pubSerialized + " user@host" + testAuthorizedKeys(t, []byte(line), + []authResult{ + {pub, nil, "user@host", "", true}, + }) +} + +func TestAuth(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithOptions := []string{ + `# comments to ignore before any keys...`, + ``, + `env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, + `# comments to ignore, along with a blank line`, + ``, + `env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, + ``, + `# more comments, plus a invalid entry`, + `ssh-rsa data-that-will-not-parse user@host3`, + } + for _, eol := range []string{"\n", "\r\n"} { + authOptions := strings.Join(authWithOptions, eol) + rest2 := strings.Join(authWithOptions[3:], eol) + rest3 := strings.Join(authWithOptions[6:], eol) + testAuthorizedKeys(t, []byte(authOptions), []authResult{ + {pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, + {pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, + {nil, nil, "", "", false}, + }) + } +} + +func TestAuthWithQuotedSpaceInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{ + {pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedCommaInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{ + {pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedQuoteInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) + authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{ + {pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) + + testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{ + {pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, + }) +} + +func TestAuthWithInvalidSpace(t *testing.T) { + _, pubSerialized := getTestKey() + authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +#more to follow but still no valid keys`) + testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{ + {nil, nil, "", "", false}, + }) +} + +func TestAuthWithMissingQuote(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) + + testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{ + {pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, + }) +} + +func TestInvalidEntry(t *testing.T) { + authInvalid := []byte(`ssh-rsa`) + _, _, _, _, err := ParseAuthorizedKey(authInvalid) + if err == nil { + t.Errorf("got valid entry for %q", authInvalid) } } @@ -43,11 +43,6 @@ func (t truncatingMAC) Size() int { func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } -// Specifies a default set of MAC algorithms and a preference order. -// This is based on RFC 4253, section 6.4, with the removal of the -// hmac-md5 variants as they have reached the end of their useful life. -var DefaultMACOrder = []string{"hmac-sha1", "hmac-sha1-96"} - var macModes = map[string]*macMode{ "hmac-sha1": {20, func(key []byte) hash.Hash { return hmac.New(sha1.New, key) diff --git a/ssh/mempipe_test.go b/ssh/mempipe_test.go index ec1b854..69217a4 100644 --- a/ssh/mempipe_test.go +++ b/ssh/mempipe_test.go @@ -36,17 +36,23 @@ func (t *memTransport) readPacket() ([]byte, error) { } } -func (t *memTransport) Close() error { - t.write.Lock() - defer t.write.Unlock() - if t.write.eof { +func (t *memTransport) closeSelf() error { + t.Lock() + defer t.Unlock() + if t.eof { return io.EOF } - t.write.eof = true - t.write.Cond.Broadcast() + t.eof = true + t.Cond.Broadcast() return nil } +func (t *memTransport) Close() error { + err := t.write.closeSelf() + t.closeSelf() + return err +} + func (t *memTransport) writePacket(p []byte) error { t.write.Lock() defer t.write.Unlock() diff --git a/ssh/messages.go b/ssh/messages.go index 94c3ea0..f9e44bb 100644 --- a/ssh/messages.go +++ b/ssh/messages.go @@ -7,58 +7,25 @@ package ssh import ( "bytes" "encoding/binary" + "errors" + "fmt" "io" "math/big" "reflect" + "strconv" ) // These are SSH message type numbers. They are scattered around several // documents but many were taken from [SSH-PARAMETERS]. const ( - msgDisconnect = 1 - msgIgnore = 2 - msgUnimplemented = 3 - msgDebug = 4 - msgServiceRequest = 5 - msgServiceAccept = 6 - - msgKexInit = 20 - msgNewKeys = 21 - - // Diffie-Helman - msgKexDHInit = 30 - msgKexDHReply = 31 - - msgKexECDHInit = 30 - msgKexECDHReply = 31 + msgIgnore = 2 + msgUnimplemented = 3 + msgDebug = 4 + msgNewKeys = 21 // Standard authentication messages - msgUserAuthRequest = 50 - msgUserAuthFailure = 51 - msgUserAuthSuccess = 52 - msgUserAuthBanner = 53 - msgUserAuthPubKeyOk = 60 - - // Method specific messages - msgUserAuthInfoRequest = 60 - msgUserAuthInfoResponse = 61 - - msgGlobalRequest = 80 - msgRequestSuccess = 81 - msgRequestFailure = 82 - - // Channel manipulation - msgChannelOpen = 90 - msgChannelOpenConfirm = 91 - msgChannelOpenFailure = 92 - msgChannelWindowAdjust = 93 - msgChannelData = 94 - msgChannelExtendedData = 95 - msgChannelEOF = 96 - msgChannelClose = 97 - msgChannelRequest = 98 - msgChannelSuccess = 99 - msgChannelFailure = 100 + msgUserAuthSuccess = 52 + msgUserAuthBanner = 53 ) // SSH messages: @@ -69,15 +36,25 @@ const ( // ssh tag of "rest" receives the remainder of a packet when unmarshaling. // See RFC 4253, section 11.1. +const msgDisconnect = 1 + +// disconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() type disconnectMsg struct { - Reason uint32 + Reason uint32 `sshtype:"1"` Message string Language string } +func (d *disconnectMsg) Error() string { + return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message) +} + // See RFC 4253, section 7.1. +const msgKexInit = 20 + type kexInitMsg struct { - Cookie [16]byte + Cookie [16]byte `sshtype:"20"` KexAlgos []string ServerHostKeyAlgos []string CiphersClientServer []string @@ -93,53 +70,74 @@ type kexInitMsg struct { } // See RFC 4253, section 8. + +// Diffie-Helman +const msgKexDHInit = 30 + type kexDHInitMsg struct { - X *big.Int + X *big.Int `sshtype:"30"` } +const msgKexECDHInit = 30 + type kexECDHInitMsg struct { - ClientPubKey []byte + ClientPubKey []byte `sshtype:"30"` } +const msgKexECDHReply = 31 + type kexECDHReplyMsg struct { - HostKey []byte + HostKey []byte `sshtype:"31"` EphemeralPubKey []byte Signature []byte } +const msgKexDHReply = 31 + type kexDHReplyMsg struct { - HostKey []byte + HostKey []byte `sshtype:"31"` Y *big.Int Signature []byte } // See RFC 4253, section 10. +const msgServiceRequest = 5 + type serviceRequestMsg struct { - Service string + Service string `sshtype:"5"` } // See RFC 4253, section 10. +const msgServiceAccept = 6 + type serviceAcceptMsg struct { - Service string + Service string `sshtype:"6"` } // See RFC 4252, section 5. +const msgUserAuthRequest = 50 + type userAuthRequestMsg struct { - User string + User string `sshtype:"50"` Service string Method string Payload []byte `ssh:"rest"` } // See RFC 4252, section 5.1 +const msgUserAuthFailure = 51 + type userAuthFailureMsg struct { - Methods []string + Methods []string `sshtype:"51"` PartialSuccess bool } // See RFC 4256, section 3.2 +const msgUserAuthInfoRequest = 60 +const msgUserAuthInfoResponse = 61 + type userAuthInfoRequestMsg struct { - User string + User string `sshtype:"60"` Instruction string DeprecatedLanguage string NumPrompts uint32 @@ -147,17 +145,24 @@ type userAuthInfoRequestMsg struct { } // See RFC 4254, section 5.1. +const msgChannelOpen = 90 + type channelOpenMsg struct { - ChanType string + ChanType string `sshtype:"90"` PeersId uint32 PeersWindow uint32 MaxPacketSize uint32 TypeSpecificData []byte `ssh:"rest"` } +const msgChannelExtendedData = 95 +const msgChannelData = 94 + // See RFC 4254, section 5.1. +const msgChannelOpenConfirm = 91 + type channelOpenConfirmMsg struct { - PeersId uint32 + PeersId uint32 `sshtype:"91"` MyId uint32 MyWindow uint32 MaxPacketSize uint32 @@ -165,172 +170,239 @@ type channelOpenConfirmMsg struct { } // See RFC 4254, section 5.1. +const msgChannelOpenFailure = 92 + type channelOpenFailureMsg struct { - PeersId uint32 + PeersId uint32 `sshtype:"92"` Reason RejectionReason Message string Language string } +const msgChannelRequest = 98 + type channelRequestMsg struct { - PeersId uint32 + PeersId uint32 `sshtype:"98"` Request string WantReply bool RequestSpecificData []byte `ssh:"rest"` } // See RFC 4254, section 5.4. +const msgChannelSuccess = 99 + type channelRequestSuccessMsg struct { - PeersId uint32 + PeersId uint32 `sshtype:"99"` } // See RFC 4254, section 5.4. +const msgChannelFailure = 100 + type channelRequestFailureMsg struct { - PeersId uint32 + PeersId uint32 `sshtype:"100"` } // See RFC 4254, section 5.3 +const msgChannelClose = 97 + type channelCloseMsg struct { - PeersId uint32 + PeersId uint32 `sshtype:"97"` } // See RFC 4254, section 5.3 +const msgChannelEOF = 96 + type channelEOFMsg struct { - PeersId uint32 + PeersId uint32 `sshtype:"96"` } // See RFC 4254, section 4 +const msgGlobalRequest = 80 + type globalRequestMsg struct { - Type string + Type string `sshtype:"80"` WantReply bool + Data []byte `ssh:"rest"` } // See RFC 4254, section 4 +const msgRequestSuccess = 81 + type globalRequestSuccessMsg struct { - Data []byte `ssh:"rest"` + Data []byte `ssh:"rest" sshtype:"81"` } // See RFC 4254, section 4 +const msgRequestFailure = 82 + type globalRequestFailureMsg struct { - Data []byte `ssh:"rest"` + Data []byte `ssh:"rest" sshtype:"82"` } // See RFC 4254, section 5.2 +const msgChannelWindowAdjust = 93 + type windowAdjustMsg struct { - PeersId uint32 + PeersId uint32 `sshtype:"93"` AdditionalBytes uint32 } // See RFC 4252, section 7 +const msgUserAuthPubKeyOk = 60 + type userAuthPubKeyOkMsg struct { - Algo string - PubKey string + Algo string `sshtype:"60"` + PubKey []byte } -// unmarshal parses the SSH wire data in packet into out using -// reflection. expectedType, if non-zero, is the SSH message type that -// the packet is expected to start with. unmarshal either returns nil -// on success, or a ParseError or UnexpectedMessageError on error. -func unmarshal(out interface{}, packet []byte, expectedType uint8) error { - if len(packet) == 0 { - return ParseError{expectedType} +// typeTag returns the type byte for the given type. The type should +// be struct. +func typeTag(structType reflect.Type) byte { + var tag byte + var tagStr string + tagStr = structType.Field(0).Tag.Get("sshtype") + i, err := strconv.Atoi(tagStr) + if err == nil { + tag = byte(i) } - if expectedType > 0 { - if packet[0] != expectedType { - return UnexpectedMessageError{expectedType, packet[0]} - } - packet = packet[1:] + return tag +} + +func fieldError(t reflect.Type, field int, problem string) error { + if problem != "" { + problem = ": " + problem } + return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) +} + +var errShortRead = errors.New("ssh: short read") +// Unmarshal parses data in SSH wire format into a structure. The out +// argument should be a pointer to struct. If the first member of the +// struct has the "sshtype" tag set to a number in decimal, the packet +// must start that number. In case of error, Unmarshal returns a +// ParseError or UnexpectedMessageError. +func Unmarshal(data []byte, out interface{}) error { v := reflect.ValueOf(out).Elem() structType := v.Type() + expectedType := typeTag(structType) + if len(data) == 0 { + return parseError(expectedType) + } + if expectedType > 0 { + if data[0] != expectedType { + return unexpectedMessageError(expectedType, data[0]) + } + data = data[1:] + } + var ok bool for i := 0; i < v.NumField(); i++ { field := v.Field(i) t := field.Type() switch t.Kind() { case reflect.Bool: - if len(packet) < 1 { - return ParseError{expectedType} + if len(data) < 1 { + return errShortRead } - field.SetBool(packet[0] != 0) - packet = packet[1:] + field.SetBool(data[0] != 0) + data = data[1:] case reflect.Array: if t.Elem().Kind() != reflect.Uint8 { - panic("array of non-uint8") + return fieldError(structType, i, "array of unsupported type") } - if len(packet) < t.Len() { - return ParseError{expectedType} + if len(data) < t.Len() { + return errShortRead } for j, n := 0, t.Len(); j < n; j++ { - field.Index(j).Set(reflect.ValueOf(packet[j])) + field.Index(j).Set(reflect.ValueOf(data[j])) + } + data = data[t.Len():] + case reflect.Uint64: + var u64 uint64 + if u64, data, ok = parseUint64(data); !ok { + return errShortRead } - packet = packet[t.Len():] + field.SetUint(u64) case reflect.Uint32: var u32 uint32 - if u32, packet, ok = parseUint32(packet); !ok { - return ParseError{expectedType} + if u32, data, ok = parseUint32(data); !ok { + return errShortRead } field.SetUint(uint64(u32)) + case reflect.Uint8: + if len(data) < 1 { + return errShortRead + } + field.SetUint(uint64(data[0])) + data = data[1:] case reflect.String: var s []byte - if s, packet, ok = parseString(packet); !ok { - return ParseError{expectedType} + if s, data, ok = parseString(data); !ok { + return fieldError(structType, i, "") } field.SetString(string(s)) case reflect.Slice: switch t.Elem().Kind() { case reflect.Uint8: if structType.Field(i).Tag.Get("ssh") == "rest" { - field.Set(reflect.ValueOf(packet)) - packet = nil + field.Set(reflect.ValueOf(data)) + data = nil } else { var s []byte - if s, packet, ok = parseString(packet); !ok { - return ParseError{expectedType} + if s, data, ok = parseString(data); !ok { + return errShortRead } field.Set(reflect.ValueOf(s)) } case reflect.String: var nl []string - if nl, packet, ok = parseNameList(packet); !ok { - return ParseError{expectedType} + if nl, data, ok = parseNameList(data); !ok { + return errShortRead } field.Set(reflect.ValueOf(nl)) default: - panic("slice of unknown type") + return fieldError(structType, i, "slice of unsupported type") } case reflect.Ptr: if t == bigIntType { var n *big.Int - if n, packet, ok = parseInt(packet); !ok { - return ParseError{expectedType} + if n, data, ok = parseInt(data); !ok { + return errShortRead } field.Set(reflect.ValueOf(n)) } else { - panic("pointer to unknown type") + return fieldError(structType, i, "pointer to unsupported type") } default: - panic("unknown type") + return fieldError(structType, i, "unsupported type") } } - if len(packet) != 0 { - return ParseError{expectedType} + if len(data) != 0 { + return parseError(expectedType) } return nil } -// marshal serializes the message in msg. The given message type is -// prepended if it is non-zero. -func marshal(msgType uint8, msg interface{}) []byte { +// Marshal serializes the message in msg to SSH wire format. The msg +// argument should be a struct or pointer to struct. If the first +// member has the "sshtype" tag set to a number in decimal, that +// number is prepended to the result. If the last of member has the +// "ssh" tag set to "rest", its contents are appended to the output. +func Marshal(msg interface{}) []byte { out := make([]byte, 0, 64) + return marshalStruct(out, msg) +} + +func marshalStruct(out []byte, msg interface{}) []byte { + v := reflect.Indirect(reflect.ValueOf(msg)) + msgType := typeTag(v.Type()) if msgType > 0 { out = append(out, msgType) } - v := reflect.ValueOf(msg) for i, n := 0, v.NumField(); i < n; i++ { field := v.Field(i) switch t := field.Type(); t.Kind() { @@ -342,13 +414,17 @@ func marshal(msgType uint8, msg interface{}) []byte { out = append(out, v) case reflect.Array: if t.Elem().Kind() != reflect.Uint8 { - panic("array of non-uint8") + panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) } for j, l := 0, t.Len(); j < l; j++ { out = append(out, uint8(field.Index(j).Uint())) } case reflect.Uint32: out = appendU32(out, uint32(field.Uint())) + case reflect.Uint64: + out = appendU64(out, uint64(field.Uint())) + case reflect.Uint8: + out = append(out, uint8(field.Uint())) case reflect.String: s := field.String() out = appendInt(out, len(s)) @@ -375,7 +451,7 @@ func marshal(msgType uint8, msg interface{}) []byte { binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) } default: - panic("slice of unknown type") + panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) } case reflect.Ptr: if t == bigIntType { @@ -393,7 +469,7 @@ func marshal(msgType uint8, msg interface{}) []byte { out = out[:oldLength+needed] marshalInt(out[oldLength:], n) } else { - panic("pointer to unknown type") + panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) } } } @@ -477,17 +553,6 @@ func parseUint64(in []byte) (uint64, []byte, bool) { return binary.BigEndian.Uint64(in), in[8:], true } -func nameListLength(namelist []string) int { - length := 4 /* uint32 length prefix */ - for i, name := range namelist { - if i != 0 { - length++ /* comma */ - } - length += len(name) - } - return length -} - func intLength(n *big.Int) int { length := 4 /* length bytes */ if n.Sign() < 0 { @@ -650,9 +715,9 @@ func decode(packet []byte) (interface{}, error) { case msgChannelFailure: msg = new(channelRequestFailureMsg) default: - return nil, UnexpectedMessageError{0, packet[0]} + return nil, unexpectedMessageError(0, packet[0]) } - if err := unmarshal(msg, packet, packet[0]); err != nil { + if err := Unmarshal(packet, msg); err != nil { return nil, err } return msg, nil diff --git a/ssh/messages_test.go b/ssh/messages_test.go index ec1d7be..f14c8a2 100644 --- a/ssh/messages_test.go +++ b/ssh/messages_test.go @@ -5,6 +5,7 @@ package ssh import ( + "bytes" "math/big" "math/rand" "reflect" @@ -32,48 +33,62 @@ func TestIntLength(t *testing.T) { } } -var messageTypes = []interface{}{ - &kexInitMsg{}, - &kexDHInitMsg{}, - &serviceRequestMsg{}, - &serviceAcceptMsg{}, - &userAuthRequestMsg{}, - &channelOpenMsg{}, - &channelOpenConfirmMsg{}, - &channelOpenFailureMsg{}, - &channelRequestMsg{}, - &channelRequestSuccessMsg{}, +type msgAllTypes struct { + Bool bool `sshtype:"21"` + Array [16]byte + Uint64 uint64 + Uint32 uint32 + Uint8 uint8 + String string + Strings []string + Bytes []byte + Int *big.Int + Rest []byte `ssh:"rest"` +} + +func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { + m := &msgAllTypes{} + m.Bool = rand.Intn(2) == 1 + randomBytes(m.Array[:], rand) + m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) + m.Uint32 = uint32(rand.Intn(1 << 32)) + m.Uint8 = uint8(rand.Intn(1 << 8)) + m.String = string(m.Array[:]) + m.Strings = randomNameList(rand) + m.Bytes = m.Array[:] + m.Int = randomInt(rand) + m.Rest = m.Array[:] + return reflect.ValueOf(m) } func TestMarshalUnmarshal(t *testing.T) { rand := rand.New(rand.NewSource(0)) - for i, iface := range messageTypes { - ty := reflect.ValueOf(iface).Type() + iface := &msgAllTypes{} + ty := reflect.ValueOf(iface).Type() - n := 100 - if testing.Short() { - n = 5 + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("failed to create value") + break } - for j := 0; j < n; j++ { - v, ok := quick.Value(ty, rand) - if !ok { - t.Errorf("#%d: failed to create value", i) - break - } - m1 := v.Elem().Interface() - m2 := iface + m1 := v.Elem().Interface() + m2 := iface - marshaled := marshal(msgIgnore, m1) - if err := unmarshal(m2, marshaled, msgIgnore); err != nil { - t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err) - break - } + marshaled := Marshal(m1) + if err := Unmarshal(marshaled, m2); err != nil { + t.Errorf("Unmarshal %#v: %s", m1, err) + break + } - if !reflect.DeepEqual(v.Interface(), m2) { - t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled) - break - } + if !reflect.DeepEqual(v.Interface(), m2) { + t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) + break } } } @@ -81,33 +96,37 @@ func TestMarshalUnmarshal(t *testing.T) { func TestUnmarshalEmptyPacket(t *testing.T) { var b []byte var m channelRequestSuccessMsg - err := unmarshal(&m, b, msgChannelRequest) - want := ParseError{msgChannelRequest} - if _, ok := err.(ParseError); !ok { - t.Fatalf("got %T, want %T", err, want) - } - if got := err.(ParseError); want != got { - t.Fatal("got %#v, want %#v", got, want) + if err := Unmarshal(b, &m); err == nil { + t.Fatalf("unmarshal of empty slice succeeded") } } func TestUnmarshalUnexpectedPacket(t *testing.T) { type S struct { - I uint32 + I uint32 `sshtype:"43"` S string B bool } - s := S{42, "hello", true} - packet := marshal(42, s) + s := S{11, "hello", true} + packet := Marshal(s) + packet[0] = 42 roundtrip := S{} - err := unmarshal(&roundtrip, packet, 43) + err := Unmarshal(packet, &roundtrip) if err == nil { t.Fatal("expected error, not nil") } - want := UnexpectedMessageError{43, 42} - if got, ok := err.(UnexpectedMessageError); !ok || want != got { - t.Fatal("expected %q, got %q", want, got) +} + +func TestMarshalPtr(t *testing.T) { + s := struct { + S string + }{"hello"} + + m1 := Marshal(s) + m2 := Marshal(&s) + if !bytes.Equal(m1, m2) { + t.Errorf("got %q, want %q for marshaled pointer", m2, m1) } } @@ -119,9 +138,9 @@ func TestBareMarshalUnmarshal(t *testing.T) { } s := S{42, "hello", true} - packet := marshal(0, s) + packet := Marshal(s) roundtrip := S{} - unmarshal(&roundtrip, packet, 0) + Unmarshal(packet, &roundtrip) if !reflect.DeepEqual(s, roundtrip) { t.Errorf("got %#v, want %#v", roundtrip, s) @@ -133,7 +152,7 @@ func TestBareMarshal(t *testing.T) { I uint32 } s := S2{42} - packet := marshal(0, s) + packet := Marshal(s) i, rest, ok := parseUint32(packet) if len(rest) > 0 || !ok { t.Errorf("parseInt(%q): parse error", packet) @@ -190,43 +209,36 @@ func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(dhi) } -// TODO(dfc) maybe this can be removed in the future if testing/quick can handle -// derived basic types. -func (RejectionReason) Generate(rand *rand.Rand, size int) reflect.Value { - m := RejectionReason(Prohibited) - return reflect.ValueOf(m) -} - var ( _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() - _kexInit = marshal(msgKexInit, _kexInitMsg) - _kexDHInit = marshal(msgKexDHInit, _kexDHInitMsg) + _kexInit = Marshal(_kexInitMsg) + _kexDHInit = Marshal(_kexDHInitMsg) ) func BenchmarkMarshalKexInitMsg(b *testing.B) { for i := 0; i < b.N; i++ { - marshal(msgKexInit, _kexInitMsg) + Marshal(_kexInitMsg) } } func BenchmarkUnmarshalKexInitMsg(b *testing.B) { m := new(kexInitMsg) for i := 0; i < b.N; i++ { - unmarshal(m, _kexInit, msgKexInit) + Unmarshal(_kexInit, m) } } func BenchmarkMarshalKexDHInitMsg(b *testing.B) { for i := 0; i < b.N; i++ { - marshal(msgKexDHInit, _kexDHInitMsg) + Marshal(_kexDHInitMsg) } } func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { m := new(kexDHInitMsg) for i := 0; i < b.N; i++ { - unmarshal(m, _kexDHInit, msgKexDHInit) + Unmarshal(_kexDHInit, m) } } diff --git a/ssh/mux.go b/ssh/mux.go new file mode 100644 index 0000000..5af7c16 --- /dev/null +++ b/ssh/mux.go @@ -0,0 +1,352 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" + "sync/atomic" +) + +// debugMux, if set, causes messages in the connection protocol to be +// logged. +const debugMux = false + +// chanList is a thread safe channel list. +type chanList struct { + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel + + // This is a debugging aid: it offsets all IDs by this + // amount. This helps distinguish otherwise identical + // server/client muxes + offset uint32 +} + +// Assigns a channel ID to the given channel. +func (c *chanList) add(ch *channel) uint32 { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + c.chans[i] = ch + return uint32(i) + c.offset + } + } + c.chans = append(c.chans, ch) + return uint32(len(c.chans)-1) + c.offset +} + +// getChan returns the channel for the given ID. +func (c *chanList) getChan(id uint32) *channel { + id -= c.offset + + c.Lock() + defer c.Unlock() + if id < uint32(len(c.chans)) { + return c.chans[id] + } + return nil +} + +func (c *chanList) remove(id uint32) { + id -= c.offset + c.Lock() + if id < uint32(len(c.chans)) { + c.chans[id] = nil + } + c.Unlock() +} + +// dropAll forgets all channels it knows, returning them in a slice. +func (c *chanList) dropAll() []*channel { + c.Lock() + defer c.Unlock() + var r []*channel + + for _, ch := range c.chans { + if ch == nil { + continue + } + r = append(r, ch) + } + c.chans = nil + return r +} + +// mux represents the state for the SSH connection protocol, which +// multiplexes many channels onto a single packet transport. +type mux struct { + conn packetConn + chanList chanList + + incomingChannels chan NewChannel + + globalSentMu sync.Mutex + globalResponses chan interface{} + incomingRequests chan *Request + + errCond *sync.Cond + err error +} + +// Each new chanList instantiation has a different offset. +var globalOff uint32 + +func (m *mux) Wait() error { + m.errCond.L.Lock() + defer m.errCond.L.Unlock() + for m.err == nil { + m.errCond.Wait() + } + return m.err +} + +// newMux returns a mux that runs over the given connection. +func newMux(p packetConn) *mux { + m := &mux{ + conn: p, + incomingChannels: make(chan NewChannel, 16), + globalResponses: make(chan interface{}, 1), + incomingRequests: make(chan *Request, 16), + errCond: newCond(), + } + m.chanList.offset = atomic.AddUint32(&globalOff, 1) + go m.loop() + return m +} + +func (m *mux) sendMessage(msg interface{}) error { + p := Marshal(msg) + return m.conn.writePacket(p) +} + +func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + if wantReply { + m.globalSentMu.Lock() + defer m.globalSentMu.Unlock() + } + + if err := m.sendMessage(globalRequestMsg{ + Type: name, + WantReply: wantReply, + Data: payload, + }); err != nil { + return false, nil, err + } + + if !wantReply { + return false, nil, nil + } + + msg, ok := <-m.globalResponses + if !ok { + return false, nil, io.EOF + } + switch msg := msg.(type) { + case *globalRequestFailureMsg: + return false, msg.Data, nil + case *globalRequestSuccessMsg: + return true, msg.Data, nil + default: + return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) + } +} + +// ackRequest must be called after processing a global request that +// has WantReply set. +func (m *mux) ackRequest(ok bool, data []byte) error { + if ok { + return m.sendMessage(globalRequestSuccessMsg{Data: data}) + } + return m.sendMessage(globalRequestFailureMsg{Data: data}) +} + +// TODO(hanwen): Disconnect is a transport layer message. We should +// probably send and receive Disconnect somewhere in the transport +// code. + +// Disconnect sends a disconnect message. +func (m *mux) Disconnect(reason uint32, message string) error { + return m.sendMessage(disconnectMsg{ + Reason: reason, + Message: message, + }) +} + +func (m *mux) Close() error { + return m.conn.Close() +} + +// loop runs the connection machine. It will process packets until an +// error is encountered. To synchronize on loop exit, use mux.Wait. +func (m *mux) loop() { + var err error + for err == nil { + err = m.onePacket() + } + + for _, ch := range m.chanList.dropAll() { + ch.close() + } + + close(m.incomingChannels) + close(m.incomingRequests) + close(m.globalResponses) + + m.conn.Close() + + m.errCond.L.Lock() + m.err = err + m.errCond.Broadcast() + m.errCond.L.Unlock() + + if debugMux { + log.Println("loop exit", err) + } +} + +// onePacket reads and processes one packet. +func (m *mux) onePacket() error { + packet, err := m.conn.readPacket() + if err != nil { + return err + } + + if debugMux { + if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { + log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) + } else { + p, _ := decode(packet) + log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) + } + } + + switch packet[0] { + case msgNewKeys: + // Ignore notification of key change. + return nil + case msgDisconnect: + return m.handleDisconnect(packet) + case msgChannelOpen: + return m.handleChannelOpen(packet) + case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: + return m.handleGlobalPacket(packet) + } + + // assume a channel packet. + if len(packet) < 5 { + return parseError(packet[0]) + } + id := binary.BigEndian.Uint32(packet[1:]) + ch := m.chanList.getChan(id) + if ch == nil { + return fmt.Errorf("ssh: invalid channel %d", id) + } + + return ch.handlePacket(packet) +} + +func (m *mux) handleDisconnect(packet []byte) error { + var d disconnectMsg + if err := Unmarshal(packet, &d); err != nil { + return err + } + + if debugMux { + log.Printf("caught disconnect: %v", d) + } + return &d +} + +func (m *mux) handleGlobalPacket(packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + case *globalRequestMsg: + m.incomingRequests <- &Request{ + Type: msg.Type, + WantReply: msg.WantReply, + Payload: msg.Data, + mux: m, + } + case *globalRequestSuccessMsg, *globalRequestFailureMsg: + m.globalResponses <- msg + default: + panic(fmt.Sprintf("not a global message %#v", msg)) + } + + return nil +} + +// handleChannelOpen schedules a channel to be Accept()ed. +func (m *mux) handleChannelOpen(packet []byte) error { + var msg channelOpenMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + failMsg := channelOpenFailureMsg{ + PeersId: msg.PeersId, + Reason: ConnectionFailed, + Message: "invalid request", + Language: "en_US.UTF-8", + } + return m.sendMessage(failMsg) + } + + c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) + c.remoteId = msg.PeersId + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.PeersWindow) + m.incomingChannels <- c + return nil +} + +func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { + ch, err := m.openChannel(chanType, extra) + if err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { + ch := m.newChannel(chanType, channelOutbound, extra) + + ch.maxIncomingPayload = channelMaxPacket + + open := channelOpenMsg{ + ChanType: chanType, + PeersWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + TypeSpecificData: extra, + PeersId: ch.localId, + } + if err := m.sendMessage(open); err != nil { + return nil, err + } + + switch msg := (<-ch.msg).(type) { + case *channelOpenConfirmMsg: + return ch, nil + case *channelOpenFailureMsg: + return nil, &OpenChannelError{msg.Reason, msg.Message} + default: + return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) + } +} diff --git a/ssh/mux_test.go b/ssh/mux_test.go new file mode 100644 index 0000000..e18afe7 --- /dev/null +++ b/ssh/mux_test.go @@ -0,0 +1,483 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "io/ioutil" + "sync" + "testing" +) + +func muxPair() (*mux, *mux) { + a, b := memPipe() + + s := newMux(a) + c := newMux(b) + + return s, c +} + +// Returns both ends of a channel, and the mux for the the 2nd +// channel. +func channelPair(t *testing.T) (*channel, *channel, *mux) { + c, s := muxPair() + + res := make(chan *channel, 1) + go func() { + newCh, ok := <-s.incomingChannels + if !ok { + t.Fatalf("No incoming channel") + } + if newCh.ChannelType() != "chan" { + t.Fatalf("got type %q want chan", newCh.ChannelType()) + } + ch, _, err := newCh.Accept() + if err != nil { + t.Fatalf("Accept %v", err) + } + res <- ch.(*channel) + }() + + ch, err := c.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + + return <-res, ch, c +} + +func TestMuxReadWrite(t *testing.T) { + s, c, mux := channelPair(t) + defer s.Close() + defer c.Close() + defer mux.Close() + + magic := "hello world" + magicExt := "hello stderr" + go func() { + _, err := s.Write([]byte(magic)) + if err != nil { + t.Fatalf("Write: %v", err) + } + _, err = s.Extended(1).Write([]byte(magicExt)) + if err != nil { + t.Fatalf("Write: %v", err) + } + err = s.Close() + if err != nil { + t.Fatalf("Close: %v", err) + } + }() + + var buf [1024]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + got := string(buf[:n]) + if got != magic { + t.Fatalf("server: got %q want %q", got, magic) + } + + n, err = c.Extended(1).Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + + got = string(buf[:n]) + if got != magicExt { + t.Fatalf("server: got %q want %q", got, magic) + } +} + +func TestMuxChannelOverflow(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + writer.Write(make([]byte, 1)) + wDone <- 1 + }() + writer.remoteWin.waitWriterBlocked() + + // Send 1 byte. + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], writer.remoteId) + marshalUint32(packet[5:], uint32(1)) + packet[9] = 42 + + if err := writer.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + if _, err := reader.SendRequest("hello", true, nil); err == nil { + t.Errorf("SendRequest succeeded.") + } + <-wDone +} + +func TestMuxChannelCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + wDone <- 1 + }() + + writer.remoteWin.waitWriterBlocked() + reader.Close() + <-wDone +} + +func TestMuxConnectionCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + wDone <- 1 + }() + + writer.remoteWin.waitWriterBlocked() + mux.Close() + <-wDone +} + +func TestMuxReject(t *testing.T) { + client, server := muxPair() + defer server.Close() + defer client.Close() + + go func() { + ch, ok := <-server.incomingChannels + if !ok { + t.Fatalf("Accept") + } + if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { + t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + } + ch.Reject(RejectionReason(42), "message") + }() + + ch, err := client.openChannel("ch", []byte("extra")) + if ch != nil { + t.Fatal("openChannel not rejected") + } + + ocf, ok := err.(*OpenChannelError) + if !ok { + t.Errorf("got %#v want *OpenChannelError", err) + } else if ocf.Reason != 42 || ocf.Message != "message" { + t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") + } + + want := "ssh: rejected: unknown reason 42 (message)" + if err.Error() != want { + t.Errorf("got %q, want %q", err.Error(), want) + } +} + +func TestMuxChannelRequest(t *testing.T) { + client, server, mux := channelPair(t) + defer server.Close() + defer client.Close() + defer mux.Close() + + var received int + var wg sync.WaitGroup + wg.Add(1) + go func() { + for r := range server.incomingRequests { + received++ + r.Reply(r.Type == "yes", nil) + } + wg.Done() + }() + _, err := client.SendRequest("yes", false, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + ok, err := client.SendRequest("yes", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + + if !ok { + t.Errorf("SendRequest(yes): %v", ok) + + } + + ok, err = client.SendRequest("no", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + if ok { + t.Errorf("SendRequest(no): %v", ok) + + } + + client.Close() + wg.Wait() + + if received != 3 { + t.Errorf("got %d requests, want %d", received, 3) + } +} + +func TestMuxGlobalRequest(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + var seen bool + go func() { + for r := range serverMux.incomingRequests { + seen = seen || r.Type == "peek" + if r.WantReply { + err := r.Reply(r.Type == "yes", + append([]byte(r.Type), r.Payload...)) + if err != nil { + t.Errorf("AckRequest: %v", err) + } + } + } + }() + + _, _, err := clientMux.SendRequest("peek", false, nil) + if err != nil { + t.Errorf("SendRequest: %v", err) + } + + ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) + if !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + + if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { + t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", + ok, data, err) + } + + clientMux.Disconnect(0, "") + if !seen { + t.Errorf("never saw 'peek' request") + } +} + +func TestMuxGlobalRequestUnblock(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + result := make(chan error, 1) + go func() { + _, _, err := clientMux.SendRequest("hello", true, nil) + result <- err + }() + + <-serverMux.incomingRequests + serverMux.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", io.EOF) + } +} + +func TestMuxChannelRequestUnblock(t *testing.T) { + a, b, connB := channelPair(t) + defer a.Close() + defer b.Close() + defer connB.Close() + + result := make(chan error, 1) + go func() { + _, err := a.SendRequest("hello", true, nil) + result <- err + }() + + <-b.incomingRequests + connB.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", err) + } +} + +func TestMuxDisconnect(t *testing.T) { + a, b := muxPair() + defer a.Close() + defer b.Close() + + go func() { + for r := range b.incomingRequests { + r.Reply(true, nil) + } + }() + + a.Disconnect(42, "whatever") + ok, _, err := a.SendRequest("hello", true, nil) + if ok || err == nil { + t.Errorf("got reply after disconnecting") + } + err = b.Wait() + if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 { + t.Errorf("got %#v, want disconnectMsg{Reason:42}", err) + } +} + +func TestMuxCloseChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + defer r.Close() + defer w.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.Close(); err != nil { + t.Errorf("w.Close: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after Close", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxCloseWriteChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.CloseWrite(); err != nil { + t.Errorf("w.CloseWrite: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after CloseWrite", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxInvalidRecord(t *testing.T) { + a, b := muxPair() + defer a.Close() + defer b.Close() + + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], 29348723 /* invalid channel id */) + marshalUint32(packet[5:], 1) + packet[9] = 42 + + a.conn.writePacket(packet) + go a.SendRequest("hello", false, nil) + // 'a' wrote an invalid packet, so 'b' has exited. + req, ok := <-b.incomingRequests + if ok { + t.Errorf("got request %#v after receiving invalid packet", req) + } +} + +func TestZeroWindowAdjust(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + go func() { + io.WriteString(a, "hello") + // bogus adjust. + a.sendMessage(windowAdjustMsg{}) + io.WriteString(a, "world") + a.Close() + }() + + want := "helloworld" + c, _ := ioutil.ReadAll(b) + if string(c) != want { + t.Errorf("got %q want %q", c, want) + } +} + +func TestMuxMaxPacketSize(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + large := make([]byte, a.maxRemotePayload+1) + packet := make([]byte, 1+4+4+1+len(large)) + packet[0] = msgChannelData + marshalUint32(packet[1:], a.remoteId) + marshalUint32(packet[5:], uint32(len(large))) + packet[9] = 42 + + if err := a.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + + go a.SendRequest("hello", false, nil) + + _, ok := <-b.incomingRequests + if ok { + t.Errorf("connection still alive after receiving large packet.") + } +} + +// Don't ship code with debug=true. +func TestDebug(t *testing.T) { + if debugMux { + t.Error("mux debug switched on") + } + if debugHandshake { + t.Error("handshake debug switched on") + } +} diff --git a/ssh/server.go b/ssh/server.go index b4defbe..7a53d57 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -6,38 +6,52 @@ package ssh import ( "bytes" - "crypto/rand" - "encoding/binary" "errors" "fmt" "io" "net" - "sync" - - _ "crypto/sha1" ) +// The Permissions type holds fine-grained permissions that are +// specific to a user or a specific authentication method for a +// user. Permissions, except for "source-address", must be enforced in +// the server application layer, after successful authentication. The +// Permissions are passed on in ServerConn so a server implementation +// can honor them. +type Permissions struct { + // Critical options restrict default permissions. Common + // restrictions are "source-address" and "force-command". If + // the server cannot enforce the restriction, or does not + // recognize it, the user should not authenticate. + CriticalOptions map[string]string + + // Extensions are extra functionality that the server may + // offer on authenticated connections. Common extensions are + // "permit-agent-forwarding", "permit-X11-forwarding". Lack of + // support for an extension does not preclude authenticating a + // user. + Extensions map[string]string +} + +// ServerConfig holds server specific configuration data. type ServerConfig struct { - hostKeys []Signer + // Config contains configuration shared between client and server. + Config - // Rand provides the source of entropy for key exchange. If Rand is - // nil, the cryptographic random reader in package crypto/rand will - // be used. - Rand io.Reader + hostKeys []Signer // NoClientAuth is true if clients are allowed to connect without // authenticating. NoClientAuth bool - // PasswordCallback, if non-nil, is called when a user attempts to - // authenticate using a password. It may be called concurrently from - // several goroutines. - PasswordCallback func(conn *ServerConn, user, password string) bool + // PasswordCallback, if non-nil, is called when a user + // attempts to authenticate using a password. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) // PublicKeyCallback, if non-nil, is called when a client attempts public // key authentication. It must return true if the given public key is - // valid for the given user. - PublicKeyCallback func(conn *ServerConn, user, algo string, pubkey []byte) bool + // valid for the given user. For example, see CertChecker.Authenticate. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) // KeyboardInteractiveCallback, if non-nil, is called when // keyboard-interactive authentication is selected (RFC @@ -46,24 +60,19 @@ type ServerConfig struct { // Challenge rounds. To avoid information leaks, the client // should be presented a challenge even if the user is // unknown. - KeyboardInteractiveCallback func(conn *ServerConn, user string, client ClientKeyboardInteractive) bool + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) - // Cryptographic-related configuration. - Crypto CryptoConfig -} - -func (c *ServerConfig) rand() io.Reader { - if c.Rand == nil { - return rand.Reader - } - return c.Rand + // AuthLogCallback, if non-nil, is called to log all authentication + // attempts. + AuthLogCallback func(conn ConnMetadata, method string, err error) } // AddHostKey adds a private key as a host key. If an existing host -// key exists with the same algorithm, it is overwritten. +// key exists with the same algorithm, it is overwritten. Each server +// config must have at least one host key. func (s *ServerConfig) AddHostKey(key Signer) { for i, k := range s.hostKeys { - if k.PublicKey().PublicKeyAlgo() == key.PublicKey().PublicKeyAlgo() { + if k.PublicKey().Type() == key.PublicKey().Type() { s.hostKeys[i] = key return } @@ -72,68 +81,73 @@ func (s *ServerConfig) AddHostKey(key Signer) { s.hostKeys = append(s.hostKeys, key) } -// SetRSAPrivateKey sets the private key for a Server. A Server must have a -// private key configured in order to accept connections. The private key must -// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa" -// typically contains such a key. -func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) error { - priv, err := ParsePrivateKey(pemBytes) - if err != nil { - return err - } - s.AddHostKey(priv) - return nil -} - // cachedPubKey contains the results of querying whether a public key is -// acceptable for a user. The cache only applies to a single ServerConn. +// acceptable for a user. type cachedPubKey struct { - user, algo string - pubKey []byte - result bool + user string + pubKeyData []byte + result error + perms *Permissions } -const maxCachedPubKeys = 16 - -// A ServerConn represents an incoming connection. -type ServerConn struct { - transport *transport - config *ServerConfig +func (k1 *cachedPubKey) Equal(k2 *cachedPubKey) bool { + return k1.user == k2.user && bytes.Equal(k1.pubKeyData, k2.pubKeyData) +} - channels map[uint32]*serverChan - nextChanId uint32 +const maxCachedPubKeys = 16 - // lock protects err and channels. - lock sync.Mutex - err error +// pubKeyCache caches tests for public keys. Since SSH clients +// will query whether a public key is acceptable before attempting to +// authenticate with it, we end up with duplicate queries for public +// key validity. The cache only applies to a single ServerConn. +type pubKeyCache struct { + keys []cachedPubKey +} - // cachedPubKeys contains the cache results of tests for public keys. - // Since SSH clients will query whether a public key is acceptable - // before attempting to authenticate with it, we end up with duplicate - // queries for public key validity. - cachedPubKeys []cachedPubKey +// get returns the result for a given user/algo/key tuple. +func (c *pubKeyCache) get(candidate cachedPubKey) (result error, ok bool) { + for _, k := range c.keys { + if k.Equal(&candidate) { + return k.result, true + } + } + return errors.New("ssh: not in cache"), false +} - // User holds the successfully authenticated user name. - // It is empty if no authentication is used. It is populated before - // any authentication callback is called and not assigned to after that. - User string +// add adds the given tuple to the cache. +func (c *pubKeyCache) add(candidate cachedPubKey) { + if len(c.keys) < maxCachedPubKeys { + c.keys = append(c.keys, candidate) + } +} - // ClientVersion is the client's version, populated after - // Handshake is called. It should not be modified. - ClientVersion []byte +// ServerConn is an authenticated SSH connection, as seen from the +// server +type ServerConn struct { + Conn - // Our version. - serverVersion []byte + // If the succeeding authentication callback returned a + // non-nil Permissions pointer, it is stored here. + Permissions *Permissions } -// Server returns a new SSH server connection -// using c as the underlying transport. -func Server(c net.Conn, config *ServerConfig) *ServerConn { - return &ServerConn{ - transport: newTransport(c, config.rand(), false /* not client */), - channels: make(map[uint32]*serverChan), - config: config, +// NewServerConn starts a new SSH server with c as the underlying +// transport. It starts with a handshake and, if the handshake is +// unsuccessful, it closes the connection and returns an error. The +// Request and NewChannel channels must be serviced, or the connection +// will hang. +func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + s := &connection{ + sshConn: sshConn{conn: c}, } + perms, err := s.serverHandshake(&fullConf) + if err != nil { + c.Close() + return nil, nil, nil, err + } + return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil } // signAndMarshal signs the data with the appropriate algorithm, @@ -144,134 +158,60 @@ func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { return nil, err } - return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil + return Marshal(sig), nil } -// Close closes the connection. -func (s *ServerConn) Close() error { return s.transport.Close() } - -// LocalAddr returns the local network address. -func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() } - -// RemoteAddr returns the remote network address. -func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() } +// handshake performs key exchange and user authentication. +func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { + if len(config.hostKeys) == 0 { + return nil, errors.New("ssh: server has no host keys") + } -// Handshake performs an SSH transport and client authentication on the given ServerConn. -func (s *ServerConn) Handshake() error { var err error s.serverVersion = []byte(packageVersion) - s.ClientVersion, err = exchangeVersions(s.transport.Conn, s.serverVersion) + s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) if err != nil { - return err + return nil, err + } + + tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) + s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) + + if err := s.transport.requestKeyChange(); err != nil { + return nil, err } - if err := s.clientInitHandshake(nil, nil); err != nil { - return err + + if packet, err := s.transport.readPacket(); err != nil { + return nil, err + } else if packet[0] != msgNewKeys { + return nil, unexpectedMessageError(msgNewKeys, packet[0]) } var packet []byte if packet, err = s.transport.readPacket(); err != nil { - return err + return nil, err } + var serviceRequest serviceRequestMsg - if err := unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil { - return err + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err } if serviceRequest.Service != serviceUserAuth { - return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") } serviceAccept := serviceAcceptMsg{ Service: serviceUserAuth, } - if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil { - return err - } - - if err := s.authenticate(); err != nil { - return err - } - return err -} - -func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) { - serverKexInit := kexInitMsg{ - KexAlgos: s.config.Crypto.kexes(), - CiphersClientServer: s.config.Crypto.ciphers(), - CiphersServerClient: s.config.Crypto.ciphers(), - MACsClientServer: s.config.Crypto.macs(), - MACsServerClient: s.config.Crypto.macs(), - CompressionClientServer: supportedCompressions, - CompressionServerClient: supportedCompressions, - } - for _, k := range s.config.hostKeys { - serverKexInit.ServerHostKeyAlgos = append( - serverKexInit.ServerHostKeyAlgos, k.PublicKey().PublicKeyAlgo()) - } - - serverKexInitPacket := marshal(msgKexInit, serverKexInit) - if err = s.transport.writePacket(serverKexInitPacket); err != nil { - return - } - - if clientKexInitPacket == nil { - clientKexInit = new(kexInitMsg) - if clientKexInitPacket, err = s.transport.readPacket(); err != nil { - return - } - if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil { - return - } - } - - algs := findAgreedAlgorithms(clientKexInit, &serverKexInit) - if algs == nil { - return errors.New("ssh: no common algorithms") - } - - if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] { - // The client sent a Kex message for the wrong algorithm, - // which we have to ignore. - if _, err = s.transport.readPacket(); err != nil { - return - } - } - - var hostKey Signer - for _, k := range s.config.hostKeys { - if algs.hostKey == k.PublicKey().PublicKeyAlgo() { - hostKey = k - } - } - - kex, ok := kexAlgoMap[algs.kex] - if !ok { - return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) + if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { + return nil, err } - magics := handshakeMagics{ - serverVersion: s.serverVersion, - clientVersion: s.ClientVersion, - serverKexInit: marshal(msgKexInit, serverKexInit), - clientKexInit: clientKexInitPacket, - } - result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey) + perms, err := s.serverAuthenticate(config) if err != nil { - return err - } - - if err = s.transport.prepareKeyChange(algs, result); err != nil { - return err - } - - if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil { - return - } - if packet, err := s.transport.readPacket(); err != nil { - return err - } else if packet[0] != msgNewKeys { - return UnexpectedMessageError{msgNewKeys, packet[0]} + return nil, err } - - return + s.mux = newMux(s.transport) + return perms, err } func isAcceptableAlgo(algo string) bool { @@ -283,181 +223,213 @@ func isAcceptableAlgo(algo string) bool { return false } -// testPubKey returns true if the given public key is acceptable for the user. -func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool { - if s.config.PublicKeyCallback == nil || !isAcceptableAlgo(algo) { - return false +func checkSourceAddress(addr net.Addr, sourceAddr string) error { + if addr == nil { + return errors.New("ssh: no address known for client, but source-address match required") } - for _, c := range s.cachedPubKeys { - if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) { - return c.result - } + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) } - result := s.config.PublicKeyCallback(s, user, algo, pubKey) - if len(s.cachedPubKeys) < maxCachedPubKeys { - c := cachedPubKey{ - user: user, - algo: algo, - pubKey: make([]byte, len(pubKey)), - result: result, + if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { + if bytes.Equal(allowedIP, tcpAddr.IP) { + return nil + } + } else { + _, ipNet, err := net.ParseCIDR(sourceAddr) + if err != nil { + return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) + } + + if ipNet.Contains(tcpAddr.IP) { + return nil } - copy(c.pubKey, pubKey) - s.cachedPubKeys = append(s.cachedPubKeys, c) } - return result + return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) } -func (s *ServerConn) authenticate() error { - var userAuthReq userAuthRequestMsg +func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { var err error - var packet []byte + var cache pubKeyCache + var perms *Permissions userAuthLoop: for { - if packet, err = s.transport.readPacket(); err != nil { - return err - } - if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil { - return err + var userAuthReq userAuthRequestMsg + if packet, err := s.transport.readPacket(); err != nil { + return nil, err + } else if err = Unmarshal(packet, &userAuthReq); err != nil { + return nil, err } if userAuthReq.Service != serviceSSH { - return errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) } + s.user = userAuthReq.User + perms = nil + authErr := errors.New("no auth passed yet") + switch userAuthReq.Method { case "none": - if s.config.NoClientAuth { - break userAuthLoop + if config.NoClientAuth { + s.user = "" + authErr = nil } case "password": - if s.config.PasswordCallback == nil { + if config.PasswordCallback == nil { + authErr = errors.New("ssh: password auth not configured") break } payload := userAuthReq.Payload if len(payload) < 1 || payload[0] != 0 { - return ParseError{msgUserAuthRequest} + return nil, parseError(msgUserAuthRequest) } payload = payload[1:] password, payload, ok := parseString(payload) if !ok || len(payload) > 0 { - return ParseError{msgUserAuthRequest} + return nil, parseError(msgUserAuthRequest) } - s.User = userAuthReq.User - if s.config.PasswordCallback(s, userAuthReq.User, string(password)) { - break userAuthLoop - } + perms, authErr = config.PasswordCallback(s, password) case "keyboard-interactive": - if s.config.KeyboardInteractiveCallback == nil { + if config.KeyboardInteractiveCallback == nil { + authErr = errors.New("ssh: keyboard-interactive auth not configubred") break } - s.User = userAuthReq.User - if s.config.KeyboardInteractiveCallback(s, s.User, &sshClientKeyboardInteractive{s}) { - break userAuthLoop - } + prompter := &sshClientKeyboardInteractive{s} + perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) case "publickey": - if s.config.PublicKeyCallback == nil { + if config.PublicKeyCallback == nil { + authErr = errors.New("ssh: publickey auth not configured") break } payload := userAuthReq.Payload if len(payload) < 1 { - return ParseError{msgUserAuthRequest} + return nil, parseError(msgUserAuthRequest) } isQuery := payload[0] == 0 payload = payload[1:] algoBytes, payload, ok := parseString(payload) if !ok { - return ParseError{msgUserAuthRequest} + return nil, parseError(msgUserAuthRequest) } algo := string(algoBytes) + if !isAcceptableAlgo(algo) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) + break + } - pubKey, payload, ok := parseString(payload) + pubKeyData, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + + pubKey, err := ParsePublicKey(pubKeyData) + if err != nil { + return nil, err + } + candidate := cachedPubKey{ + user: s.user, + pubKeyData: pubKeyData, + } + candidate.result, ok = cache.get(candidate) if !ok { - return ParseError{msgUserAuthRequest} + candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) + if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { + candidate.result = checkSourceAddress( + s.RemoteAddr(), + candidate.perms.CriticalOptions[sourceAddressCriticalOption]) + } + cache.add(candidate) } + if isQuery { // The client can query if the given public key // would be okay. if len(payload) > 0 { - return ParseError{msgUserAuthRequest} + return nil, parseError(msgUserAuthRequest) } - if s.testPubKey(userAuthReq.User, algo, pubKey) { + + if candidate.result == nil { okMsg := userAuthPubKeyOkMsg{ Algo: algo, - PubKey: string(pubKey), + PubKey: pubKeyData, } - if err = s.transport.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil { - return err + if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { + return nil, err } continue userAuthLoop } + authErr = candidate.result } else { sig, payload, ok := parseSignature(payload) if !ok || len(payload) > 0 { - return ParseError{msgUserAuthRequest} + return nil, parseError(msgUserAuthRequest) } // Ensure the public key algo and signature algo // are supported. Compare the private key // algorithm name that corresponds to algo with // sig.Format. This is usually the same, but // for certs, the names differ. - if !isAcceptableAlgo(algo) || !isAcceptableAlgo(sig.Format) || pubAlgoToPrivAlgo(algo) != sig.Format { + if !isAcceptableAlgo(sig.Format) { break } - signedData := buildDataSignedForAuth(s.transport.sessionID, userAuthReq, algoBytes, pubKey) - key, _, ok := ParsePublicKey(pubKey) - if !ok { - return ParseError{msgUserAuthRequest} - } + signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData) - if !key.Verify(signedData, sig.Blob) { - return ParseError{msgUserAuthRequest} - } - // TODO(jmpittman): Implement full validation for certificates. - s.User = userAuthReq.User - if s.testPubKey(userAuthReq.User, algo, pubKey) { - break userAuthLoop + if err := pubKey.Verify(signedData, sig); err != nil { + return nil, err } + + authErr = candidate.result + perms = candidate.perms } + default: + authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) + } + + if config.AuthLogCallback != nil { + config.AuthLogCallback(s, userAuthReq.Method, authErr) + } + + if authErr == nil { + break userAuthLoop } var failureMsg userAuthFailureMsg - if s.config.PasswordCallback != nil { + if config.PasswordCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "password") } - if s.config.PublicKeyCallback != nil { + if config.PublicKeyCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "publickey") } - if s.config.KeyboardInteractiveCallback != nil { + if config.KeyboardInteractiveCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") } if len(failureMsg.Methods) == 0 { - return errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") } - if err = s.transport.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil { - return err + if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil { + return nil, err } } - packet = []byte{msgUserAuthSuccess} - if err = s.transport.writePacket(packet); err != nil { - return err + if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { + return nil, err } - - return nil + return perms, nil } // sshClientKeyboardInteractive implements a ClientKeyboardInteractive by // asking the client on the other side of a ServerConn. type sshClientKeyboardInteractive struct { - *ServerConn + *connection } func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { @@ -471,7 +443,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest prompts = appendBool(prompts, echos[i]) } - if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{ + if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ Instruction: instruction, NumPrompts: uint32(len(questions)), Prompts: prompts, @@ -484,19 +456,19 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest return nil, err } if packet[0] != msgUserAuthInfoResponse { - return nil, UnexpectedMessageError{msgUserAuthInfoResponse, packet[0]} + return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) } packet = packet[1:] n, packet, ok := parseUint32(packet) if !ok || int(n) != len(questions) { - return nil, &ParseError{msgUserAuthInfoResponse} + return nil, parseError(msgUserAuthInfoResponse) } for i := uint32(0); i < n; i++ { ans, rest, ok := parseString(packet) if !ok { - return nil, &ParseError{msgUserAuthInfoResponse} + return nil, parseError(msgUserAuthInfoResponse) } answers = append(answers, string(ans)) @@ -508,185 +480,3 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest return answers, nil } - -const defaultWindowSize = 32768 - -// Accept reads and processes messages on a ServerConn. It must be called -// in order to demultiplex messages to any resulting Channels. -func (s *ServerConn) Accept() (Channel, error) { - // TODO(dfc) s.lock is not held here so visibility of s.err is not guaranteed. - if s.err != nil { - return nil, s.err - } - - for { - packet, err := s.transport.readPacket() - if err != nil { - - s.lock.Lock() - s.err = err - s.lock.Unlock() - - // TODO(dfc) s.lock protects s.channels but isn't being held here. - for _, c := range s.channels { - c.setDead() - c.handleData(nil) - } - - return nil, err - } - - switch packet[0] { - case msgChannelData: - if len(packet) < 9 { - // malformed data packet - return nil, ParseError{msgChannelData} - } - remoteId := binary.BigEndian.Uint32(packet[1:5]) - s.lock.Lock() - c, ok := s.channels[remoteId] - if !ok { - s.lock.Unlock() - continue - } - if length := binary.BigEndian.Uint32(packet[5:9]); length > 0 { - packet = packet[9:] - c.handleData(packet[:length]) - } - s.lock.Unlock() - default: - decoded, err := decode(packet) - if err != nil { - return nil, err - } - switch msg := decoded.(type) { - case *channelOpenMsg: - if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { - return nil, errors.New("ssh: invalid MaxPacketSize from peer") - } - c := &serverChan{ - channel: channel{ - packetConn: s.transport, - remoteId: msg.PeersId, - remoteWin: window{Cond: newCond()}, - maxPacket: msg.MaxPacketSize, - }, - chanType: msg.ChanType, - extraData: msg.TypeSpecificData, - myWindow: defaultWindowSize, - serverConn: s, - cond: newCond(), - pendingData: make([]byte, defaultWindowSize), - } - c.remoteWin.add(msg.PeersWindow) - s.lock.Lock() - c.localId = s.nextChanId - s.nextChanId++ - s.channels[c.localId] = c - s.lock.Unlock() - return c, nil - - case *channelRequestMsg: - s.lock.Lock() - c, ok := s.channels[msg.PeersId] - if !ok { - s.lock.Unlock() - continue - } - c.handlePacket(msg) - s.lock.Unlock() - - case *windowAdjustMsg: - s.lock.Lock() - c, ok := s.channels[msg.PeersId] - if !ok { - s.lock.Unlock() - continue - } - c.handlePacket(msg) - s.lock.Unlock() - - case *channelEOFMsg: - s.lock.Lock() - c, ok := s.channels[msg.PeersId] - if !ok { - s.lock.Unlock() - continue - } - c.handlePacket(msg) - s.lock.Unlock() - - case *channelCloseMsg: - s.lock.Lock() - c, ok := s.channels[msg.PeersId] - if !ok { - s.lock.Unlock() - continue - } - c.handlePacket(msg) - s.lock.Unlock() - - case *globalRequestMsg: - if msg.WantReply { - if err := s.transport.writePacket([]byte{msgRequestFailure}); err != nil { - return nil, err - } - } - - case *kexInitMsg: - s.lock.Lock() - if err := s.clientInitHandshake(msg, packet); err != nil { - s.lock.Unlock() - return nil, err - } - s.lock.Unlock() - case *disconnectMsg: - return nil, io.EOF - default: - // Unknown message. Ignore. - } - } - } - - panic("unreachable") -} - -// A Listener implements a network listener (net.Listener) for SSH connections. -type Listener struct { - listener net.Listener - config *ServerConfig -} - -// Addr returns the listener's network address. -func (l *Listener) Addr() net.Addr { - return l.listener.Addr() -} - -// Close closes the listener. -func (l *Listener) Close() error { - return l.listener.Close() -} - -// Accept waits for and returns the next incoming SSH connection. -// The receiver should call Handshake() in another goroutine -// to avoid blocking the accepter. -func (l *Listener) Accept() (*ServerConn, error) { - c, err := l.listener.Accept() - if err != nil { - return nil, err - } - return Server(c, l.config), nil -} - -// Listen creates an SSH listener accepting connections on -// the given network address using net.Listen. -func Listen(network, addr string, config *ServerConfig) (*Listener, error) { - l, err := net.Listen(network, addr) - if err != nil { - return nil, err - } - return &Listener{ - l, - config, - }, nil -} diff --git a/ssh/session.go b/ssh/session.go index 39f2d22..3b42b50 100644 --- a/ssh/session.go +++ b/ssh/session.go @@ -129,128 +129,126 @@ type Session struct { Stdout io.Writer Stderr io.Writer - *clientChan // the channel backing this session - - started bool // true once Start, Run or Shell is invoked. + ch Channel // the channel backing this session + started bool // true once Start, Run or Shell is invoked. copyFuncs []func() error errors chan error // one send per copyFunc // true if pipe method is active stdinpipe, stdoutpipe, stderrpipe bool + + // stdinPipeWriter is non-nil if StdinPipe has not been called + // and Stdin was specified by the user; it is the write end of + // a pipe connecting Session.Stdin to the stdin channel. + stdinPipeWriter io.WriteCloser + + exitStatus chan error } -// RFC 4254 Section 6.4. -type setenvRequest struct { - PeersId uint32 - Request string - WantReply bool - Name string - Value string +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return s.ch.SendRequest(name, wantReply, payload) } -// RFC 4254 Section 6.5. -type subsystemRequestMsg struct { - PeersId uint32 - Request string - WantReply bool - Subsystem string +func (s *Session) Close() error { + return s.ch.Close() +} + +// RFC 4254 Section 6.4. +type setenvRequest struct { + Name string + Value string } // Setenv sets an environment variable that will be applied to any // command executed by Shell or Run. func (s *Session) Setenv(name, value string) error { - req := setenvRequest{ - PeersId: s.remoteId, - Request: "env", - WantReply: true, - Name: name, - Value: value, + msg := setenvRequest{ + Name: name, + Value: value, } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { - return err + ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: setenv failed") } - return s.waitForResponse() + return err } // RFC 4254 Section 6.2. type ptyRequestMsg struct { - PeersId uint32 - Request string - WantReply bool - Term string - Columns uint32 - Rows uint32 - Width uint32 - Height uint32 - Modelist string + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string } // RequestPty requests the association of a pty with the session on the remote host. func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { var tm []byte for k, v := range termmodes { - tm = append(tm, k) - tm = appendU32(tm, v) + kv := struct { + Key byte + Val uint32 + }{k, v} + + tm = append(tm, Marshal(&kv)...) } tm = append(tm, tty_OP_END) req := ptyRequestMsg{ - PeersId: s.remoteId, - Request: "pty-req", - WantReply: true, - Term: term, - Columns: uint32(w), - Rows: uint32(h), - Width: uint32(w * 8), - Height: uint32(h * 8), - Modelist: string(tm), + Term: term, + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + Modelist: string(tm), } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { - return err + ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) + if err == nil && !ok { + err = errors.New("ssh: pty-req failed") } - return s.waitForResponse() + return err +} + +// RFC 4254 Section 6.5. +type subsystemRequestMsg struct { + Subsystem string } // RequestSubsystem requests the association of a subsystem with the session on the remote host. // A subsystem is a predefined command that runs in the background when the ssh session is initiated func (s *Session) RequestSubsystem(subsystem string) error { - req := subsystemRequestMsg{ - PeersId: s.remoteId, - Request: "subsystem", - WantReply: true, + msg := subsystemRequestMsg{ Subsystem: subsystem, } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { - return err + ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: subsystem request failed") } - return s.waitForResponse() + return err } // RFC 4254 Section 6.9. type signalMsg struct { - PeersId uint32 - Request string - WantReply bool - Signal string + Signal string } // Signal sends the given signal to the remote process. // sig is one of the SIG* constants. func (s *Session) Signal(sig Signal) error { - req := signalMsg{ - PeersId: s.remoteId, - Request: "signal", - WantReply: false, - Signal: string(sig), + msg := signalMsg{ + Signal: string(sig), } - return s.writePacket(marshal(msgChannelRequest, req)) + + _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) + return err } // RFC 4254 Section 6.5. type execMsg struct { - PeersId uint32 - Request string - WantReply bool - Command string + Command string } // Start runs cmd on the remote host. Typically, the remote @@ -261,16 +259,15 @@ func (s *Session) Start(cmd string) error { return errors.New("ssh: session already started") } req := execMsg{ - PeersId: s.remoteId, - Request: "exec", - WantReply: true, - Command: cmd, + Command: cmd, } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { - return err + + ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) + if err == nil && !ok { + err = fmt.Errorf("ssh: command %v failed", cmd) } - if err := s.waitForResponse(); err != nil { - return fmt.Errorf("ssh: could not execute command %s: %v", cmd, err) + if err != nil { + return err } return s.start() } @@ -339,31 +336,17 @@ func (s *Session) Shell() error { if s.started { return errors.New("ssh: session already started") } - req := channelRequestMsg{ - PeersId: s.remoteId, - Request: "shell", - WantReply: true, + + ok, err := s.ch.SendRequest("shell", true, nil) + if err == nil && !ok { + return fmt.Errorf("ssh: cound not start shell") } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { + if err != nil { return err } - if err := s.waitForResponse(); err != nil { - return fmt.Errorf("ssh: could not execute shell: %v", err) - } return s.start() } -func (s *Session) waitForResponse() error { - msg := <-s.msg - switch msg.(type) { - case *channelRequestSuccessMsg: - return nil - case *channelRequestFailureMsg: - return errors.New("ssh: request failed") - } - return fmt.Errorf("ssh: unknown packet %T received: %v", msg, msg) -} - func (s *Session) start() error { s.started = true @@ -394,8 +377,11 @@ func (s *Session) Wait() error { if !s.started { return errors.New("ssh: session not started") } - waitErr := s.wait() + waitErr := <-s.exitStatus + if s.stdinPipeWriter != nil { + s.stdinPipeWriter.Close() + } var copyError error for _ = range s.copyFuncs { if err := <-s.errors; err != nil && copyError == nil { @@ -408,52 +394,35 @@ func (s *Session) Wait() error { return copyError } -func (s *Session) wait() error { +func (s *Session) wait(reqs <-chan *Request) error { wm := Waitmsg{status: -1} - // Wait for msg channel to be closed before returning. - for msg := range s.msg { - switch msg := msg.(type) { - case *channelRequestMsg: - switch msg.Request { - case "exit-status": - d := msg.RequestSpecificData - wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) - case "exit-signal": - signal, rest, ok := parseString(msg.RequestSpecificData) - if !ok { - return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) - } - wm.signal = safeString(string(signal)) - - // skip coreDumped bool - if len(rest) == 0 { - return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) - } - rest = rest[1:] - - errmsg, rest, ok := parseString(rest) - if !ok { - return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) - } - wm.msg = safeString(string(errmsg)) - - lang, _, ok := parseString(rest) - if !ok { - return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) - } - wm.lang = safeString(string(lang)) - default: - // This handles keepalives and matches - // OpenSSH's behaviour. - if msg.WantReply { - s.writePacket(marshal(msgChannelFailure, channelRequestFailureMsg{ - PeersId: s.remoteId, - })) - } + for msg := range reqs { + switch msg.Type { + case "exit-status": + d := msg.Payload + wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) + case "exit-signal": + var sigval struct { + Signal string + CoreDumped bool + Error string + Lang string + } + if err := Unmarshal(msg.Payload, &sigval); err != nil { + return err } + + // Must sanitize strings? + wm.signal = sigval.Signal + wm.msg = sigval.Error + wm.lang = sigval.Lang default: - return fmt.Errorf("wait: unexpected packet %T received: %v", msg, msg) + // This handles keepalives and matches + // OpenSSH's behaviour. + if msg.WantReply { + msg.Reply(false, nil) + } } } if wm.status == 0 { @@ -476,12 +445,20 @@ func (s *Session) stdin() { if s.stdinpipe { return } + var stdin io.Reader if s.Stdin == nil { - s.Stdin = new(bytes.Buffer) + stdin = new(bytes.Buffer) + } else { + r, w := io.Pipe() + go func() { + _, err := io.Copy(w, s.Stdin) + w.CloseWithError(err) + }() + stdin, s.stdinPipeWriter = r, w } s.copyFuncs = append(s.copyFuncs, func() error { - _, err := io.Copy(s.clientChan.stdin, s.Stdin) - if err1 := s.clientChan.stdin.Close(); err == nil && err1 != io.EOF { + _, err := io.Copy(s.ch, stdin) + if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { err = err1 } return err @@ -496,7 +473,7 @@ func (s *Session) stdout() { s.Stdout = ioutil.Discard } s.copyFuncs = append(s.copyFuncs, func() error { - _, err := io.Copy(s.Stdout, s.clientChan.stdout) + _, err := io.Copy(s.Stdout, s.ch) return err }) } @@ -509,11 +486,21 @@ func (s *Session) stderr() { s.Stderr = ioutil.Discard } s.copyFuncs = append(s.copyFuncs, func() error { - _, err := io.Copy(s.Stderr, s.clientChan.stderr) + _, err := io.Copy(s.Stderr, s.ch.Stderr()) return err }) } +// sessionStdin reroutes Close to CloseWrite. +type sessionStdin struct { + io.Writer + ch Channel +} + +func (s *sessionStdin) Close() error { + return s.ch.CloseWrite() +} + // StdinPipe returns a pipe that will be connected to the // remote command's standard input when the command starts. func (s *Session) StdinPipe() (io.WriteCloser, error) { @@ -524,7 +511,7 @@ func (s *Session) StdinPipe() (io.WriteCloser, error) { return nil, errors.New("ssh: StdinPipe after process started") } s.stdinpipe = true - return s.clientChan.stdin, nil + return &sessionStdin{s.ch, s.ch}, nil } // StdoutPipe returns a pipe that will be connected to the @@ -541,7 +528,7 @@ func (s *Session) StdoutPipe() (io.Reader, error) { return nil, errors.New("ssh: StdoutPipe after process started") } s.stdoutpipe = true - return s.clientChan.stdout, nil + return s.ch, nil } // StderrPipe returns a pipe that will be connected to the @@ -558,28 +545,20 @@ func (s *Session) StderrPipe() (io.Reader, error) { return nil, errors.New("ssh: StderrPipe after process started") } s.stderrpipe = true - return s.clientChan.stderr, nil + return s.ch.Stderr(), nil } -// NewSession returns a new interactive session on the remote host. -func (c *ClientConn) NewSession() (*Session, error) { - ch := c.newChan(c.transport) - if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenMsg{ - ChanType: "session", - PeersId: ch.localId, - PeersWindow: channelWindowSize, - MaxPacketSize: channelMaxPacketSize, - })); err != nil { - c.chanList.remove(ch.localId) - return nil, err +// newSession returns a new interactive session on the remote host. +func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { + s := &Session{ + ch: ch, } - if err := ch.waitForChannelOpenResponse(); err != nil { - c.chanList.remove(ch.localId) - return nil, fmt.Errorf("ssh: unable to open session: %v", err) - } - return &Session{ - clientChan: ch, - }, nil + s.exitStatus = make(chan error, 1) + go func() { + s.exitStatus <- s.wait(reqs) + }() + + return s, nil } // An ExitError reports unsuccessful completion of a remote command. diff --git a/ssh/session_test.go b/ssh/session_test.go index 5cff58a..cc26573 100644 --- a/ssh/session_test.go +++ b/ssh/session_test.go @@ -12,71 +12,60 @@ import ( "io" "io/ioutil" "math/rand" - "net" "testing" "code.google.com/p/go.crypto/ssh/terminal" ) -type serverType func(*serverChan, *testing.T) +type serverType func(Channel, <-chan *Request, *testing.T) // dial constructs a new test server and returns a *ClientConn. -func dial(handler serverType, t *testing.T) *ClientConn { - l, err := Listen("tcp", "127.0.0.1:0", serverConfig) +func dial(handler serverType, t *testing.T) *Client { + c1, c2, err := netPipe() if err != nil { - t.Fatalf("unable to listen: %v", err) + t.Fatalf("netPipe: %v", err) } + go func() { - defer l.Close() - conn, err := l.Accept() - if err != nil { - t.Errorf("Unable to accept: %v", err) - return + defer c1.Close() + conf := ServerConfig{ + NoClientAuth: true, } - defer conn.Close() - if err := conn.Handshake(); err != nil { - t.Errorf("Unable to handshake: %v", err) - return + conf.AddHostKey(testSigners["rsa"]) + + _, chans, reqs, err := NewServerConn(c1, &conf) + if err != nil { + t.Fatalf("Unable to handshake: %v", err) } - done := make(chan struct{}) - for { - ch, err := conn.Accept() - if err == io.EOF || err == io.ErrUnexpectedEOF { - return - } - // We sometimes get ECONNRESET rather than EOF. - if _, ok := err.(*net.OpError); ok { - return + go DiscardRequests(reqs) + + for newCh := range chans { + if newCh.ChannelType() != "session" { + newCh.Reject(UnknownChannelType, "unknown channel type") + continue } + + ch, inReqs, err := newCh.Accept() if err != nil { - t.Errorf("Unable to accept incoming channel request: %v", err) - return - } - if ch.ChannelType() != "session" { - ch.Reject(UnknownChannelType, "unknown channel type") + t.Errorf("Accept: %v", err) continue } - ch.Accept() go func() { - defer close(done) - handler(ch.(*serverChan), t) + handler(ch, inReqs, t) }() } - <-done }() config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthPassword(clientPassword), - }, } - c, err := Dial("tcp", l.Addr().String(), config) + conn, chans, reqs, err := NewClientConn(c2, "", config) if err != nil { t.Fatalf("unable to dial remote side: %v", err) } - return c + + return NewClient(conn, chans, reqs) } // Test a simple string is returned to session.Stdout. @@ -330,164 +319,6 @@ func TestExitWithoutStatusOrSignal(t *testing.T) { } } -func TestInvalidServerMessage(t *testing.T) { - conn := dial(sendInvalidRecord, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - // Make sure that we closed all the clientChans when the connection - // failed. - session.wait() - - defer session.Close() -} - -// In the wild some clients (and servers) send zero sized window updates. -// Test that the client can continue after receiving a zero sized update. -func TestClientZeroWindowAdjust(t *testing.T) { - conn := dial(sendZeroWindowAdjust, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - err = session.Wait() - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -// In the wild some clients (and servers) send zero sized window updates. -// Test that the server can continue after receiving a zero size update. -func TestServerZeroWindowAdjust(t *testing.T) { - conn := dial(exitStatusZeroHandler, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - - // send a bogus zero sized window update - session.clientChan.sendWindowAdj(0) - - err = session.Wait() - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -// Verify that the client never sends a packet larger than maxpacket. -func TestClientStdinRespectsMaxPacketSize(t *testing.T) { - conn := dial(discardHandler, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("failed to request new session: %v", err) - } - defer session.Close() - stdin, err := session.StdinPipe() - if err != nil { - t.Fatalf("failed to obtain stdinpipe: %v", err) - } - const size = 100 * 1000 - for i := 0; i < 10; i++ { - n, err := stdin.Write(make([]byte, size)) - if n != size || err != nil { - t.Fatalf("failed to write: %d, %v", n, err) - } - } -} - -// Verify that the client never accepts a packet larger than maxpacket. -func TestServerStdoutRespectsMaxPacketSize(t *testing.T) { - conn := dial(largeSendHandler, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - out, err := session.StdoutPipe() - if err != nil { - t.Fatalf("Unable to connect to Stdout: %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - if _, err := ioutil.ReadAll(out); err != nil { - t.Fatalf("failed to read: %v", err) - } -} - -func TestClientCannotSendAfterEOF(t *testing.T) { - conn := dial(exitWithoutSignalOrStatus, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - in, err := session.StdinPipe() - if err != nil { - t.Fatalf("Unable to connect channel stdin: %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - if err := in.Close(); err != nil { - t.Fatalf("Unable to close stdin: %v", err) - } - if _, err := in.Write([]byte("foo")); err == nil { - t.Fatalf("Session write should fail") - } -} - -func TestClientCannotSendAfterClose(t *testing.T) { - conn := dial(exitWithoutSignalOrStatus, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - in, err := session.StdinPipe() - if err != nil { - t.Fatalf("Unable to connect channel stdin: %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - // close underlying channel - if err := session.channel.Close(); err != nil { - t.Fatalf("Unable to close session: %v", err) - } - if _, err := in.Write([]byte("foo")); err == nil { - t.Fatalf("Session write should fail") - } -} - -func TestClientCannotSendHugePacket(t *testing.T) { - // client and server use the same transport write code so this - // test suffices for both. - conn := dial(shellHandler, t) - defer conn.Close() - if err := conn.transport.writePacket(make([]byte, maxPacket*2)); err == nil { - t.Fatalf("huge packet write should fail") - } -} - // windowTestBytes is the number of bytes that we'll send to the SSH server. const windowTestBytes = 16000 * 200 @@ -560,93 +391,104 @@ func TestClientHandlesKeepalives(t *testing.T) { } type exitStatusMsg struct { - PeersId uint32 - Request string - WantReply bool - Status uint32 + Status uint32 } type exitSignalMsg struct { - PeersId uint32 - Request string - WantReply bool Signal string CoreDumped bool Errmsg string Lang string } -func newServerShell(ch *serverChan, prompt string) *ServerTerminal { - term := terminal.NewTerminal(ch, prompt) - return &ServerTerminal{ - Term: term, - Channel: ch, +func handleTerminalRequests(in <-chan *Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any commands, only the default shell. + ok = false + } + case "env": + ok = true + } + req.Reply(ok, nil) } } -func exitStatusZeroHandler(ch *serverChan, t *testing.T) { +func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { + term := terminal.NewTerminal(ch, prompt) + go handleTerminalRequests(in) + return term +} + +func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() // this string is returned to stdout - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendStatus(0, ch, t) } -func exitStatusNonZeroHandler(ch *serverChan, t *testing.T) { +func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendStatus(15, ch, t) } -func exitSignalAndStatusHandler(ch *serverChan, t *testing.T) { +func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendStatus(15, ch, t) sendSignal("TERM", ch, t) } -func exitSignalHandler(ch *serverChan, t *testing.T) { +func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendSignal("TERM", ch, t) } -func exitSignalUnknownHandler(ch *serverChan, t *testing.T) { +func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendSignal("SYS", ch, t) } -func exitWithoutSignalOrStatus(ch *serverChan, t *testing.T) { +func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) } -func shellHandler(ch *serverChan, t *testing.T) { +func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() // this string is returned to stdout - shell := newServerShell(ch, "golang") + shell := newServerShell(ch, in, "golang") readLine(shell, t) sendStatus(0, ch, t) } // Ignores the command, writes fixed strings to stderr and stdout. // Strings are "this-is-stdout." and "this-is-stderr.". -func fixedOutputHandler(ch *serverChan, t *testing.T) { +func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() + _, err := ch.Read(nil) - _, err := ch.Read(make([]byte, 0)) - if _, ok := err.(ChannelRequest); !ok { + req, ok := <-in + if !ok { t.Fatalf("error: expected channel request, got: %#v", err) return } + // ignore request, always send some text - ch.AckRequest(true) + req.Reply(true, nil) _, err = io.WriteString(ch, "this-is-stdout.") if err != nil { @@ -659,84 +501,39 @@ func fixedOutputHandler(ch *serverChan, t *testing.T) { sendStatus(0, ch, t) } -func readLine(shell *ServerTerminal, t *testing.T) { +func readLine(shell *terminal.Terminal, t *testing.T) { if _, err := shell.ReadLine(); err != nil && err != io.EOF { t.Errorf("unable to read line: %v", err) } } -func sendStatus(status uint32, ch *serverChan, t *testing.T) { +func sendStatus(status uint32, ch Channel, t *testing.T) { msg := exitStatusMsg{ - PeersId: ch.remoteId, - Request: "exit-status", - WantReply: false, - Status: status, + Status: status, } - if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil { + if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { t.Errorf("unable to send status: %v", err) } } -func sendSignal(signal string, ch *serverChan, t *testing.T) { +func sendSignal(signal string, ch Channel, t *testing.T) { sig := exitSignalMsg{ - PeersId: ch.remoteId, - Request: "exit-signal", - WantReply: false, Signal: signal, CoreDumped: false, Errmsg: "Process terminated", Lang: "en-GB-oed", } - if err := ch.writePacket(marshal(msgChannelRequest, sig)); err != nil { + if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { t.Errorf("unable to send signal: %v", err) } } -func sendInvalidRecord(ch *serverChan, t *testing.T) { +func discardHandler(ch Channel, t *testing.T) { defer ch.Close() - packet := make([]byte, 1+4+4+1) - packet[0] = msgChannelData - marshalUint32(packet[1:], 29348723 /* invalid channel id */) - marshalUint32(packet[5:], 1) - packet[9] = 42 - - if err := ch.writePacket(packet); err != nil { - t.Errorf("unable send invalid record: %v", err) - } -} - -func sendZeroWindowAdjust(ch *serverChan, t *testing.T) { - defer ch.Close() - // send a bogus zero sized window update - ch.sendWindowAdj(0) - shell := newServerShell(ch, "> ") - readLine(shell, t) - sendStatus(0, ch, t) -} - -func discardHandler(ch *serverChan, t *testing.T) { - defer ch.Close() - // grow the window to avoid being fooled by - // the initial 1 << 14 window. - ch.sendWindowAdj(1024 * 1024) io.Copy(ioutil.Discard, ch) } -func largeSendHandler(ch *serverChan, t *testing.T) { - defer ch.Close() - // grow the window to avoid being fooled by - // the initial 1 << 14 window. - ch.sendWindowAdj(1024 * 1024) - shell := newServerShell(ch, "> ") - readLine(shell, t) - // try to send more than the 32k window - // will allow - if err := ch.writePacket(make([]byte, 128*1024)); err == nil { - t.Errorf("wrote packet larger than 32k") - } -} - -func echoHandler(ch *serverChan, t *testing.T) { +func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) @@ -773,17 +570,59 @@ func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, erro return written, nil } -func channelKeepaliveSender(ch *serverChan, t *testing.T) { +func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) - msg := channelRequestMsg{ - PeersId: ch.remoteId, - Request: "keepalive@openssh.com", - WantReply: true, - } - if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil { + if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { t.Errorf("unable to send channel keepalive request: %v", err) } sendStatus(0, ch, t) } + +func TestClientWriteEOF(t *testing.T) { + conn := dial(simpleEchoHandler, t) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("StdoutPipe failed: %v", err) + } + + data := []byte(`0000`) + _, err = stdin.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + stdin.Close() + + res, err := ioutil.ReadAll(stdout) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, res) { + t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) + } +} + +func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + data, err := ioutil.ReadAll(ch) + if err != nil { + t.Errorf("handler read error: %v", err) + } + _, err = ch.Write(data) + if err != nil { + t.Errorf("handler write error: %v", err) + } +} diff --git a/ssh/tcpip.go b/ssh/tcpip.go index 74fc1a7..5a4fa8b 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go @@ -16,10 +16,11 @@ import ( "time" ) -// Listen requests the remote peer open a listening socket -// on addr. Incoming connections will be available by calling -// Accept on the returned net.Listener. -func (c *ClientConn) Listen(n, addr string) (net.Listener, error) { +// Listen requests the remote peer open a listening socket on +// addr. Incoming connections will be available by calling Accept on +// the returned net.Listener. The listener must be serviced, or the +// SSH connection may hang. +func (c *Client) Listen(n, addr string) (net.Listener, error) { laddr, err := net.ResolveTCPAddr(n, addr) if err != nil { return nil, err @@ -59,7 +60,7 @@ func isBrokenOpenSSHVersion(versionStr string) bool { // autoPortListenWorkaround simulates automatic port allocation by // trying random ports repeatedly. -func (c *ClientConn) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { +func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { var sshListener net.Listener var err error const tries = 10 @@ -77,44 +78,45 @@ func (c *ClientConn) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, // RFC 4254 7.1 type channelForwardMsg struct { - Message string - WantReply bool - raddr string - rport uint32 + addr string + rport uint32 } // ListenTCP requests the remote peer open a listening socket // on laddr. Incoming connections will be available by calling // Accept on the returned net.Listener. -func (c *ClientConn) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { - if laddr.Port == 0 && isBrokenOpenSSHVersion(c.serverVersion) { +func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { + if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { return c.autoPortListenWorkaround(laddr) } m := channelForwardMsg{ - "tcpip-forward", - true, // sendGlobalRequest waits for a reply laddr.IP.String(), uint32(laddr.Port), } // send message - resp, err := c.sendGlobalRequest(m) + ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) if err != nil { return nil, err } + if !ok { + return nil, errors.New("ssh: tcpip-forward request denied by peer") + } // If the original port was 0, then the remote side will // supply a real port number in the response. if laddr.Port == 0 { - port, _, ok := parseUint32(resp.Data) - if !ok { - return nil, errors.New("unable to parse response") + var p struct { + Port uint32 } - laddr.Port = int(port) + if err := Unmarshal(resp, &p); err != nil { + return nil, err + } + laddr.Port = int(p.Port) } // Register this forward, using the port number we obtained. - ch := c.forwardList.add(*laddr) + ch := c.forwards.add(*laddr) return &tcpListener{laddr, c, ch}, nil } @@ -137,7 +139,7 @@ type forwardEntry struct { // arguments to add/remove/lookup should be address as specified in // the original forward-request. type forward struct { - c *clientChan // the ssh client channel underlying this forward + newCh NewChannel // the ssh client channel underlying this forward raddr *net.TCPAddr // the raddr of the incoming connection } @@ -152,6 +154,31 @@ func (l *forwardList) add(addr net.TCPAddr) chan forward { return f.c } +func (l *forwardList) handleChannels(in <-chan NewChannel) { + for ch := range in { + laddr, rest, ok := parseTCPAddr(ch.ExtraData()) + if !ok { + // invalid request + ch.Reject(ConnectionFailed, "could not parse TCP address") + continue + } + + raddr, rest, ok := parseTCPAddr(rest) + if !ok { + // invalid request + ch.Reject(ConnectionFailed, "could not parse TCP address") + continue + } + + if ok := l.forward(*laddr, *raddr, ch); !ok { + // Section 7.2, implementations MUST reject spurious incoming + // connections. + ch.Reject(Prohibited, "no forward for address") + continue + } + } +} + // remove removes the forward entry, and the channel feeding its // listener. func (l *forwardList) remove(addr net.TCPAddr) { @@ -176,21 +203,22 @@ func (l *forwardList) closeAll() { l.entries = nil } -func (l *forwardList) lookup(addr net.TCPAddr) (chan forward, bool) { +func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { l.Lock() defer l.Unlock() for _, f := range l.entries { - if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { - return f.c, true + if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { + f.c <- forward{ch, &raddr} + return true } } - return nil, false + return false } type tcpListener struct { laddr *net.TCPAddr - conn *ClientConn + conn *Client in <-chan forward } @@ -200,30 +228,33 @@ func (l *tcpListener) Accept() (net.Conn, error) { if !ok { return nil, io.EOF } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + return &tcpChanConn{ - tcpChan: &tcpChan{ - clientChan: s.c, - Reader: s.c.stdout, - Writer: s.c.stdin, - }, - laddr: l.laddr, - raddr: s.raddr, + Channel: ch, + laddr: l.laddr, + raddr: s.raddr, }, nil } // Close closes the listener. func (l *tcpListener) Close() error { m := channelForwardMsg{ - "cancel-tcpip-forward", - true, l.laddr.IP.String(), uint32(l.laddr.Port), } - l.conn.forwardList.remove(*l.laddr) - if _, err := l.conn.sendGlobalRequest(m); err != nil { - return err + + // this also closes the listener. + l.conn.forwards.remove(*l.laddr) + ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-tcpip-forward failed") } - return nil + return err } // Addr returns the listener's network address. @@ -233,7 +264,7 @@ func (l *tcpListener) Addr() net.Addr { // Dial initiates a connection to the addr from the remote host. // The resulting connection has a zero LocalAddr() and RemoteAddr(). -func (c *ClientConn) Dial(n, addr string) (net.Conn, error) { +func (c *Client) Dial(n, addr string) (net.Conn, error) { // Parse the address into host and numeric port. host, portString, err := net.SplitHostPort(addr) if err != nil { @@ -253,7 +284,7 @@ func (c *ClientConn) Dial(n, addr string) (net.Conn, error) { return nil, err } return &tcpChanConn{ - tcpChan: ch, + Channel: ch, laddr: zeroAddr, raddr: zeroAddr, }, nil @@ -262,7 +293,7 @@ func (c *ClientConn) Dial(n, addr string) (net.Conn, error) { // DialTCP connects to the remote address raddr on the network net, // which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used // as the local address for the connection. -func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { +func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { if laddr == nil { laddr = &net.TCPAddr{ IP: net.IPv4zero, @@ -274,7 +305,7 @@ func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, err return nil, err } return &tcpChanConn{ - tcpChan: ch, + Channel: ch, laddr: laddr, raddr: raddr, }, nil @@ -282,54 +313,32 @@ func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, err // RFC 4254 7.2 type channelOpenDirectMsg struct { - ChanType string - PeersId uint32 - PeersWindow uint32 - MaxPacketSize uint32 - raddr string - rport uint32 - laddr string - lport uint32 + raddr string + rport uint32 + laddr string + lport uint32 } -// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as -// strings and are expected to be resolvable at the remote end. -func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpChan, error) { - ch := c.newChan(c.transport) - if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{ - ChanType: "direct-tcpip", - PeersId: ch.localId, - PeersWindow: channelWindowSize, - MaxPacketSize: channelMaxPacketSize, - raddr: raddr, - rport: uint32(rport), - laddr: laddr, - lport: uint32(lport), - })); err != nil { - c.chanList.remove(ch.localId) - return nil, err - } - if err := ch.waitForChannelOpenResponse(); err != nil { - c.chanList.remove(ch.localId) - return nil, fmt.Errorf("ssh: unable to open direct tcpip connection: %v", err) +func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { + msg := channelOpenDirectMsg{ + raddr: raddr, + rport: uint32(rport), + laddr: laddr, + lport: uint32(lport), } - return &tcpChan{ - clientChan: ch, - Reader: ch.stdout, - Writer: ch.stdin, - }, nil + ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) + go DiscardRequests(in) + return ch, err } type tcpChan struct { - *clientChan // the backing channel - io.Reader - io.Writer + Channel // the backing channel } // tcpChanConn fulfills the net.Conn interface without // the tcpChan having to hold laddr or raddr directly. type tcpChanConn struct { - *tcpChan + Channel laddr, raddr net.Addr } diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go index 7fa9fc4..f1265cb 100644 --- a/ssh/tcpip_test.go +++ b/ssh/tcpip_test.go @@ -1,3 +1,7 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package ssh import ( diff --git a/ssh/terminal/util_windows.go b/ssh/terminal/util_windows.go new file mode 100644 index 0000000..0a454e0 --- /dev/null +++ b/ssh/terminal/util_windows.go @@ -0,0 +1,171 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal + +import ( + "io" + "syscall" + "unsafe" +) + +const ( + enableLineInput = 2 + enableEchoInput = 4 + enableProcessedInput = 1 + enableWindowInput = 8 + enableMouseInput = 16 + enableInsertMode = 32 + enableQuickEditMode = 64 + enableExtendedFlags = 128 + enableAutoPosition = 256 + enableProcessedOutput = 1 + enableWrapAtEolOutput = 2 +) + +var kernel32 = syscall.NewLazyDLL("kernel32.dll") + +var ( + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") + procSetConsoleMode = kernel32.NewProc("SetConsoleMode") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") +) + +type ( + short int16 + word uint16 + + coord struct { + x short + y short + } + smallRect struct { + left short + top short + right short + bottom short + } + consoleScreenBufferInfo struct { + size coord + cursorPosition coord + attributes word + window smallRect + maximumWindowSize coord + } +) + +type State struct { + mode uint32 +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + var st uint32 + r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + return r != 0 && e == 0 +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput) + _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) + if e != 0 { + return nil, error(e) + } + return &State{st}, nil +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + return &State{st}, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + _, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0) + return err +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + var info consoleScreenBufferInfo + _, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0) + if e != 0 { + return 0, 0, error(e) + } + return int(info.size.x), int(info.size.y), nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + old := st + + st &^= (enableEchoInput) + st |= (enableProcessedInput | enableLineInput | enableProcessedOutput) + _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) + if e != 0 { + return nil, error(e) + } + + defer func() { + syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) + }() + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(syscall.Handle(fd), buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} diff --git a/ssh/test/agent_unix_test.go b/ssh/test/agent_unix_test.go new file mode 100644 index 0000000..26c88eb --- /dev/null +++ b/ssh/test/agent_unix_test.go @@ -0,0 +1,50 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd + +package test + +import ( + "bytes" + "testing" + + "code.google.com/p/go.crypto/ssh" + "code.google.com/p/go.crypto/ssh/agent" +) + +func TestAgentForward(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + keyring := agent.NewKeyring() + keyring.Add(testPrivateKeys["dsa"], nil, "") + pub := testPublicKeys["dsa"] + + sess, err := conn.NewSession() + if err != nil { + t.Fatalf("NewSession: %v", err) + } + if err := agent.RequestAgentForwarding(sess); err != nil { + t.Fatalf("RequestAgentForwarding: %v", err) + } + + if err := agent.ForwardToAgent(conn, keyring); err != nil { + t.Fatalf("SetupForwardKeyring: %v", err) + } + out, err := sess.CombinedOutput("ssh-add -L") + if err != nil { + t.Fatalf("running ssh-add: %v, out %s", err, out) + } + key, _, _, _, err := ssh.ParseAuthorizedKey(out) + if err != nil { + t.Fatalf("ParseAuthorizedKey(%q): %v", out, err) + } + + if !bytes.Equal(key.Marshal(), pub.Marshal()) { + t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub)) + } +} diff --git a/ssh/test/cert_test.go b/ssh/test/cert_test.go new file mode 100644 index 0000000..d4f7226 --- /dev/null +++ b/ssh/test/cert_test.go @@ -0,0 +1,47 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd + +package test + +import ( + "crypto/rand" + "testing" + + "code.google.com/p/go.crypto/ssh" +) + +func TestCertLogin(t *testing.T) { + s := newServer(t) + defer s.Shutdown() + + // Use a key different from the default. + clientKey := testSigners["dsa"] + caAuthKey := testSigners["ecdsa"] + cert := &ssh.Certificate{ + Key: clientKey.PublicKey(), + ValidPrincipals: []string{username()}, + CertType: ssh.UserCert, + ValidBefore: ssh.CertTimeInfinity, + } + if err := cert.SignCert(rand.Reader, caAuthKey); err != nil { + t.Fatalf("SetSignature: %v", err) + } + + certSigner, err := ssh.NewCertSigner(cert, clientKey) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + conf := &ssh.ClientConfig{ + User: username(), + } + conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner)) + client, err := s.TryDial(conf) + if err != nil { + t.Fatalf("TryDial: %v", err) + } + client.Close() +} diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go index 3a57c10..881a9da 100644 --- a/ssh/test/forward_unix_test.go +++ b/ssh/test/forward_unix_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd plan9 +// +build darwin freebsd linux netbsd openbsd package test diff --git a/ssh/test/session_test.go b/ssh/test/session_test.go index bd7307d..d8d35a5 100644 --- a/ssh/test/session_test.go +++ b/ssh/test/session_test.go @@ -11,6 +11,7 @@ package test import ( "bytes" "code.google.com/p/go.crypto/ssh" + "errors" "io" "strings" "testing" @@ -38,12 +39,13 @@ func TestHostKeyCheck(t *testing.T) { defer server.Shutdown() conf := clientConfig() - k := conf.HostKeyChecker.(*storedHostKey) + hostDB := hostKeyDB() + conf.HostKeyCallback = hostDB.Check // change the keys. - k.keys[ssh.KeyAlgoRSA][25]++ - k.keys[ssh.KeyAlgoDSA][25]++ - k.keys[ssh.KeyAlgoECDSA256][25]++ + hostDB.keys[ssh.KeyAlgoRSA][25]++ + hostDB.keys[ssh.KeyAlgoDSA][25]++ + hostDB.keys[ssh.KeyAlgoECDSA256][25]++ conn, err := server.TryDial(conf) if err == nil { @@ -54,6 +56,53 @@ func TestHostKeyCheck(t *testing.T) { } } +func TestRunCommandStdin(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + r, w := io.Pipe() + defer r.Close() + defer w.Close() + session.Stdin = r + + err = session.Run("true") + if err != nil { + t.Fatalf("session failed: %v", err) + } +} + +func TestRunCommandStdinError(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + r, w := io.Pipe() + defer r.Close() + session.Stdin = r + pipeErr := errors.New("closing write end of pipe") + w.CloseWithError(pipeErr) + + err = session.Run("true") + if err != pipeErr { + t.Fatalf("expected %v, found %v", pipeErr, err) + } +} + func TestRunCommandFailed(t *testing.T) { server := newServer(t) defer server.Shutdown() @@ -107,7 +156,7 @@ func TestFuncLargeRead(t *testing.T) { t.Fatalf("unable to acquire stdout pipe: %s", err) } - err = session.Start("dd if=/dev/urandom bs=2048 count=1") + err = session.Start("dd if=/dev/urandom bs=2048 count=1024") if err != nil { t.Fatalf("unable to execute remote command: %s", err) } @@ -118,11 +167,53 @@ func TestFuncLargeRead(t *testing.T) { t.Fatalf("error reading from remote stdout: %s", err) } - if n != 2048 { + if n != 2048*1024 { t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n) } } +func TestKeyChange(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + hostDB := hostKeyDB() + conf.HostKeyCallback = hostDB.Check + conf.RekeyThreshold = 1024 + conn := server.Dial(conf) + defer conn.Close() + + for i := 0; i < 4; i++ { + session, err := conn.NewSession() + if err != nil { + t.Fatalf("unable to create new session: %s", err) + } + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + err = session.Start("dd if=/dev/urandom bs=1024 count=1") + if err != nil { + t.Fatalf("unable to execute remote command: %s", err) + } + buf := new(bytes.Buffer) + n, err := io.Copy(buf, stdout) + if err != nil { + t.Fatalf("error reading from remote stdout: %s", err) + } + + want := int64(1024) + if n != want { + t.Fatalf("Expected %d bytes but read only %d from remote command", want, n) + } + } + + if changes := hostDB.checkCount; changes < 4 { + t.Errorf("got %d key changes, want 4", changes) + } +} + func TestInvalidTerminalMode(t *testing.T) { server := newServer(t) defer server.Shutdown() @@ -183,3 +274,44 @@ func TestValidTerminalMode(t *testing.T) { t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput) } } + +func TestCiphers(t *testing.T) { + var config ssh.Config + config.SetDefaults() + cipherOrder := config.Ciphers + + for _, ciph := range cipherOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + conf.Ciphers = []string{ciph} + // Don't fail if sshd doesnt have the cipher. + conf.Ciphers = append(conf.Ciphers, cipherOrder...) + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + } else { + t.Fatalf("failed for cipher %q", ciph) + } + } +} + +func TestMACs(t *testing.T) { + var config ssh.Config + config.SetDefaults() + macOrder := config.MACs + + for _, mac := range macOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + conf.MACs = []string{mac} + // Don't fail if sshd doesnt have the MAC. + conf.MACs = append(conf.MACs, macOrder...) + if conn, err := server.TryDial(conf); err == nil { + conn.Close() + } else { + t.Fatalf("failed for MAC %q", mac) + } + } +} diff --git a/ssh/test/tcpip_test.go b/ssh/test/tcpip_test.go index ee06b60..a2eb935 100644 --- a/ssh/test/tcpip_test.go +++ b/ssh/test/tcpip_test.go @@ -9,39 +9,38 @@ package test // direct-tcpip functional tests import ( + "io" "net" - "net/http" "testing" ) -func TestTCPIPHTTP(t *testing.T) { - // google.com will generate at least one redirect, possibly three - // depending on your location. - doTest(t, "http://google.com") -} - -func TestTCPIPHTTPS(t *testing.T) { - doTest(t, "https://encrypted.google.com/") -} - -func doTest(t *testing.T, url string) { +func TestDial(t *testing.T) { server := newServer(t) defer server.Shutdown() - conn := server.Dial(clientConfig()) - defer conn.Close() + sshConn := server.Dial(clientConfig()) + defer sshConn.Close() - tr := &http.Transport{ - Dial: func(n, addr string) (net.Conn, error) { - return conn.Dial(n, addr) - }, - } - client := &http.Client{ - Transport: tr, + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen: %v", err) } - resp, err := client.Get(url) + defer l.Close() + + go func() { + for { + c, err := l.Accept() + if err != nil { + break + } + + io.WriteString(c, c.RemoteAddr().String()) + c.Close() + } + }() + + conn, err := sshConn.Dial("tcp", l.Addr().String()) if err != nil { - t.Fatalf("unable to proxy: %s", err) + t.Fatalf("Dial: %v", err) } - // got a body without error - t.Log(resp) + defer conn.Close() } diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go index 86df3f4..f44c65d 100644 --- a/ssh/test/test_unix_test.go +++ b/ssh/test/test_unix_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd plan9 +// +build darwin freebsd linux netbsd openbsd package test @@ -11,7 +11,6 @@ package test import ( "bytes" "fmt" - "io" "io/ioutil" "log" "net" @@ -23,13 +22,14 @@ import ( "text/template" "code.google.com/p/go.crypto/ssh" + "code.google.com/p/go.crypto/ssh/testdata" ) const sshd_config = ` Protocol 2 -HostKey {{.Dir}}/ssh_host_rsa_key -HostKey {{.Dir}}/ssh_host_dsa_key -HostKey {{.Dir}}/ssh_host_ecdsa_key +HostKey {{.Dir}}/id_rsa +HostKey {{.Dir}}/id_dsa +HostKey {{.Dir}}/id_ecdsa Pidfile {{.Dir}}/sshd.pid #UsePrivilegeSeparation no KeyRegenerationInterval 3600 @@ -41,41 +41,14 @@ PermitRootLogin no StrictModes no RSAAuthentication yes PubkeyAuthentication yes -AuthorizedKeysFile {{.Dir}}/authorized_keys +AuthorizedKeysFile {{.Dir}}/id_user.pub +TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub IgnoreRhosts yes RhostsRSAAuthentication no HostbasedAuthentication no ` -var ( - configTmpl template.Template - privateKey ssh.Signer - hostKeyRSA ssh.Signer - hostKeyECDSA ssh.Signer - hostKeyDSA ssh.Signer -) - -func init() { - template.Must(configTmpl.Parse(sshd_config)) - - for n, k := range map[string]*ssh.Signer{ - "ssh_host_ecdsa_key": &hostKeyECDSA, - "ssh_host_rsa_key": &hostKeyRSA, - "ssh_host_dsa_key": &hostKeyDSA, - } { - var err error - *k, err = ssh.ParsePrivateKey([]byte(keys[n])) - if err != nil { - panic(fmt.Sprintf("ParsePrivateKey(%q): %v", n, err)) - } - } - - var err error - privateKey, err = ssh.ParsePrivateKey([]byte(testClientPrivateKey)) - if err != nil { - panic(fmt.Sprintf("ParsePrivateKey: %v", err)) - } -} +var configTmpl = template.Must(template.New("").Parse(sshd_config)) type server struct { t *testing.T @@ -107,36 +80,44 @@ func username() string { type storedHostKey struct { // keys map from an algorithm string to binary key data. keys map[string][]byte + + // checkCount counts the Check calls. Used for testing + // rekeying. + checkCount int } func (k *storedHostKey) Add(key ssh.PublicKey) { if k.keys == nil { k.keys = map[string][]byte{} } - k.keys[key.PublicKeyAlgo()] = ssh.MarshalPublicKey(key) + k.keys[key.Type()] = key.Marshal() } -func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error { - if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 { +func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error { + k.checkCount++ + algo := key.Type() + + if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 { return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo]) } return nil } -func clientConfig() *ssh.ClientConfig { - keyChecker := storedHostKey{} - keyChecker.Add(hostKeyECDSA.PublicKey()) - keyChecker.Add(hostKeyRSA.PublicKey()) - keyChecker.Add(hostKeyDSA.PublicKey()) +func hostKeyDB() *storedHostKey { + keyChecker := &storedHostKey{} + keyChecker.Add(testPublicKeys["ecdsa"]) + keyChecker.Add(testPublicKeys["rsa"]) + keyChecker.Add(testPublicKeys["dsa"]) + return keyChecker +} - kc := new(keychain) - kc.keys = append(kc.keys, privateKey) +func clientConfig() *ssh.ClientConfig { config := &ssh.ClientConfig{ User: username(), - Auth: []ssh.ClientAuth{ - ssh.ClientAuthKeyring(kc), + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(testSigners["user"]), }, - HostKeyChecker: &keyChecker, + HostKeyCallback: hostKeyDB().Check, } return config } @@ -171,7 +152,7 @@ func unixConnection() (*net.UnixConn, *net.UnixConn, error) { return c1.(*net.UnixConn), c2.(*net.UnixConn), nil } -func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) { +func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { sshd, err := exec.LookPath("sshd") if err != nil { s.t.Skipf("skipping test: %v", err) @@ -197,10 +178,14 @@ func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) { s.t.Fatalf("s.cmd.Start: %v", err) } s.clientConn = c1 - return ssh.Client(c1, config) + conn, chans, reqs, err := ssh.NewClientConn(c1, "", config) + if err != nil { + return nil, err + } + return ssh.NewClient(conn, chans, reqs), nil } -func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn { +func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client { conn, err := s.TryDial(config) if err != nil { s.t.Fail() @@ -226,6 +211,17 @@ func (s *server) Shutdown() { s.cleanup() } +func writeFile(path string, contents []byte) { + f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) + if err != nil { + panic(err) + } + defer f.Close() + if _, err := f.Write(contents); err != nil { + panic(err) + } +} + // newServer returns a new mock ssh server. func newServer(t *testing.T) *server { dir, err := ioutil.TempDir("", "sshtest") @@ -244,15 +240,10 @@ func newServer(t *testing.T) *server { } f.Close() - for k, v := range keys { - f, err := os.OpenFile(filepath.Join(dir, k), os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) - if err != nil { - t.Fatal(err) - } - if _, err := f.Write([]byte(v)); err != nil { - t.Fatal(err) - } - f.Close() + for k, v := range testdata.PEMBytes { + filename := "id_" + k + writeFile(filepath.Join(dir, filename), v) + writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k])) } return &server{ @@ -265,32 +256,3 @@ func newServer(t *testing.T) *server { }, } } - -// keychain implements the ClientKeyring interface. -type keychain struct { - keys []ssh.Signer -} - -func (k *keychain) Key(i int) (ssh.PublicKey, error) { - if i < 0 || i >= len(k.keys) { - return nil, nil - } - return k.keys[i].PublicKey(), nil -} - -func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) { - return k.keys[i].Sign(rand, data) -} - -func (k *keychain) loadPEM(file string) error { - buf, err := ioutil.ReadFile(file) - if err != nil { - return err - } - key, err := ssh.ParsePrivateKey(buf) - if err != nil { - return err - } - k.keys = append(k.keys, key) - return nil -} diff --git a/ssh/test/testdata_test.go b/ssh/test/testdata_test.go new file mode 100644 index 0000000..7f50fbe --- /dev/null +++ b/ssh/test/testdata_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package test + +import ( + "crypto/rand" + "fmt" + + "code.google.com/p/go.crypto/ssh" + "code.google.com/p/go.crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]ssh.Signer + testPublicKeys map[string]ssh.PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]ssh.Signer, n) + testPublicKeys = make(map[string]ssh.PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &ssh.Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/ssh/testdata/doc.go b/ssh/testdata/doc.go new file mode 100644 index 0000000..4302486 --- /dev/null +++ b/ssh/testdata/doc.go @@ -0,0 +1,8 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package contains test data shared between the various subpackages of +// the code.google.com/p/go.crypto/ssh package. Under no circumstance should +// this data be used for production code. +package testdata diff --git a/ssh/testdata/keys.go b/ssh/testdata/keys.go new file mode 100644 index 0000000..5ff1c0e --- /dev/null +++ b/ssh/testdata/keys.go @@ -0,0 +1,43 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testdata + +var PEMBytes = map[string][]byte{ + "dsa": []byte(`-----BEGIN DSA PRIVATE KEY----- +MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB +lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3 +EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD +nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV +2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r +juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr +FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz +DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj +nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY +Fmsr0W6fHB9nhS4/UXM8 +-----END DSA PRIVATE KEY----- +`), + "ecdsa": []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49 +AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ +6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA== +-----END EC PRIVATE KEY----- +`), + "rsa": []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld +r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ +tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC +nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW +2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB +y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr +rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg== +-----END RSA PRIVATE KEY----- +`), + "user": []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49 +AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD +PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w== +-----END EC PRIVATE KEY----- +`), +} diff --git a/ssh/testdata_test.go b/ssh/testdata_test.go new file mode 100644 index 0000000..302fdc8 --- /dev/null +++ b/ssh/testdata_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package ssh + +import ( + "crypto/rand" + "fmt" + + "code.google.com/p/go.crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]Signer + testPublicKeys map[string]PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]Signer, n) + testPublicKeys = make(map[string]PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/ssh/transport.go b/ssh/transport.go index 46fa262..4f68b04 100644 --- a/ssh/transport.go +++ b/ssh/transport.go @@ -6,26 +6,12 @@ package ssh import ( "bufio" - "crypto/cipher" - "crypto/subtle" - "encoding/binary" "errors" - "hash" "io" - "net" - "sync" ) const ( - packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. - - // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations - // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC - // indicates implementations SHOULD be able to handle larger packet sizes, but then - // waffles on about reasonable limits. - // - // OpenSSH caps their maxPacket at 256kb so we choose to do the same. - maxPacket = 256 * 1024 + gcmCipherID = "aes128-gcm@openssh.com" ) // packetConn represents a transport that implements packet based @@ -41,225 +27,128 @@ type packetConn interface { Close() error } -// transport represents the SSH connection to the remote peer. +// transport is the keyingTransport that implements the SSH packet +// protocol. type transport struct { - reader - writer + reader connectionState + writer connectionState + + bufReader *bufio.Reader + bufWriter *bufio.Writer + rand io.Reader - net.Conn + io.Closer // Initial H used for the session ID. Once assigned this does // not change, even during subsequent key exchanges. sessionID []byte } -// reader represents the incoming connection state. -type reader struct { - io.Reader - common +func (t *transport) getSessionID() []byte { + if t.sessionID == nil { + panic("session ID not set yet") + } + s := make([]byte, len(t.sessionID)) + copy(s, t.sessionID) + return s } -// writer represents the outgoing connection state. -type writer struct { - sync.Mutex // protects writer.Writer from concurrent writes - *bufio.Writer - rand io.Reader - common +// packetCipher represents a combination of SSH encryption/MAC +// protocol. A single instance should be used for one direction only. +type packetCipher interface { + // writePacket encrypts the packet and writes it to w. The + // contents of the packet are generally scrambled. + writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error + + // readPacket reads and decrypts a packet of data. The + // returned packet may be overwritten by future calls of + // readPacket. + readPacket(seqnum uint32, r io.Reader) ([]byte, error) +} + +// connectionState represents one side (read or write) of the +// connection. This is necessary because each direction has its own +// keys, and can even have its own algorithms +type connectionState struct { + packetCipher + seqNum uint32 + dir direction + pendingKeyChange chan packetCipher } // prepareKeyChange sets up key material for a keychange. The key changes in // both directions are triggered by reading and writing a msgNewKey packet // respectively. func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { - t.writer.cipherAlgo = algs.wCipher - t.writer.macAlgo = algs.wMAC - t.writer.compressionAlgo = algs.wCompression - - t.reader.cipherAlgo = algs.rCipher - t.reader.macAlgo = algs.rMAC - t.reader.compressionAlgo = algs.rCompression - if t.sessionID == nil { t.sessionID = kexResult.H } kexResult.SessionID = t.sessionID - t.reader.pendingKeyChange <- kexResult - t.writer.pendingKeyChange <- kexResult - return nil -} - -// common represents the cipher state needed to process messages in a single -// direction. -type common struct { - seqNum uint32 - mac hash.Hash - cipher cipher.Stream - - cipherAlgo string - macAlgo string - compressionAlgo string - - dir direction - pendingKeyChange chan *kexResult -} - -// Read and decrypt a single packet from the remote peer. -func (r *reader) readPacket() ([]byte, error) { - var lengthBytes = make([]byte, 5) - var macSize uint32 - if _, err := io.ReadFull(r, lengthBytes); err != nil { - return nil, err - } - - r.cipher.XORKeyStream(lengthBytes, lengthBytes) - if r.mac != nil { - r.mac.Reset() - seqNumBytes := []byte{ - byte(r.seqNum >> 24), - byte(r.seqNum >> 16), - byte(r.seqNum >> 8), - byte(r.seqNum), - } - r.mac.Write(seqNumBytes) - r.mac.Write(lengthBytes) - macSize = uint32(r.mac.Size()) + if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { + return err + } else { + t.reader.pendingKeyChange <- ciph } - length := binary.BigEndian.Uint32(lengthBytes[0:4]) - paddingLength := uint32(lengthBytes[4]) - - if length <= paddingLength+1 { - return nil, errors.New("ssh: invalid packet length, packet too small") + if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { + return err + } else { + t.writer.pendingKeyChange <- ciph } - if length > maxPacket { - return nil, errors.New("ssh: invalid packet length, packet too large") - } + return nil +} - packet := make([]byte, length-1+macSize) - if _, err := io.ReadFull(r, packet); err != nil { - return nil, err - } - mac := packet[length-1:] - r.cipher.XORKeyStream(packet, packet[:length-1]) +// Read and decrypt next packet. +func (t *transport) readPacket() ([]byte, error) { + return t.reader.readPacket(t.bufReader) +} - if r.mac != nil { - r.mac.Write(packet[:length-1]) - if subtle.ConstantTimeCompare(r.mac.Sum(nil), mac) != 1 { - return nil, errors.New("ssh: MAC failure") - } +func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { + packet, err := s.packetCipher.readPacket(s.seqNum, r) + s.seqNum++ + if err == nil && len(packet) == 0 { + err = errors.New("ssh: zero length packet") } - r.seqNum++ - packet = packet[:length-paddingLength-1] - if len(packet) > 0 && packet[0] == msgNewKeys { select { - case k := <-r.pendingKeyChange: - if err := r.setupKeys(r.dir, k); err != nil { - return nil, err - } + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher default: return nil, errors.New("ssh: got bogus newkeys message.") } } - return packet, nil -} -// Read and decrypt next packet discarding debug and noop messages. -func (t *transport) readPacket() ([]byte, error) { - for { - packet, err := t.reader.readPacket() - if err != nil { - return nil, err - } - if len(packet) == 0 { - return nil, errors.New("ssh: zero length packet") - } + // The packet may point to an internal buffer, so copy the + // packet out here. + fresh := make([]byte, len(packet)) + copy(fresh, packet) - if packet[0] != msgIgnore && packet[0] != msgDebug { - return packet, nil - } - } - panic("unreachable") + return fresh, err } -// Encrypt and send a packet of data to the remote peer. -func (w *writer) writePacket(packet []byte) error { - changeKeys := len(packet) > 0 && packet[0] == msgNewKeys - - if len(packet) > maxPacket { - return errors.New("ssh: packet too large") - } - w.Mutex.Lock() - defer w.Mutex.Unlock() +func (t *transport) writePacket(packet []byte) error { + return t.writer.writePacket(t.bufWriter, t.rand, packet) +} - paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple - if paddingLength < 4 { - paddingLength += packetSizeMultiple - } +func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { + changeKeys := len(packet) > 0 && packet[0] == msgNewKeys - length := len(packet) + 1 + paddingLength - lengthBytes := []byte{ - byte(length >> 24), - byte(length >> 16), - byte(length >> 8), - byte(length), - byte(paddingLength), - } - padding := make([]byte, paddingLength) - _, err := io.ReadFull(w.rand, padding) + err := s.packetCipher.writePacket(s.seqNum, w, rand, packet) if err != nil { return err } - - if w.mac != nil { - w.mac.Reset() - seqNumBytes := []byte{ - byte(w.seqNum >> 24), - byte(w.seqNum >> 16), - byte(w.seqNum >> 8), - byte(w.seqNum), - } - w.mac.Write(seqNumBytes) - w.mac.Write(lengthBytes) - w.mac.Write(packet) - w.mac.Write(padding) - } - - // TODO(dfc) lengthBytes, packet and padding should be - // subslices of a single buffer - w.cipher.XORKeyStream(lengthBytes, lengthBytes) - w.cipher.XORKeyStream(packet, packet) - w.cipher.XORKeyStream(padding, padding) - - if _, err := w.Write(lengthBytes); err != nil { - return err - } - if _, err := w.Write(packet); err != nil { - return err - } - if _, err := w.Write(padding); err != nil { - return err - } - - if w.mac != nil { - if _, err := w.Write(w.mac.Sum(nil)); err != nil { - return err - } - } - - w.seqNum++ if err = w.Flush(); err != nil { return err } - + s.seqNum++ if changeKeys { select { - case k := <-w.pendingKeyChange: - err = w.setupKeys(w.dir, k) + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher default: panic("ssh: no key material for msgNewKeys") } @@ -267,24 +156,20 @@ func (w *writer) writePacket(packet []byte) error { return err } -func newTransport(conn net.Conn, rand io.Reader, isClient bool) *transport { +func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { t := &transport{ - reader: reader{ - Reader: bufio.NewReader(conn), - common: common{ - cipher: noneCipher{}, - pendingKeyChange: make(chan *kexResult, 1), - }, + bufReader: bufio.NewReader(rwc), + bufWriter: bufio.NewWriter(rwc), + rand: rand, + reader: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), }, - writer: writer{ - Writer: bufio.NewWriter(conn), - rand: rand, - common: common{ - cipher: noneCipher{}, - pendingKeyChange: make(chan *kexResult, 1), - }, + writer: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), }, - Conn: conn, + Closer: rwc, } if isClient { t.reader.dir = serverKeys @@ -303,48 +188,64 @@ type direction struct { macKeyTag []byte } -// TODO(dfc) can this be made a constant ? var ( serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} ) +// generateKeys generates key material for IV, MAC and encryption. +func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) { + cipherMode := cipherModes[algs.Cipher] + macMode := macModes[algs.MAC] + + iv = make([]byte, cipherMode.ivSize) + key = make([]byte, cipherMode.keySize) + macKey = make([]byte, macMode.keySize) + + generateKeyMaterial(iv, d.ivTag, kex) + generateKeyMaterial(key, d.keyTag, kex) + generateKeyMaterial(macKey, d.macKeyTag, kex) + return +} + // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as // described in RFC 4253, section 6.4. direction should either be serverKeys // (to setup server->client keys) or clientKeys (for client->server keys). -func (c *common) setupKeys(d direction, r *kexResult) error { - cipherMode := cipherModes[c.cipherAlgo] - macMode := macModes[c.macAlgo] +func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { + iv, key, macKey := generateKeys(d, algs, kex) - iv := make([]byte, cipherMode.ivSize) - key := make([]byte, cipherMode.keySize) - macKey := make([]byte, macMode.keySize) - - h := r.Hash.New() - generateKeyMaterial(iv, d.ivTag, r.K, r.H, r.SessionID, h) - generateKeyMaterial(key, d.keyTag, r.K, r.H, r.SessionID, h) - generateKeyMaterial(macKey, d.macKeyTag, r.K, r.H, r.SessionID, h) + if algs.Cipher == gcmCipherID { + return newGCMCipher(iv, key, macKey) + } - c.mac = macMode.new(macKey) + c := &streamPacketCipher{ + mac: macModes[algs.MAC].new(macKey), + } + c.macResult = make([]byte, c.mac.Size()) var err error - c.cipher, err = cipherMode.createCipher(key, iv) - return err + c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv) + if err != nil { + return nil, err + } + + return c, nil } // generateKeyMaterial fills out with key material generated from tag, K, H // and sessionId, as specified in RFC 4253, section 7.2. -func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) { +func generateKeyMaterial(out, tag []byte, r *kexResult) { var digestsSoFar []byte + h := r.Hash.New() for len(out) > 0 { h.Reset() - h.Write(K) - h.Write(H) + h.Write(r.K) + h.Write(r.H) if len(digestsSoFar) == 0 { h.Write(tag) - h.Write(sessionId) + h.Write(r.SessionID) } else { h.Write(digestsSoFar) } diff --git a/ssh/transport_test.go b/ssh/transport_test.go index 3320114..92d83ab 100644 --- a/ssh/transport_test.go +++ b/ssh/transport_test.go @@ -6,6 +6,8 @@ package ssh import ( "bytes" + "crypto/rand" + "encoding/binary" "strings" "testing" ) @@ -67,3 +69,41 @@ func TestExchangeVersions(t *testing.T) { } } } + +type closerBuffer struct { + bytes.Buffer +} + +func (b *closerBuffer) Close() error { + return nil +} + +func TestTransportMaxPacketWrite(t *testing.T) { + buf := &closerBuffer{} + tr := newTransport(buf, rand.Reader, true) + huge := make([]byte, maxPacket+1) + err := tr.writePacket(huge) + if err == nil { + t.Errorf("transport accepted write for a huge packet.") + } +} + +func TestTransportMaxPacketReader(t *testing.T) { + var header [5]byte + huge := make([]byte, maxPacket+128) + binary.BigEndian.PutUint32(header[0:], uint32(len(huge))) + // padding. + header[4] = 0 + + buf := &closerBuffer{} + buf.Write(header[:]) + buf.Write(huge) + + tr := newTransport(buf, rand.Reader, true) + _, err := tr.readPacket() + if err == nil { + t.Errorf("transport succeeded reading huge packet.") + } else if !strings.Contains(err.Error(), "large") { + t.Errorf("got %q, should mention %q", err.Error(), "large") + } +} |
