From d25ba3de70123bbe0659bfab221fc850b892dd17 Mon Sep 17 00:00:00 2001 From: Shulhan Date: Fri, 22 Dec 2023 13:42:41 +0700 Subject: ssh: implement Session Run with context The RunWithContext similar to Run but terminate the remote command with SIGKILL when its receive context cancellation. --- ssh/session.go | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/ssh/session.go b/ssh/session.go index acef622..f940686 100644 --- a/ssh/session.go +++ b/ssh/session.go @@ -9,6 +9,7 @@ package ssh import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -314,6 +315,16 @@ func (s *Session) Run(cmd string) error { return s.Wait() } +// RunWithContext similar to [Session.Run] but with context. +func (s *Session) RunWithContext(ctx context.Context, cmd string) (err error) { + err = s.Start(cmd) + if err != nil { + return err + } + + return s.waitWithContext(ctx) +} + // Output runs cmd on the remote host and returns its standard output. func (s *Session) Output(cmd string) ([]byte, error) { if s.Stdout != nil { @@ -397,10 +408,22 @@ func (s *Session) start() error { // unsuccessfully or is interrupted by a signal, the error is of type // *ExitError. Other error types may be returned for I/O problems. func (s *Session) Wait() error { + return s.waitWithContext(context.Background()) +} + +// waitWithContext wait for remote command to exit or terminate with SIGKILL +// when its receive context cancellation. +func (s *Session) waitWithContext(ctx context.Context) (err error) { if !s.started { return errors.New("ssh: session not started") } - waitErr := <-s.exitStatus + + var waitErr error + select { + case <-ctx.Done(): + waitErr = s.Signal(SIGKILL) + case waitErr = <-s.exitStatus: + } if s.stdinPipeWriter != nil { s.stdinPipeWriter.Close() @@ -414,7 +437,10 @@ func (s *Session) Wait() error { if waitErr != nil { return waitErr } - return copyError + if copyError != nil { + return copyError + } + return context.Cause(ctx) } func (s *Session) wait(reqs <-chan *Request) error { -- cgit v1.3