diff options
| author | Shulhan <m.shulhan@gmail.com> | 2023-12-22 13:42:41 +0700 |
|---|---|---|
| committer | Shulhan <m.shulhan@gmail.com> | 2023-12-22 16:23:26 +0700 |
| commit | 10c3c1ae8a5e3c843f377034bae617b2ded59e0e (patch) | |
| tree | 395cb3c090655ed50d20c8d7b2975a618442c27c | |
| parent | 9d2ee975ef9fe627bf0a6f01c1f69e8ef1d4f05d (diff) | |
| download | go-x-crypto-ssh-run-with-context.tar.xz | |
ssh: implement Session Run with contextssh-run-with-context.mailedssh-run-with-context
The RunWithContext similar to Run but terminate the remote command
with SIGKILL when its receive context cancellation.
Change-Id: Ib82e23b77450bef222bba8576eca11b9d356688b
| -rw-r--r-- | ssh/session.go | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/ssh/session.go b/ssh/session.go index acef622..afd46e6 100644 --- a/ssh/session.go +++ b/ssh/session.go @@ -9,6 +9,7 @@ package ssh import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -314,6 +315,16 @@ func (s *Session) Run(cmd string) error { return s.Wait() } +// RunWithContext similar to [Session.Run] with context. +func (s *Session) RunWithContext(ctx context.Context, cmd string) (err error) { + err = s.Start(cmd) + if err != nil { + return err + } + + return s.waitWithContext(ctx) +} + // Output runs cmd on the remote host and returns its standard output. func (s *Session) Output(cmd string) ([]byte, error) { if s.Stdout != nil { @@ -397,10 +408,22 @@ func (s *Session) start() error { // unsuccessfully or is interrupted by a signal, the error is of type // *ExitError. Other error types may be returned for I/O problems. func (s *Session) Wait() error { + return s.waitWithContext(context.Background()) +} + +// waitWithContext wait for remote command to exit or terminate with SIGKILL +// when its receive context cancellation. +func (s *Session) waitWithContext(ctx context.Context) (err error) { if !s.started { return errors.New("ssh: session not started") } - waitErr := <-s.exitStatus + + var waitErr error + select { + case <-ctx.Done(): + waitErr = s.Signal(SIGKILL) + case waitErr = <-s.exitStatus: + } if s.stdinPipeWriter != nil { s.stdinPipeWriter.Close() @@ -414,7 +437,10 @@ func (s *Session) Wait() error { if waitErr != nil { return waitErr } - return copyError + if copyError != nil { + return copyError + } + return context.Cause(ctx) } func (s *Session) wait(reqs <-chan *Request) error { |
