diff options
| author | Adam Langley <agl@golang.org> | 2014-04-09 13:57:52 -0700 |
|---|---|---|
| committer | Adam Langley <agl@golang.org> | 2014-04-09 13:57:52 -0700 |
| commit | fa50e7408b9ef89ff2965535b59f1a0010c0770b (patch) | |
| tree | e045a3f48f9ffd3bb712002f8f9f6fd489e8f7ef /ssh/session_test.go | |
| parent | 8f45c680ceb25c200b8c301d9184532aeb7cb36e (diff) | |
| download | go-x-crypto-fa50e7408b9ef89ff2965535b59f1a0010c0770b.tar.xz | |
go.crypto/ssh: import gosshnew.
See https://groups.google.com/d/msg/Golang-nuts/AoVxQ4bB5XQ/i8kpMxdbVlEJ
R=hanwen
CC=golang-codereviews
https://golang.org/cl/86190043
Diffstat (limited to 'ssh/session_test.go')
| -rw-r--r-- | ssh/session_test.go | 411 |
1 files changed, 125 insertions, 286 deletions
diff --git a/ssh/session_test.go b/ssh/session_test.go index 5cff58a..cc26573 100644 --- a/ssh/session_test.go +++ b/ssh/session_test.go @@ -12,71 +12,60 @@ import ( "io" "io/ioutil" "math/rand" - "net" "testing" "code.google.com/p/go.crypto/ssh/terminal" ) -type serverType func(*serverChan, *testing.T) +type serverType func(Channel, <-chan *Request, *testing.T) // dial constructs a new test server and returns a *ClientConn. -func dial(handler serverType, t *testing.T) *ClientConn { - l, err := Listen("tcp", "127.0.0.1:0", serverConfig) +func dial(handler serverType, t *testing.T) *Client { + c1, c2, err := netPipe() if err != nil { - t.Fatalf("unable to listen: %v", err) + t.Fatalf("netPipe: %v", err) } + go func() { - defer l.Close() - conn, err := l.Accept() - if err != nil { - t.Errorf("Unable to accept: %v", err) - return + defer c1.Close() + conf := ServerConfig{ + NoClientAuth: true, } - defer conn.Close() - if err := conn.Handshake(); err != nil { - t.Errorf("Unable to handshake: %v", err) - return + conf.AddHostKey(testSigners["rsa"]) + + _, chans, reqs, err := NewServerConn(c1, &conf) + if err != nil { + t.Fatalf("Unable to handshake: %v", err) } - done := make(chan struct{}) - for { - ch, err := conn.Accept() - if err == io.EOF || err == io.ErrUnexpectedEOF { - return - } - // We sometimes get ECONNRESET rather than EOF. - if _, ok := err.(*net.OpError); ok { - return + go DiscardRequests(reqs) + + for newCh := range chans { + if newCh.ChannelType() != "session" { + newCh.Reject(UnknownChannelType, "unknown channel type") + continue } + + ch, inReqs, err := newCh.Accept() if err != nil { - t.Errorf("Unable to accept incoming channel request: %v", err) - return - } - if ch.ChannelType() != "session" { - ch.Reject(UnknownChannelType, "unknown channel type") + t.Errorf("Accept: %v", err) continue } - ch.Accept() go func() { - defer close(done) - handler(ch.(*serverChan), t) + handler(ch, inReqs, t) }() } - <-done }() config := &ClientConfig{ User: "testuser", - Auth: []ClientAuth{ - ClientAuthPassword(clientPassword), - }, } - c, err := Dial("tcp", l.Addr().String(), config) + conn, chans, reqs, err := NewClientConn(c2, "", config) if err != nil { t.Fatalf("unable to dial remote side: %v", err) } - return c + + return NewClient(conn, chans, reqs) } // Test a simple string is returned to session.Stdout. @@ -330,164 +319,6 @@ func TestExitWithoutStatusOrSignal(t *testing.T) { } } -func TestInvalidServerMessage(t *testing.T) { - conn := dial(sendInvalidRecord, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - // Make sure that we closed all the clientChans when the connection - // failed. - session.wait() - - defer session.Close() -} - -// In the wild some clients (and servers) send zero sized window updates. -// Test that the client can continue after receiving a zero sized update. -func TestClientZeroWindowAdjust(t *testing.T) { - conn := dial(sendZeroWindowAdjust, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - err = session.Wait() - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -// In the wild some clients (and servers) send zero sized window updates. -// Test that the server can continue after receiving a zero size update. -func TestServerZeroWindowAdjust(t *testing.T) { - conn := dial(exitStatusZeroHandler, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - - // send a bogus zero sized window update - session.clientChan.sendWindowAdj(0) - - err = session.Wait() - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -// Verify that the client never sends a packet larger than maxpacket. -func TestClientStdinRespectsMaxPacketSize(t *testing.T) { - conn := dial(discardHandler, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("failed to request new session: %v", err) - } - defer session.Close() - stdin, err := session.StdinPipe() - if err != nil { - t.Fatalf("failed to obtain stdinpipe: %v", err) - } - const size = 100 * 1000 - for i := 0; i < 10; i++ { - n, err := stdin.Write(make([]byte, size)) - if n != size || err != nil { - t.Fatalf("failed to write: %d, %v", n, err) - } - } -} - -// Verify that the client never accepts a packet larger than maxpacket. -func TestServerStdoutRespectsMaxPacketSize(t *testing.T) { - conn := dial(largeSendHandler, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - out, err := session.StdoutPipe() - if err != nil { - t.Fatalf("Unable to connect to Stdout: %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - if _, err := ioutil.ReadAll(out); err != nil { - t.Fatalf("failed to read: %v", err) - } -} - -func TestClientCannotSendAfterEOF(t *testing.T) { - conn := dial(exitWithoutSignalOrStatus, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - in, err := session.StdinPipe() - if err != nil { - t.Fatalf("Unable to connect channel stdin: %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - if err := in.Close(); err != nil { - t.Fatalf("Unable to close stdin: %v", err) - } - if _, err := in.Write([]byte("foo")); err == nil { - t.Fatalf("Session write should fail") - } -} - -func TestClientCannotSendAfterClose(t *testing.T) { - conn := dial(exitWithoutSignalOrStatus, t) - defer conn.Close() - session, err := conn.NewSession() - if err != nil { - t.Fatalf("Unable to request new session: %v", err) - } - defer session.Close() - in, err := session.StdinPipe() - if err != nil { - t.Fatalf("Unable to connect channel stdin: %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("Unable to execute command: %v", err) - } - // close underlying channel - if err := session.channel.Close(); err != nil { - t.Fatalf("Unable to close session: %v", err) - } - if _, err := in.Write([]byte("foo")); err == nil { - t.Fatalf("Session write should fail") - } -} - -func TestClientCannotSendHugePacket(t *testing.T) { - // client and server use the same transport write code so this - // test suffices for both. - conn := dial(shellHandler, t) - defer conn.Close() - if err := conn.transport.writePacket(make([]byte, maxPacket*2)); err == nil { - t.Fatalf("huge packet write should fail") - } -} - // windowTestBytes is the number of bytes that we'll send to the SSH server. const windowTestBytes = 16000 * 200 @@ -560,93 +391,104 @@ func TestClientHandlesKeepalives(t *testing.T) { } type exitStatusMsg struct { - PeersId uint32 - Request string - WantReply bool - Status uint32 + Status uint32 } type exitSignalMsg struct { - PeersId uint32 - Request string - WantReply bool Signal string CoreDumped bool Errmsg string Lang string } -func newServerShell(ch *serverChan, prompt string) *ServerTerminal { - term := terminal.NewTerminal(ch, prompt) - return &ServerTerminal{ - Term: term, - Channel: ch, +func handleTerminalRequests(in <-chan *Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any commands, only the default shell. + ok = false + } + case "env": + ok = true + } + req.Reply(ok, nil) } } -func exitStatusZeroHandler(ch *serverChan, t *testing.T) { +func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { + term := terminal.NewTerminal(ch, prompt) + go handleTerminalRequests(in) + return term +} + +func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() // this string is returned to stdout - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendStatus(0, ch, t) } -func exitStatusNonZeroHandler(ch *serverChan, t *testing.T) { +func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendStatus(15, ch, t) } -func exitSignalAndStatusHandler(ch *serverChan, t *testing.T) { +func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendStatus(15, ch, t) sendSignal("TERM", ch, t) } -func exitSignalHandler(ch *serverChan, t *testing.T) { +func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendSignal("TERM", ch, t) } -func exitSignalUnknownHandler(ch *serverChan, t *testing.T) { +func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) sendSignal("SYS", ch, t) } -func exitWithoutSignalOrStatus(ch *serverChan, t *testing.T) { +func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) } -func shellHandler(ch *serverChan, t *testing.T) { +func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() // this string is returned to stdout - shell := newServerShell(ch, "golang") + shell := newServerShell(ch, in, "golang") readLine(shell, t) sendStatus(0, ch, t) } // Ignores the command, writes fixed strings to stderr and stdout. // Strings are "this-is-stdout." and "this-is-stderr.". -func fixedOutputHandler(ch *serverChan, t *testing.T) { +func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() + _, err := ch.Read(nil) - _, err := ch.Read(make([]byte, 0)) - if _, ok := err.(ChannelRequest); !ok { + req, ok := <-in + if !ok { t.Fatalf("error: expected channel request, got: %#v", err) return } + // ignore request, always send some text - ch.AckRequest(true) + req.Reply(true, nil) _, err = io.WriteString(ch, "this-is-stdout.") if err != nil { @@ -659,84 +501,39 @@ func fixedOutputHandler(ch *serverChan, t *testing.T) { sendStatus(0, ch, t) } -func readLine(shell *ServerTerminal, t *testing.T) { +func readLine(shell *terminal.Terminal, t *testing.T) { if _, err := shell.ReadLine(); err != nil && err != io.EOF { t.Errorf("unable to read line: %v", err) } } -func sendStatus(status uint32, ch *serverChan, t *testing.T) { +func sendStatus(status uint32, ch Channel, t *testing.T) { msg := exitStatusMsg{ - PeersId: ch.remoteId, - Request: "exit-status", - WantReply: false, - Status: status, + Status: status, } - if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil { + if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { t.Errorf("unable to send status: %v", err) } } -func sendSignal(signal string, ch *serverChan, t *testing.T) { +func sendSignal(signal string, ch Channel, t *testing.T) { sig := exitSignalMsg{ - PeersId: ch.remoteId, - Request: "exit-signal", - WantReply: false, Signal: signal, CoreDumped: false, Errmsg: "Process terminated", Lang: "en-GB-oed", } - if err := ch.writePacket(marshal(msgChannelRequest, sig)); err != nil { + if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { t.Errorf("unable to send signal: %v", err) } } -func sendInvalidRecord(ch *serverChan, t *testing.T) { +func discardHandler(ch Channel, t *testing.T) { defer ch.Close() - packet := make([]byte, 1+4+4+1) - packet[0] = msgChannelData - marshalUint32(packet[1:], 29348723 /* invalid channel id */) - marshalUint32(packet[5:], 1) - packet[9] = 42 - - if err := ch.writePacket(packet); err != nil { - t.Errorf("unable send invalid record: %v", err) - } -} - -func sendZeroWindowAdjust(ch *serverChan, t *testing.T) { - defer ch.Close() - // send a bogus zero sized window update - ch.sendWindowAdj(0) - shell := newServerShell(ch, "> ") - readLine(shell, t) - sendStatus(0, ch, t) -} - -func discardHandler(ch *serverChan, t *testing.T) { - defer ch.Close() - // grow the window to avoid being fooled by - // the initial 1 << 14 window. - ch.sendWindowAdj(1024 * 1024) io.Copy(ioutil.Discard, ch) } -func largeSendHandler(ch *serverChan, t *testing.T) { - defer ch.Close() - // grow the window to avoid being fooled by - // the initial 1 << 14 window. - ch.sendWindowAdj(1024 * 1024) - shell := newServerShell(ch, "> ") - readLine(shell, t) - // try to send more than the 32k window - // will allow - if err := ch.writePacket(make([]byte, 128*1024)); err == nil { - t.Errorf("wrote packet larger than 32k") - } -} - -func echoHandler(ch *serverChan, t *testing.T) { +func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) @@ -773,17 +570,59 @@ func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, erro return written, nil } -func channelKeepaliveSender(ch *serverChan, t *testing.T) { +func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { defer ch.Close() - shell := newServerShell(ch, "> ") + shell := newServerShell(ch, in, "> ") readLine(shell, t) - msg := channelRequestMsg{ - PeersId: ch.remoteId, - Request: "keepalive@openssh.com", - WantReply: true, - } - if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil { + if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { t.Errorf("unable to send channel keepalive request: %v", err) } sendStatus(0, ch, t) } + +func TestClientWriteEOF(t *testing.T) { + conn := dial(simpleEchoHandler, t) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("StdoutPipe failed: %v", err) + } + + data := []byte(`0000`) + _, err = stdin.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + stdin.Close() + + res, err := ioutil.ReadAll(stdout) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, res) { + t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) + } +} + +func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + data, err := ioutil.ReadAll(ch) + if err != nil { + t.Errorf("handler read error: %v", err) + } + _, err = ch.Write(data) + if err != nil { + t.Errorf("handler write error: %v", err) + } +} |
