diff options
| author | Adam Langley <agl@golang.org> | 2014-04-09 13:57:52 -0700 |
|---|---|---|
| committer | Adam Langley <agl@golang.org> | 2014-04-09 13:57:52 -0700 |
| commit | fa50e7408b9ef89ff2965535b59f1a0010c0770b (patch) | |
| tree | e045a3f48f9ffd3bb712002f8f9f6fd489e8f7ef /ssh/session.go | |
| parent | 8f45c680ceb25c200b8c301d9184532aeb7cb36e (diff) | |
| download | go-x-crypto-fa50e7408b9ef89ff2965535b59f1a0010c0770b.tar.xz | |
go.crypto/ssh: import gosshnew.
See https://groups.google.com/d/msg/Golang-nuts/AoVxQ4bB5XQ/i8kpMxdbVlEJ
R=hanwen
CC=golang-codereviews
https://golang.org/cl/86190043
Diffstat (limited to 'ssh/session.go')
| -rw-r--r-- | ssh/session.go | 309 |
1 files changed, 144 insertions, 165 deletions
diff --git a/ssh/session.go b/ssh/session.go index 39f2d22..3b42b50 100644 --- a/ssh/session.go +++ b/ssh/session.go @@ -129,128 +129,126 @@ type Session struct { Stdout io.Writer Stderr io.Writer - *clientChan // the channel backing this session - - started bool // true once Start, Run or Shell is invoked. + ch Channel // the channel backing this session + started bool // true once Start, Run or Shell is invoked. copyFuncs []func() error errors chan error // one send per copyFunc // true if pipe method is active stdinpipe, stdoutpipe, stderrpipe bool + + // stdinPipeWriter is non-nil if StdinPipe has not been called + // and Stdin was specified by the user; it is the write end of + // a pipe connecting Session.Stdin to the stdin channel. + stdinPipeWriter io.WriteCloser + + exitStatus chan error } -// RFC 4254 Section 6.4. -type setenvRequest struct { - PeersId uint32 - Request string - WantReply bool - Name string - Value string +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return s.ch.SendRequest(name, wantReply, payload) } -// RFC 4254 Section 6.5. -type subsystemRequestMsg struct { - PeersId uint32 - Request string - WantReply bool - Subsystem string +func (s *Session) Close() error { + return s.ch.Close() +} + +// RFC 4254 Section 6.4. +type setenvRequest struct { + Name string + Value string } // Setenv sets an environment variable that will be applied to any // command executed by Shell or Run. func (s *Session) Setenv(name, value string) error { - req := setenvRequest{ - PeersId: s.remoteId, - Request: "env", - WantReply: true, - Name: name, - Value: value, + msg := setenvRequest{ + Name: name, + Value: value, } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { - return err + ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: setenv failed") } - return s.waitForResponse() + return err } // RFC 4254 Section 6.2. type ptyRequestMsg struct { - PeersId uint32 - Request string - WantReply bool - Term string - Columns uint32 - Rows uint32 - Width uint32 - Height uint32 - Modelist string + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string } // RequestPty requests the association of a pty with the session on the remote host. func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { var tm []byte for k, v := range termmodes { - tm = append(tm, k) - tm = appendU32(tm, v) + kv := struct { + Key byte + Val uint32 + }{k, v} + + tm = append(tm, Marshal(&kv)...) } tm = append(tm, tty_OP_END) req := ptyRequestMsg{ - PeersId: s.remoteId, - Request: "pty-req", - WantReply: true, - Term: term, - Columns: uint32(w), - Rows: uint32(h), - Width: uint32(w * 8), - Height: uint32(h * 8), - Modelist: string(tm), + Term: term, + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + Modelist: string(tm), } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { - return err + ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) + if err == nil && !ok { + err = errors.New("ssh: pty-req failed") } - return s.waitForResponse() + return err +} + +// RFC 4254 Section 6.5. +type subsystemRequestMsg struct { + Subsystem string } // RequestSubsystem requests the association of a subsystem with the session on the remote host. // A subsystem is a predefined command that runs in the background when the ssh session is initiated func (s *Session) RequestSubsystem(subsystem string) error { - req := subsystemRequestMsg{ - PeersId: s.remoteId, - Request: "subsystem", - WantReply: true, + msg := subsystemRequestMsg{ Subsystem: subsystem, } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { - return err + ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: subsystem request failed") } - return s.waitForResponse() + return err } // RFC 4254 Section 6.9. type signalMsg struct { - PeersId uint32 - Request string - WantReply bool - Signal string + Signal string } // Signal sends the given signal to the remote process. // sig is one of the SIG* constants. func (s *Session) Signal(sig Signal) error { - req := signalMsg{ - PeersId: s.remoteId, - Request: "signal", - WantReply: false, - Signal: string(sig), + msg := signalMsg{ + Signal: string(sig), } - return s.writePacket(marshal(msgChannelRequest, req)) + + _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) + return err } // RFC 4254 Section 6.5. type execMsg struct { - PeersId uint32 - Request string - WantReply bool - Command string + Command string } // Start runs cmd on the remote host. Typically, the remote @@ -261,16 +259,15 @@ func (s *Session) Start(cmd string) error { return errors.New("ssh: session already started") } req := execMsg{ - PeersId: s.remoteId, - Request: "exec", - WantReply: true, - Command: cmd, + Command: cmd, } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { - return err + + ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) + if err == nil && !ok { + err = fmt.Errorf("ssh: command %v failed", cmd) } - if err := s.waitForResponse(); err != nil { - return fmt.Errorf("ssh: could not execute command %s: %v", cmd, err) + if err != nil { + return err } return s.start() } @@ -339,31 +336,17 @@ func (s *Session) Shell() error { if s.started { return errors.New("ssh: session already started") } - req := channelRequestMsg{ - PeersId: s.remoteId, - Request: "shell", - WantReply: true, + + ok, err := s.ch.SendRequest("shell", true, nil) + if err == nil && !ok { + return fmt.Errorf("ssh: cound not start shell") } - if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { + if err != nil { return err } - if err := s.waitForResponse(); err != nil { - return fmt.Errorf("ssh: could not execute shell: %v", err) - } return s.start() } -func (s *Session) waitForResponse() error { - msg := <-s.msg - switch msg.(type) { - case *channelRequestSuccessMsg: - return nil - case *channelRequestFailureMsg: - return errors.New("ssh: request failed") - } - return fmt.Errorf("ssh: unknown packet %T received: %v", msg, msg) -} - func (s *Session) start() error { s.started = true @@ -394,8 +377,11 @@ func (s *Session) Wait() error { if !s.started { return errors.New("ssh: session not started") } - waitErr := s.wait() + waitErr := <-s.exitStatus + if s.stdinPipeWriter != nil { + s.stdinPipeWriter.Close() + } var copyError error for _ = range s.copyFuncs { if err := <-s.errors; err != nil && copyError == nil { @@ -408,52 +394,35 @@ func (s *Session) Wait() error { return copyError } -func (s *Session) wait() error { +func (s *Session) wait(reqs <-chan *Request) error { wm := Waitmsg{status: -1} - // Wait for msg channel to be closed before returning. - for msg := range s.msg { - switch msg := msg.(type) { - case *channelRequestMsg: - switch msg.Request { - case "exit-status": - d := msg.RequestSpecificData - wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) - case "exit-signal": - signal, rest, ok := parseString(msg.RequestSpecificData) - if !ok { - return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) - } - wm.signal = safeString(string(signal)) - - // skip coreDumped bool - if len(rest) == 0 { - return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) - } - rest = rest[1:] - - errmsg, rest, ok := parseString(rest) - if !ok { - return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) - } - wm.msg = safeString(string(errmsg)) - - lang, _, ok := parseString(rest) - if !ok { - return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) - } - wm.lang = safeString(string(lang)) - default: - // This handles keepalives and matches - // OpenSSH's behaviour. - if msg.WantReply { - s.writePacket(marshal(msgChannelFailure, channelRequestFailureMsg{ - PeersId: s.remoteId, - })) - } + for msg := range reqs { + switch msg.Type { + case "exit-status": + d := msg.Payload + wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) + case "exit-signal": + var sigval struct { + Signal string + CoreDumped bool + Error string + Lang string + } + if err := Unmarshal(msg.Payload, &sigval); err != nil { + return err } + + // Must sanitize strings? + wm.signal = sigval.Signal + wm.msg = sigval.Error + wm.lang = sigval.Lang default: - return fmt.Errorf("wait: unexpected packet %T received: %v", msg, msg) + // This handles keepalives and matches + // OpenSSH's behaviour. + if msg.WantReply { + msg.Reply(false, nil) + } } } if wm.status == 0 { @@ -476,12 +445,20 @@ func (s *Session) stdin() { if s.stdinpipe { return } + var stdin io.Reader if s.Stdin == nil { - s.Stdin = new(bytes.Buffer) + stdin = new(bytes.Buffer) + } else { + r, w := io.Pipe() + go func() { + _, err := io.Copy(w, s.Stdin) + w.CloseWithError(err) + }() + stdin, s.stdinPipeWriter = r, w } s.copyFuncs = append(s.copyFuncs, func() error { - _, err := io.Copy(s.clientChan.stdin, s.Stdin) - if err1 := s.clientChan.stdin.Close(); err == nil && err1 != io.EOF { + _, err := io.Copy(s.ch, stdin) + if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { err = err1 } return err @@ -496,7 +473,7 @@ func (s *Session) stdout() { s.Stdout = ioutil.Discard } s.copyFuncs = append(s.copyFuncs, func() error { - _, err := io.Copy(s.Stdout, s.clientChan.stdout) + _, err := io.Copy(s.Stdout, s.ch) return err }) } @@ -509,11 +486,21 @@ func (s *Session) stderr() { s.Stderr = ioutil.Discard } s.copyFuncs = append(s.copyFuncs, func() error { - _, err := io.Copy(s.Stderr, s.clientChan.stderr) + _, err := io.Copy(s.Stderr, s.ch.Stderr()) return err }) } +// sessionStdin reroutes Close to CloseWrite. +type sessionStdin struct { + io.Writer + ch Channel +} + +func (s *sessionStdin) Close() error { + return s.ch.CloseWrite() +} + // StdinPipe returns a pipe that will be connected to the // remote command's standard input when the command starts. func (s *Session) StdinPipe() (io.WriteCloser, error) { @@ -524,7 +511,7 @@ func (s *Session) StdinPipe() (io.WriteCloser, error) { return nil, errors.New("ssh: StdinPipe after process started") } s.stdinpipe = true - return s.clientChan.stdin, nil + return &sessionStdin{s.ch, s.ch}, nil } // StdoutPipe returns a pipe that will be connected to the @@ -541,7 +528,7 @@ func (s *Session) StdoutPipe() (io.Reader, error) { return nil, errors.New("ssh: StdoutPipe after process started") } s.stdoutpipe = true - return s.clientChan.stdout, nil + return s.ch, nil } // StderrPipe returns a pipe that will be connected to the @@ -558,28 +545,20 @@ func (s *Session) StderrPipe() (io.Reader, error) { return nil, errors.New("ssh: StderrPipe after process started") } s.stderrpipe = true - return s.clientChan.stderr, nil + return s.ch.Stderr(), nil } -// NewSession returns a new interactive session on the remote host. -func (c *ClientConn) NewSession() (*Session, error) { - ch := c.newChan(c.transport) - if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenMsg{ - ChanType: "session", - PeersId: ch.localId, - PeersWindow: channelWindowSize, - MaxPacketSize: channelMaxPacketSize, - })); err != nil { - c.chanList.remove(ch.localId) - return nil, err +// newSession returns a new interactive session on the remote host. +func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { + s := &Session{ + ch: ch, } - if err := ch.waitForChannelOpenResponse(); err != nil { - c.chanList.remove(ch.localId) - return nil, fmt.Errorf("ssh: unable to open session: %v", err) - } - return &Session{ - clientChan: ch, - }, nil + s.exitStatus = make(chan error, 1) + go func() { + s.exitStatus <- s.wait(reqs) + }() + + return s, nil } // An ExitError reports unsuccessful completion of a remote command. |
