aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <m.shulhan@gmail.com>2023-12-22 13:42:41 +0700
committerShulhan <m.shulhan@gmail.com>2023-12-22 16:23:26 +0700
commit10c3c1ae8a5e3c843f377034bae617b2ded59e0e (patch)
tree395cb3c090655ed50d20c8d7b2975a618442c27c
parent9d2ee975ef9fe627bf0a6f01c1f69e8ef1d4f05d (diff)
downloadgo-x-crypto-ssh-run-with-context.tar.xz
ssh: implement Session Run with contextssh-run-with-context.mailedssh-run-with-context
The RunWithContext similar to Run but terminate the remote command with SIGKILL when its receive context cancellation. Change-Id: Ib82e23b77450bef222bba8576eca11b9d356688b
-rw-r--r--ssh/session.go30
1 files changed, 28 insertions, 2 deletions
diff --git a/ssh/session.go b/ssh/session.go
index acef622..afd46e6 100644
--- a/ssh/session.go
+++ b/ssh/session.go
@@ -9,6 +9,7 @@ package ssh
import (
"bytes"
+ "context"
"encoding/binary"
"errors"
"fmt"
@@ -314,6 +315,16 @@ func (s *Session) Run(cmd string) error {
return s.Wait()
}
+// RunWithContext similar to [Session.Run] with context.
+func (s *Session) RunWithContext(ctx context.Context, cmd string) (err error) {
+ err = s.Start(cmd)
+ if err != nil {
+ return err
+ }
+
+ return s.waitWithContext(ctx)
+}
+
// Output runs cmd on the remote host and returns its standard output.
func (s *Session) Output(cmd string) ([]byte, error) {
if s.Stdout != nil {
@@ -397,10 +408,22 @@ func (s *Session) start() error {
// unsuccessfully or is interrupted by a signal, the error is of type
// *ExitError. Other error types may be returned for I/O problems.
func (s *Session) Wait() error {
+ return s.waitWithContext(context.Background())
+}
+
+// waitWithContext wait for remote command to exit or terminate with SIGKILL
+// when its receive context cancellation.
+func (s *Session) waitWithContext(ctx context.Context) (err error) {
if !s.started {
return errors.New("ssh: session not started")
}
- waitErr := <-s.exitStatus
+
+ var waitErr error
+ select {
+ case <-ctx.Done():
+ waitErr = s.Signal(SIGKILL)
+ case waitErr = <-s.exitStatus:
+ }
if s.stdinPipeWriter != nil {
s.stdinPipeWriter.Close()
@@ -414,7 +437,10 @@ func (s *Session) Wait() error {
if waitErr != nil {
return waitErr
}
- return copyError
+ if copyError != nil {
+ return copyError
+ }
+ return context.Cause(ctx)
}
func (s *Session) wait(reqs <-chan *Request) error {