diff options
| author | Shulhan <m.shulhan@gmail.com> | 2023-12-22 13:42:41 +0700 |
|---|---|---|
| committer | Shulhan <m.shulhan@gmail.com> | 2026-03-27 02:56:30 +0700 |
| commit | d25ba3de70123bbe0659bfab221fc850b892dd17 (patch) | |
| tree | 8d11078a675f54eea93191ef424731eacf91ae3b | |
| parent | f88a1e842ae6c96508e445694e94e1b7d84415a5 (diff) | |
| download | go-x-crypto-d25ba3de70123bbe0659bfab221fc850b892dd17.tar.xz | |
ssh: implement Session Run with context
The RunWithContext similar to Run but terminate the remote command
with SIGKILL when its receive context cancellation.
| -rw-r--r-- | ssh/session.go | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/ssh/session.go b/ssh/session.go index acef622..f940686 100644 --- a/ssh/session.go +++ b/ssh/session.go @@ -9,6 +9,7 @@ package ssh import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -314,6 +315,16 @@ func (s *Session) Run(cmd string) error { return s.Wait() } +// RunWithContext similar to [Session.Run] but with context. +func (s *Session) RunWithContext(ctx context.Context, cmd string) (err error) { + err = s.Start(cmd) + if err != nil { + return err + } + + return s.waitWithContext(ctx) +} + // Output runs cmd on the remote host and returns its standard output. func (s *Session) Output(cmd string) ([]byte, error) { if s.Stdout != nil { @@ -397,10 +408,22 @@ func (s *Session) start() error { // unsuccessfully or is interrupted by a signal, the error is of type // *ExitError. Other error types may be returned for I/O problems. func (s *Session) Wait() error { + return s.waitWithContext(context.Background()) +} + +// waitWithContext wait for remote command to exit or terminate with SIGKILL +// when its receive context cancellation. +func (s *Session) waitWithContext(ctx context.Context) (err error) { if !s.started { return errors.New("ssh: session not started") } - waitErr := <-s.exitStatus + + var waitErr error + select { + case <-ctx.Done(): + waitErr = s.Signal(SIGKILL) + case waitErr = <-s.exitStatus: + } if s.stdinPipeWriter != nil { s.stdinPipeWriter.Close() @@ -414,7 +437,10 @@ func (s *Session) Wait() error { if waitErr != nil { return waitErr } - return copyError + if copyError != nil { + return copyError + } + return context.Cause(ctx) } func (s *Session) wait(reqs <-chan *Request) error { |
