aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <m.shulhan@gmail.com>2023-12-22 13:42:41 +0700
committerShulhan <m.shulhan@gmail.com>2026-03-27 02:56:30 +0700
commitd25ba3de70123bbe0659bfab221fc850b892dd17 (patch)
tree8d11078a675f54eea93191ef424731eacf91ae3b
parentf88a1e842ae6c96508e445694e94e1b7d84415a5 (diff)
downloadgo-x-crypto-d25ba3de70123bbe0659bfab221fc850b892dd17.tar.xz
ssh: implement Session Run with context
The RunWithContext similar to Run but terminate the remote command with SIGKILL when its receive context cancellation.
-rw-r--r--ssh/session.go30
1 files changed, 28 insertions, 2 deletions
diff --git a/ssh/session.go b/ssh/session.go
index acef622..f940686 100644
--- a/ssh/session.go
+++ b/ssh/session.go
@@ -9,6 +9,7 @@ package ssh
import (
"bytes"
+ "context"
"encoding/binary"
"errors"
"fmt"
@@ -314,6 +315,16 @@ func (s *Session) Run(cmd string) error {
return s.Wait()
}
+// RunWithContext similar to [Session.Run] but with context.
+func (s *Session) RunWithContext(ctx context.Context, cmd string) (err error) {
+ err = s.Start(cmd)
+ if err != nil {
+ return err
+ }
+
+ return s.waitWithContext(ctx)
+}
+
// Output runs cmd on the remote host and returns its standard output.
func (s *Session) Output(cmd string) ([]byte, error) {
if s.Stdout != nil {
@@ -397,10 +408,22 @@ func (s *Session) start() error {
// unsuccessfully or is interrupted by a signal, the error is of type
// *ExitError. Other error types may be returned for I/O problems.
func (s *Session) Wait() error {
+ return s.waitWithContext(context.Background())
+}
+
+// waitWithContext wait for remote command to exit or terminate with SIGKILL
+// when its receive context cancellation.
+func (s *Session) waitWithContext(ctx context.Context) (err error) {
if !s.started {
return errors.New("ssh: session not started")
}
- waitErr := <-s.exitStatus
+
+ var waitErr error
+ select {
+ case <-ctx.Done():
+ waitErr = s.Signal(SIGKILL)
+ case waitErr = <-s.exitStatus:
+ }
if s.stdinPipeWriter != nil {
s.stdinPipeWriter.Close()
@@ -414,7 +437,10 @@ func (s *Session) Wait() error {
if waitErr != nil {
return waitErr
}
- return copyError
+ if copyError != nil {
+ return copyError
+ }
+ return context.Cause(ctx)
}
func (s *Session) wait(reqs <-chan *Request) error {