aboutsummaryrefslogtreecommitdiff
path: root/session.go
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2023-09-24 01:31:26 +0700
committerShulhan <ms@kilabit.info>2023-09-26 00:24:02 +0700
commit8cc52027d243946c03c6b0d1016ca7cc3d7de09a (patch)
treecbb1db5fbc9a5f48e24b64e391ea1d01adffee1c /session.go
parent9c709996d9519d6552e44182440438d080b8789d (diff)
downloadawwan-8cc52027d243946c03c6b0d1016ca7cc3d7de09a.tar.xz
all: make the magic word "#put" able to copy encrypted file
When issuing "#put:" or "#put!" command in the script, if the input file is not exist it will check for the encrypted file, the one with ".vault" extension. If it exists, the encrypted file will be used as input for copy operation.
Diffstat (limited to 'session.go')
-rw-r--r--session.go158
1 files changed, 112 insertions, 46 deletions
diff --git a/session.go b/session.go
index fb848b8..7470624 100644
--- a/session.go
+++ b/session.go
@@ -4,10 +4,12 @@
package awwan
import (
+ "bytes"
"crypto/rsa"
"errors"
"fmt"
"io/fs"
+ "log"
"os"
"os/exec"
"path/filepath"
@@ -94,12 +96,7 @@ func (ses *Session) Vals(keyPath string) []string {
// Copy file in local system.
func (ses *Session) Copy(stmt *Statement) (err error) {
- var (
- logp = "Copy"
-
- src string
- dest string
- )
+ var logp = `Copy`
if len(stmt.cmd) == 0 {
return fmt.Errorf("%s: missing source argument", logp)
@@ -111,16 +108,26 @@ func (ses *Session) Copy(stmt *Statement) (err error) {
return fmt.Errorf("%s: two or more destination arguments is given", logp)
}
- src, err = ses.generateFileInput(stmt.cmd)
+ var (
+ src string
+ isVault bool
+ )
+
+ src, isVault, err = ses.generateFileInput(stmt.cmd)
if err != nil {
return fmt.Errorf("%s: %w", logp, err)
}
- dest = stmt.args[0]
-
- err = libos.Copy(dest, src)
+ err = libos.Copy(stmt.args[0], src)
+ if isVault {
+ // Delete the decrypted file on exit.
+ var errRemove = os.Remove(src)
+ if errRemove != nil {
+ log.Printf(`%s: %s`, logp, errRemove)
+ }
+ }
if err != nil {
- return fmt.Errorf("%s: %w", logp, err)
+ return fmt.Errorf(`%s: %w`, logp, err)
}
return nil
}
@@ -160,12 +167,7 @@ func (ses *Session) Get(stmt *Statement) (err error) {
// Put copy file from local to remote system.
func (ses *Session) Put(stmt *Statement) (err error) {
- var (
- logp = "Put"
-
- local string
- remote string
- )
+ var logp = `Put`
if len(stmt.cmd) == 0 {
return fmt.Errorf("%s: missing source argument", logp)
@@ -177,18 +179,29 @@ func (ses *Session) Put(stmt *Statement) (err error) {
return fmt.Errorf("%s: two or more destination arguments is given", logp)
}
- local, err = ses.generateFileInput(stmt.cmd)
+ var (
+ local string
+ isVault bool
+ )
+
+ local, isVault, err = ses.generateFileInput(stmt.cmd)
if err != nil {
return fmt.Errorf("%s: %w", logp, err)
}
- remote = stmt.args[0]
+ var remote = stmt.args[0]
if ses.sftpc == nil {
err = ses.sshClient.ScpPut(local, remote)
} else {
err = ses.sftpc.Put(local, remote)
}
+ if isVault {
+ var errRemove = os.Remove(local)
+ if errRemove != nil {
+ log.Printf(`%s: %s`, logp, errRemove)
+ }
+ }
if err != nil {
return fmt.Errorf("%s: %w", logp, err)
}
@@ -215,7 +228,15 @@ func (ses *Session) SudoCopy(req *Request, stmt *Statement, withParseInput bool)
}
if withParseInput {
- src, err = ses.generateFileInput(stmt.cmd)
+ var isVault bool
+
+ src, isVault, err = ses.generateFileInput(stmt.cmd)
+ if isVault {
+ var errRemove = os.Remove(src)
+ if errRemove != nil {
+ log.Printf(`%s: %s`, logp, errRemove)
+ }
+ }
if err != nil {
return fmt.Errorf("%s: %w", logp, err)
}
@@ -297,9 +318,8 @@ func (ses *Session) SudoGet(stmt *Statement) (err error) {
// SudoPut copy file from local to remote using sudo.
func (ses *Session) SudoPut(stmt *Statement) (err error) {
var (
- logp = "SudoPut"
+ logp = `SudoPut`
- local string
baseName string
tmp string
remote string
@@ -316,9 +336,19 @@ func (ses *Session) SudoPut(stmt *Statement) (err error) {
return fmt.Errorf("%s: two or more destination arguments is given", logp)
}
+ var (
+ local string
+ isVault bool
+ )
// Apply the session variables into local file to be copied first, and
// save them into cache directory.
- local, err = ses.generateFileInput(stmt.cmd)
+ local, isVault, err = ses.generateFileInput(stmt.cmd)
+ if isVault {
+ var errRemove = os.Remove(local)
+ if errRemove != nil {
+ log.Printf(`%s: %s`, logp, errRemove)
+ }
+ }
if err != nil {
return fmt.Errorf("%s: %w", logp, err)
}
@@ -474,50 +504,56 @@ func (ses *Session) executeScriptOnRemote(req *Request, pos linePosition) {
//
// For example, if the input file path is "{{.BaseDir}}/a/b/script" then the
// output file path would be "{{.BaseDir}}/.cache/a/b/script".
-func (ses *Session) generateFileInput(in string) (out string, err error) {
+func (ses *Session) generateFileInput(in string) (out string, isVault bool, err error) {
+ // Check if the file is binary first, since binary file will not get
+ // encrypted.
+ if libos.IsBinary(in) {
+ return in, false, nil
+ }
+
var (
logp = `generateFileInput`
- tmpl *template.Template
- f *os.File
- outDir string
- base string
+ contentInput []byte
)
- if libos.IsBinary(in) {
- return in, nil
+ contentInput, isVault, err = ses.loadFileInput(in)
+ if err != nil {
+ return ``, false, fmt.Errorf(`%s: %w`, logp, err)
}
- outDir = filepath.Join(ses.BaseDir, defCacheDir, filepath.Dir(in))
- base = filepath.Base(in)
- out = filepath.Join(outDir, base)
+ var tmpl = template.New(in)
- err = os.MkdirAll(outDir, 0700)
+ tmpl, err = tmpl.Parse(string(contentInput))
if err != nil {
- return "", fmt.Errorf("%s %s: %w", logp, in, err)
+ return ``, false, fmt.Errorf(`%s: %w`, logp, err)
}
- tmpl, err = template.ParseFiles(in)
- if err != nil {
- return "", fmt.Errorf("%s %s: %w", logp, in, err)
- }
+ var contentOut bytes.Buffer
- f, err = os.Create(out)
+ err = tmpl.Execute(&contentOut, ses)
if err != nil {
- return "", fmt.Errorf("%s %s: %w", logp, in, err)
+ return ``, false, fmt.Errorf(`%s: %w`, logp, err)
}
- err = tmpl.Execute(f, ses)
+ var (
+ outDir = filepath.Join(ses.BaseDir, defCacheDir, filepath.Dir(in))
+ base = filepath.Base(in)
+ )
+
+ err = os.MkdirAll(outDir, 0700)
if err != nil {
- return "", fmt.Errorf("%s %s: %w", logp, in, err)
+ return ``, false, fmt.Errorf(`%s: %s: %w`, logp, outDir, err)
}
- err = f.Close()
+ out = filepath.Join(outDir, base)
+
+ err = os.WriteFile(out, contentOut.Bytes(), 0600)
if err != nil {
- return "", fmt.Errorf("%s %s: %w", logp, in, err)
+ return ``, false, fmt.Errorf(`%s: %s: %w`, logp, out, err)
}
- return out, nil
+ return out, isVault, nil
}
// generatePaths using baseDir return all paths from BaseDir to ScriptDir.
@@ -645,6 +681,36 @@ func (ses *Session) loadFileEnv(awwanEnv string, isVault bool) (err error) {
return nil
}
+// loadFileInput read the input file for Copy or Put operation.
+// If the original input file does not exist, try loading the encrypted file
+// with ".vault" extension.
+//
+// On success, it will return the content of file and true if the file is
+// from encrypted file .vault.
+func (ses *Session) loadFileInput(path string) (content []byte, isVault bool, err error) {
+ content, err = os.ReadFile(path)
+ if err == nil {
+ return content, false, nil
+ }
+ if !errors.Is(err, fs.ErrNotExist) {
+ return nil, false, err
+ }
+
+ path = path + defEncryptExt
+
+ content, err = os.ReadFile(path)
+ if err != nil {
+ return nil, false, err
+ }
+
+ content, err = decrypt(ses.privateKey, content)
+ if err != nil {
+ return nil, false, err
+ }
+
+ return content, true, nil
+}
+
func (ses *Session) loadRawEnv(content []byte) (err error) {
var in *ini.Ini