From c0d336ca456093b2b7c0b585dbe08f62cbc8ca83 Mon Sep 17 00:00:00 2001 From: Shulhan Date: Fri, 22 Dec 2023 02:04:11 +0700 Subject: all: add [context.Context] to Local and Play Passing context allow the command Local or Play to be cancelled when running in asynchronous mode, in this case when awwan run with WUI. --- awwan.go | 45 +++++++++------ awwan_local_test.go | 72 ++++++++++++++++++++++-- awwan_play_test.go | 20 +++++-- awwan_sudo_test.go | 7 ++- cmd/awwan/main.go | 7 ++- http_server.go | 21 ++++++- session.go | 122 ++++++++++++++++++++++------------------ sudo_test.go | 5 +- testdata/local/cancel.aww | 1 + testdata/local/cancel_test.data | 3 + 10 files changed, 214 insertions(+), 89 deletions(-) create mode 100644 testdata/local/cancel.aww create mode 100644 testdata/local/cancel_test.data diff --git a/awwan.go b/awwan.go index bb8ded4..99cfec9 100644 --- a/awwan.go +++ b/awwan.go @@ -4,6 +4,7 @@ package awwan import ( + "context" "fmt" "log" "os" @@ -284,7 +285,7 @@ func (aww *Awwan) EnvSet(key, val, file string) (err error) { } // Local execute the script in the local machine using shell. -func (aww *Awwan) Local(req *ExecRequest) (err error) { +func (aww *Awwan) Local(ctx context.Context, req *ExecRequest) (err error) { var ( logp = `Local` sessionDir = filepath.Dir(req.scriptPath) @@ -320,14 +321,20 @@ func (aww *Awwan) Local(req *ExecRequest) (err error) { req.mlog.Outf(`=== BEGIN: %s %s %s`, req.Mode, req.Script, req.LineRange) for _, pos = range req.lineRange.list { - err = ses.executeRequires(req, pos) - if err != nil { - goto out - } - - err = ses.executeScriptOnLocal(req, pos) - if err != nil { + select { + case <-ctx.Done(): + err = ctx.Err() goto out + default: + err = ses.executeRequires(ctx, req, pos) + if err != nil { + goto out + } + + err = ses.executeScriptOnLocal(ctx, req, pos) + if err != nil { + goto out + } } } req.mlog.Outf(`=== END: %s %s %s`, req.Mode, req.Script, req.LineRange) @@ -341,7 +348,7 @@ out: } // Play execute the script in the remote machine using SSH. -func (aww *Awwan) Play(req *ExecRequest) (err error) { +func (aww *Awwan) Play(ctx context.Context, req *ExecRequest) (err error) { var ( logp = `Play` sessionDir = filepath.Dir(req.scriptPath) @@ -391,14 +398,20 @@ func (aww *Awwan) Play(req *ExecRequest) (err error) { req.mlog.Outf(`=== BEGIN: %s %s %s`, req.Mode, req.Script, req.LineRange) for _, pos = range req.lineRange.list { - err = ses.executeRequires(req, pos) - if err != nil { - goto out - } - - err = ses.executeScriptOnRemote(req, pos) - if err != nil { + select { + case <-ctx.Done(): + err = ctx.Err() goto out + default: + err = ses.executeRequires(ctx, req, pos) + if err != nil { + goto out + } + + err = ses.executeScriptOnRemote(ctx, req, pos) + if err != nil { + goto out + } } } req.mlog.Outf(`=== END: %s %s %s`, req.Mode, req.Script, req.LineRange) diff --git a/awwan_local_test.go b/awwan_local_test.go index 5755846..55b7812 100644 --- a/awwan_local_test.go +++ b/awwan_local_test.go @@ -7,10 +7,12 @@ package awwan import ( "bytes" + "context" "io/fs" "os" "path/filepath" "testing" + "time" "github.com/shuLhan/share/lib/test" "github.com/shuLhan/share/lib/test/mock" @@ -66,6 +68,8 @@ func TestAwwanLocal(t *testing.T) { }} var ( + ctx = context.Background() + req *ExecRequest logw bytes.Buffer c testCase @@ -80,7 +84,7 @@ func TestAwwanLocal(t *testing.T) { logw.Reset() req.registerLogWriter(`output`, &logw) - err = aww.Local(req) + err = aww.Local(ctx, req) if err != nil { test.Assert(t, `error`, c.expError, err.Error()) continue @@ -90,6 +94,62 @@ func TestAwwanLocal(t *testing.T) { } } +func TestAwwanLocalCancel(t *testing.T) { + var ( + baseDir = `testdata/local` + scriptFile = filepath.Join(baseDir, `cancel.aww`) + tdataFile = filepath.Join(baseDir, `cancel_test.data`) + mockrw = mock.ReadWriter{} + + tdata *test.Data + aww *Awwan + err error + ) + + tdata, err = test.LoadData(tdataFile) + if err != nil { + t.Fatal(err) + } + + aww, err = New(baseDir) + if err != nil { + t.Fatal(err) + } + + // Mock terminal to read passphrase for private key. + aww.cryptoc.termrw = &mockrw + + var execReq *ExecRequest + + execReq, err = NewExecRequest(CommandModeLocal, scriptFile, `1`) + if err != nil { + t.Fatal(err) + } + + var logw bytes.Buffer + + execReq.registerLogWriter(`output`, &logw) + + var ( + ctx = context.Background() + ctxDoCancel context.CancelFunc + ) + + ctx, ctxDoCancel = context.WithCancel(ctx) + + go func() { + var err2 = aww.Local(ctx, execReq) + t.Logf(`LocalCancel: error: %s`, err2) + }() + + // Wait for actual exec.CommandContext to run ... + time.Sleep(500 * time.Millisecond) + + ctxDoCancel() + + test.Assert(t, `stdout`, string(tdata.Output[`cancel`]), logw.String()) +} + func TestAwwanLocal_Get(t *testing.T) { type testCase struct { desc string @@ -148,6 +208,7 @@ func TestAwwanLocal_Get(t *testing.T) { }} var ( + ctx = context.Background() script = filepath.Join(baseDir, `get.aww`) c testCase @@ -166,7 +227,7 @@ func TestAwwanLocal_Get(t *testing.T) { t.Fatal(err) } - err = aww.Local(req) + err = aww.Local(ctx, req) if err != nil { test.Assert(t, `Local: error`, c.expError, err.Error()) continue @@ -255,6 +316,7 @@ func TestAwwanLocal_Put(t *testing.T) { }} var ( + ctx = context.Background() mockrw = mock.ReadWriter{} aww *Awwan @@ -289,7 +351,7 @@ func TestAwwanLocal_Put(t *testing.T) { t.Fatal(err) } - err = aww.Local(req) + err = aww.Local(ctx, req) if err != nil { test.Assert(t, `Local error`, c.expError, err.Error()) continue @@ -378,6 +440,8 @@ func TestAwwanLocal_withEncryption(t *testing.T) { }} var ( + ctx = context.Background() + c testCase logw bytes.Buffer req *ExecRequest @@ -399,7 +463,7 @@ func TestAwwanLocal_withEncryption(t *testing.T) { mockrw.BufRead.WriteString(c.pass) aww.cryptoc.privateKey = nil - err = aww.Local(req) + err = aww.Local(ctx, req) if err != nil { test.Assert(t, `Local error`, c.expError, err.Error()) } diff --git a/awwan_play_test.go b/awwan_play_test.go index f36d021..d0f206e 100644 --- a/awwan_play_test.go +++ b/awwan_play_test.go @@ -7,6 +7,7 @@ package awwan import ( "bytes" + "context" "io/fs" "os" "path/filepath" @@ -66,6 +67,8 @@ func TestAwwan_Play_withLocal(t *testing.T) { }} var ( + ctx = context.Background() + c testCase req *ExecRequest logw bytes.Buffer @@ -81,7 +84,7 @@ func TestAwwan_Play_withLocal(t *testing.T) { logw.Reset() req.registerLogWriter(`output`, &logw) - err = aww.Play(req) + err = aww.Play(ctx, req) if err != nil { test.Assert(t, `Play: error`, c.expError, err.Error()) continue @@ -132,6 +135,8 @@ func TestAwwan_Play_Get(t *testing.T) { }} var ( + ctx = context.Background() + req *ExecRequest c testCaseGetPut fi os.FileInfo @@ -150,7 +155,7 @@ func TestAwwan_Play_Get(t *testing.T) { t.Fatal(err) } - err = aww.Play(req) + err = aww.Play(ctx, req) if err != nil { test.Assert(t, `play error`, c.expError, err.Error()) } @@ -219,6 +224,8 @@ func TestAwwan_Play_Put(t *testing.T) { }} var ( + ctx = context.Background() + req *ExecRequest c testCaseGetPut fi os.FileInfo @@ -237,7 +244,7 @@ func TestAwwan_Play_Put(t *testing.T) { t.Fatal(err) } - err = aww.Play(req) + err = aww.Play(ctx, req) if err != nil { test.Assert(t, `play error`, c.expError, err.Error()) } @@ -310,6 +317,7 @@ func TestAwwan_Play_SudoGet(t *testing.T) { }} var ( + ctx = context.Background() mockin = &mockStdin{} req *ExecRequest @@ -335,7 +343,7 @@ func TestAwwan_Play_SudoGet(t *testing.T) { mockin.buf.WriteString(c.sudoPass) req.stdin = mockin - err = aww.Play(req) + err = aww.Play(ctx, req) if err != nil { test.Assert(t, `play error`, c.expError, err.Error()) } @@ -405,6 +413,8 @@ func TestAwwan_Play_SudoPut(t *testing.T) { }} var ( + ctx = context.Background() + req *ExecRequest c testCaseGetPut fi os.FileInfo @@ -423,7 +433,7 @@ func TestAwwan_Play_SudoPut(t *testing.T) { t.Fatal(err) } - err = aww.Play(req) + err = aww.Play(ctx, req) if err != nil { test.Assert(t, `play error`, c.expError, err.Error()) } diff --git a/awwan_sudo_test.go b/awwan_sudo_test.go index cc7bcc3..ae112f7 100644 --- a/awwan_sudo_test.go +++ b/awwan_sudo_test.go @@ -7,6 +7,7 @@ package awwan import ( "bytes" + "context" "io/fs" "os" "path/filepath" @@ -84,6 +85,7 @@ func TestAwwan_Local_SudoGet(t *testing.T) { }} var ( + ctx = context.Background() script = filepath.Join(baseDir, `get.aww`) mockin = &mockStdin{} @@ -105,7 +107,7 @@ func TestAwwan_Local_SudoGet(t *testing.T) { mockin.buf.WriteString(c.sudoPass) req.stdin = mockin - err = aww.Local(req) + err = aww.Local(ctx, req) if err != nil { test.Assert(t, `Local: error`, c.expError, err.Error()) continue @@ -179,6 +181,7 @@ func TestAwwan_Local_SudoPut(t *testing.T) { }} var ( + ctx = context.Background() mockin = &mockStdin{} mockout = &bytes.Buffer{} mockTerm = mock.ReadWriter{} @@ -214,7 +217,7 @@ func TestAwwan_Local_SudoPut(t *testing.T) { mockin.buf.WriteString(c.sudoPass) req.stdin = mockin - err = aww.Local(req) + err = aww.Local(ctx, req) if err != nil { test.Assert(t, `Local error`, c.expError, err.Error()) continue diff --git a/cmd/awwan/main.go b/cmd/awwan/main.go index c4e525b..8f70784 100644 --- a/cmd/awwan/main.go +++ b/cmd/awwan/main.go @@ -6,6 +6,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -199,6 +200,8 @@ func main() { log.Fatalf(`%s: %s`, logp, err) } + var ctx = context.Background() + switch cmdMode { case awwan.CommandModeDecrypt: var filePlain string @@ -241,9 +244,9 @@ func main() { case awwan.CommandModeEnvSet: err = aww.EnvSet(flag.Arg(1), flag.Arg(2), flag.Arg(3)) case awwan.CommandModeLocal: - err = aww.Local(req) + err = aww.Local(ctx, req) case awwan.CommandModePlay: - err = aww.Play(req) + err = aww.Play(ctx, req) case awwan.CommandModeServe: err = aww.Serve(*serveAddress, *isDev) } diff --git a/http_server.go b/http_server.go index 70dc7a8..c45cde2 100644 --- a/http_server.go +++ b/http_server.go @@ -5,6 +5,7 @@ package awwan import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -49,6 +50,10 @@ type httpServer struct { // idExecRes contains the execution ID and its response. idExecRes map[string]*ExecResponse + // idContextCancel contains the execution ID and its context + // cancellation function. + idContextCancel map[string]context.CancelFunc + aww *Awwan memfsBase *memfs.MemFS // The files caches. @@ -63,7 +68,8 @@ func newHTTPServer(aww *Awwan, address string) (httpd *httpServer, err error) { ) httpd = &httpServer{ - idExecRes: make(map[string]*ExecResponse), + idExecRes: make(map[string]*ExecResponse), + idContextCancel: make(map[string]context.CancelFunc), aww: aww, baseDir: aww.BaseDir, @@ -701,11 +707,20 @@ func (httpd *httpServer) Execute(epr *libhttp.EndpointRequest) (resb []byte, err httpd.idExecRes[execRes.ID] = execRes + var ( + ctx = context.Background() + ctxDoCancel context.CancelFunc + ) + + ctx, ctxDoCancel = context.WithCancel(ctx) + + httpd.idContextCancel[execRes.ID] = ctxDoCancel + go func() { if req.Mode == CommandModeLocal { - err = httpd.aww.Local(req) + err = httpd.aww.Local(ctx, req) } else { - err = httpd.aww.Play(req) + err = httpd.aww.Play(ctx, req) } execRes.end(err) }() diff --git a/session.go b/session.go index dc07e15..f4c14da 100644 --- a/session.go +++ b/session.go @@ -5,6 +5,7 @@ package awwan import ( "bytes" + "context" "errors" "fmt" "io/fs" @@ -22,8 +23,7 @@ import ( "github.com/shuLhan/share/lib/ssh/config" ) -// Session manage and cache SSH client and list of scripts. -// One session have one SSH client, but may contains more than one script. +// Session manage environment and SSH client. type Session struct { cryptoc *cryptoContext @@ -228,7 +228,7 @@ func (ses *Session) Put(req *ExecRequest, stmt *Statement) (err error) { } // SudoCopy copy file in local system using sudo. -func (ses *Session) SudoCopy(req *ExecRequest, stmt *Statement) (err error) { +func (ses *Session) SudoCopy(ctx context.Context, req *ExecRequest, stmt *Statement) (err error) { var ( logp = `SudoCopy` src = stmt.args[0] @@ -254,7 +254,7 @@ func (ses *Session) SudoCopy(req *ExecRequest, stmt *Statement) (err error) { raw: []byte(fmt.Sprintf(`sudo cp %q %q`, src, dst)), } - err = ExecLocal(req, sudoCp) + err = ExecLocal(ctx, req, sudoCp) if isVault { var errRemove = os.Remove(src) if errRemove != nil { @@ -275,7 +275,7 @@ func (ses *Session) SudoCopy(req *ExecRequest, stmt *Statement) (err error) { raw: []byte(fmt.Sprintf(`sudo chmod %o %q`, stmt.mode, dst)), } ) - err = ExecLocal(req, sudoChmod) + err = ExecLocal(ctx, req, sudoChmod) if err != nil { return fmt.Errorf(`%s: chmod: %w`, logp, err) } @@ -287,7 +287,7 @@ func (ses *Session) SudoCopy(req *ExecRequest, stmt *Statement) (err error) { args: []string{`chown`, stmt.owner, dst}, raw: []byte(fmt.Sprintf(`sudo chown %s %q`, stmt.owner, dst)), } - err = ExecLocal(req, sudoChown) + err = ExecLocal(ctx, req, sudoChown) if err != nil { return fmt.Errorf(`%s: chown: %w`, logp, err) } @@ -300,7 +300,7 @@ func (ses *Session) SudoCopy(req *ExecRequest, stmt *Statement) (err error) { // local using sudo. // If the owner and/or mode is set, it will also applied using sudo on local // host, after the file has been retrieved. -func (ses *Session) SudoGet(req *ExecRequest, stmt *Statement) (err error) { +func (ses *Session) SudoGet(ctx context.Context, req *ExecRequest, stmt *Statement) (err error) { var ( logp = `SudoGet` src = stmt.args[0] @@ -313,13 +313,13 @@ func (ses *Session) SudoGet(req *ExecRequest, stmt *Statement) (err error) { } if stmt.mode != 0 { - err = ses.localSudoChmod(req, dst, stmt.mode) + err = ses.localSudoChmod(ctx, req, dst, stmt.mode) if err != nil { return fmt.Errorf(`%s: %w`, logp, err) } } if len(stmt.owner) != 0 { - err = ses.localSudoChown(req, dst, stmt.owner) + err = ses.localSudoChown(ctx, req, dst, stmt.owner) if err != nil { return fmt.Errorf(`%s: %w`, logp, err) } @@ -378,7 +378,7 @@ func (ses *Session) SudoPut(req *ExecRequest, stmt *Statement) (err error) { // // The raw field must be used when generating Command to handle arguments // with quotes. -func ExecLocal(req *ExecRequest, stmt *Statement) (err error) { +func ExecLocal(ctx context.Context, req *ExecRequest, stmt *Statement) (err error) { if stmt.cmd == `sudo` { if req.stdin != nil { var raw = make([]byte, 0, len(stmt.raw)) @@ -388,7 +388,7 @@ func ExecLocal(req *ExecRequest, stmt *Statement) (err error) { } } - var cmd = exec.Command(`/bin/sh`, `-c`, string(stmt.raw)) + var cmd = exec.CommandContext(ctx, `/bin/sh`, `-c`, string(stmt.raw)) cmd.Stdin = req.stdin cmd.Stdout = req.mlog @@ -413,7 +413,7 @@ func (ses *Session) close() (err error) { // executeRequires run the "#require:" statements from line 0 until // the start argument in the local system. -func (ses *Session) executeRequires(req *ExecRequest, pos linePosition) (err error) { +func (ses *Session) executeRequires(ctx context.Context, req *ExecRequest, pos linePosition) (err error) { if pos.start >= int64(len(req.script.requires)) { return nil } @@ -424,22 +424,27 @@ func (ses *Session) executeRequires(req *ExecRequest, pos linePosition) (err err ) for x = 0; x <= pos.start; x++ { - stmt = req.script.requires[x] - if stmt == nil { - continue - } + select { + case <-ctx.Done(): + return ctx.Err() + default: + stmt = req.script.requires[x] + if stmt == nil { + continue + } - req.mlog.Outf(`--- require %d: %v`, x, stmt) + req.mlog.Outf(`--- require %d: %v`, x, stmt) - err = ExecLocal(req, stmt) - if err != nil { - return err + err = ExecLocal(ctx, req, stmt) + if err != nil { + return err + } } } return nil } -func (ses *Session) executeScriptOnLocal(req *ExecRequest, pos linePosition) (err error) { +func (ses *Session) executeScriptOnLocal(ctx context.Context, req *ExecRequest, pos linePosition) (err error) { var max = int64(len(req.script.stmts)) if pos.start > max { return @@ -449,41 +454,46 @@ func (ses *Session) executeScriptOnLocal(req *ExecRequest, pos linePosition) (er } for x := pos.start; x <= pos.end; x++ { - stmt := req.script.stmts[x] - if stmt == nil { - continue - } - if stmt.kind == statementKindComment { - continue - } - if stmt.kind == statementKindRequire { - continue - } - - req.mlog.Outf(`--> %3d: %s`, x, stmt.String()) + select { + case <-ctx.Done(): + return ctx.Err() + default: + stmt := req.script.stmts[x] + if stmt == nil { + continue + } + if stmt.kind == statementKindComment { + continue + } + if stmt.kind == statementKindRequire { + continue + } - switch stmt.kind { - case statementKindDefault: - err = ExecLocal(req, stmt) - case statementKindGet: - err = ses.Copy(req, stmt) - case statementKindLocal: - err = ExecLocal(req, stmt) - case statementKindPut: - err = ses.Copy(req, stmt) - case statementKindSudoGet: - err = ses.SudoCopy(req, stmt) - case statementKindSudoPut: - err = ses.SudoCopy(req, stmt) - } - if err != nil { - return err + req.mlog.Outf(`--> %3d: %s`, x, stmt.String()) + + switch stmt.kind { + case statementKindDefault: + err = ExecLocal(ctx, req, stmt) + case statementKindGet: + err = ses.Copy(req, stmt) + case statementKindLocal: + err = ExecLocal(ctx, req, stmt) + case statementKindPut: + err = ses.Copy(req, stmt) + case statementKindSudoGet: + err = ses.SudoCopy(ctx, req, stmt) + case statementKindSudoPut: + err = ses.SudoCopy(ctx, req, stmt) + } + if err != nil { + return err + } } } return nil } -func (ses *Session) executeScriptOnRemote(req *ExecRequest, pos linePosition) (err error) { +func (ses *Session) executeScriptOnRemote(ctx context.Context, req *ExecRequest, pos linePosition) (err error) { var max = int64(len(req.script.stmts)) if pos.start > max { return @@ -512,11 +522,11 @@ func (ses *Session) executeScriptOnRemote(req *ExecRequest, pos linePosition) (e case statementKindGet: err = ses.Get(stmt) case statementKindLocal: - err = ExecLocal(req, stmt) + err = ExecLocal(ctx, req, stmt) case statementKindPut: err = ses.Put(req, stmt) case statementKindSudoGet: - err = ses.SudoGet(req, stmt) + err = ses.SudoGet(ctx, req, stmt) case statementKindSudoPut: err = ses.SudoPut(req, stmt) } @@ -742,7 +752,7 @@ func (ses *Session) loadRawEnv(content []byte) (err error) { // localSudoChmod change the file permission in local environment using // sudo. -func (ses *Session) localSudoChmod(req *ExecRequest, file string, mode fs.FileMode) (err error) { +func (ses *Session) localSudoChmod(ctx context.Context, req *ExecRequest, file string, mode fs.FileMode) (err error) { var ( fsmode = strconv.FormatUint(uint64(mode), 8) sudoChmod = &Statement{ @@ -752,7 +762,7 @@ func (ses *Session) localSudoChmod(req *ExecRequest, file string, mode fs.FileMo raw: []byte(fmt.Sprintf(`sudo chmod %o %q`, mode, file)), } ) - err = ExecLocal(req, sudoChmod) + err = ExecLocal(ctx, req, sudoChmod) if err != nil { return fmt.Errorf(`%s: %w`, sudoChmod.raw, err) } @@ -760,14 +770,14 @@ func (ses *Session) localSudoChmod(req *ExecRequest, file string, mode fs.FileMo } // localSudoChown change the file owner in local environment using sudo. -func (ses *Session) localSudoChown(req *ExecRequest, file, owner string) (err error) { +func (ses *Session) localSudoChown(ctx context.Context, req *ExecRequest, file, owner string) (err error) { var sudoChown = &Statement{ kind: statementKindDefault, cmd: `sudo`, args: []string{`chown`, owner, file}, raw: []byte(fmt.Sprintf(`sudo chown %s %q`, owner, file)), } - err = ExecLocal(req, sudoChown) + err = ExecLocal(ctx, req, sudoChown) if err != nil { return fmt.Errorf(`%s: %w`, sudoChown.raw, err) } diff --git a/sudo_test.go b/sudo_test.go index 2ad8be5..cf2de67 100644 --- a/sudo_test.go +++ b/sudo_test.go @@ -7,6 +7,7 @@ package awwan import ( "bytes" + "context" "testing" "github.com/shuLhan/share/lib/test" @@ -85,6 +86,8 @@ func TestExecLocal_sudo(t *testing.T) { }} var ( + ctx = context.Background() + c testCase stmt Statement x int @@ -98,7 +101,7 @@ func TestExecLocal_sudo(t *testing.T) { mockin.buf.WriteString(c.sudoPass) for x, stmt = range c.listStmt { - err = ExecLocal(req, &stmt) + err = ExecLocal(ctx, req, &stmt) if err != nil { t.Log(mockout.String()) var expError = c.expError[x] diff --git a/testdata/local/cancel.aww b/testdata/local/cancel.aww new file mode 100644 index 0000000..46a443f --- /dev/null +++ b/testdata/local/cancel.aww @@ -0,0 +1 @@ +sleep 300 diff --git a/testdata/local/cancel_test.data b/testdata/local/cancel_test.data new file mode 100644 index 0000000..357ec8b --- /dev/null +++ b/testdata/local/cancel_test.data @@ -0,0 +1,3 @@ +<<< cancel +----/--/-- --:--:-- === BEGIN: local testdata/local/cancel.aww 1 +----/--/-- --:--:-- --> 1: sleep 300 -- cgit v1.3