From 8b0eb632e4b3cb47e4b97e154ef7abaf03c7c36b Mon Sep 17 00:00:00 2001 From: Shulhan Date: Mon, 25 Dec 2023 23:01:30 +0700 Subject: ssh/config: set the default UserKnownHostsFile in setDefaults While at it, unfold each value of IdentityFile and UserKnownHostsFile in setDefaults, by expanding "~" into user's home directory or joining with "config" directory if its relative. --- lib/ssh/config/config.go | 13 ++++++-- lib/ssh/config/parser_test.go | 2 ++ lib/ssh/config/section.go | 48 +++++++++++++++++++++++------ lib/ssh/config/section_test.go | 2 +- lib/ssh/config/testdata/config | 1 + lib/ssh/config/testdata/config_get_test.txt | 1 + 6 files changed, 54 insertions(+), 13 deletions(-) diff --git a/lib/ssh/config/config.go b/lib/ssh/config/config.go index 6ba14393..de5d4e34 100644 --- a/lib/ssh/config/config.go +++ b/lib/ssh/config/config.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "os" + "path/filepath" "strings" ) @@ -30,6 +31,9 @@ var ( type Config struct { envs map[string]string + // dir store the path to the "config" directory. + dir string + // workDir store the current working directory. workDir string @@ -42,6 +46,12 @@ type Config struct { func newConfig(file string) (cfg *Config, err error) { cfg = &Config{} + // If file is empty, the dir is set to ".". + cfg.dir, err = filepath.Abs(filepath.Dir(file)) + if err != nil { + return nil, err + } + cfg.workDir, err = os.Getwd() if err != nil { return nil, err @@ -161,8 +171,7 @@ func (cfg *Config) Get(host string) (section *Section) { // This function can be useful if we want to load another SSH config file // without using Include directive. func (cfg *Config) Prepend(other *Config) { - newSections := make([]*Section, 0, - len(cfg.sections)+len(other.sections)) + newSections := make([]*Section, 0, len(cfg.sections)+len(other.sections)) newSections = append(newSections, other.sections...) newSections = append(newSections, cfg.sections...) cfg.sections = newSections diff --git a/lib/ssh/config/parser_test.go b/lib/ssh/config/parser_test.go index 21bef8cd..1bb6f0e1 100644 --- a/lib/ssh/config/parser_test.go +++ b/lib/ssh/config/parser_test.go @@ -81,6 +81,7 @@ func TestReadLines(t *testing.T) { `Port 28022`, `User foo`, `IdentityFile ~/.ssh/foo`, + `UserKnownHostsFile known_hosts`, `Host *foo.local`, `User allfoo`, `IdentityFile ~/.ssh/allfoo`, @@ -127,6 +128,7 @@ func TestConfigParser_load(t *testing.T) { `Port 28022`, `User foo`, `IdentityFile ~/.ssh/foo`, + `UserKnownHostsFile known_hosts`, `Host *foo.local`, `User allfoo`, `IdentityFile ~/.ssh/allfoo`, diff --git a/lib/ssh/config/section.go b/lib/ssh/config/section.go index 6b359bf5..79158d92 100644 --- a/lib/ssh/config/section.go +++ b/lib/ssh/config/section.go @@ -206,6 +206,9 @@ type Section struct { // name contains the raw value after Host or Match. name string + // dir store the path to the "config" directory. + dir string + // WorkingDir contains the directory where the SSH client started. // This value is required when client want to copy file from/to // remote. @@ -248,6 +251,7 @@ func NewSection(cfg *Config, name string) (section *Section) { } if cfg != nil { + section.dir = cfg.dir section.homeDir = cfg.homeDir section.WorkingDir = cfg.workDir } @@ -481,11 +485,20 @@ func (section *Section) setDefaults() { if len(section.IdentityFile) == 0 { section.IdentityFile = defaultIdentityFile() } + var ( + file string + x int + ) + for x, file = range section.IdentityFile { + section.IdentityFile[x] = section.pathUnfold(file) + } - for x, identFile := range section.IdentityFile { - if identFile[0] == '~' { - section.IdentityFile[x] = strings.Replace(identFile, "~", section.homeDir, 1) - } + // Set and expand the UserKnownHostsFile. + if len(section.knownHostsFile) == 0 { + section.knownHostsFile = defaultUserKnownHostsFile() + } + for x, file = range section.knownHostsFile { + section.knownHostsFile[x] = section.pathUnfold(file) } var ( @@ -597,9 +610,6 @@ func (section *Section) User() string { // UserKnownHostsFile return list of user known_hosts file set in this // Section. func (section *Section) UserKnownHostsFile() []string { - if len(section.knownHostsFile) == 0 { - return defaultUserKnownHostsFile() - } return section.knownHostsFile } @@ -648,7 +658,7 @@ func (section *Section) MarshalText() (text []byte, err error) { buf.WriteString(` `) buf.WriteString(key) buf.WriteByte(' ') - buf.WriteString(section.pathUnfold(val)) + buf.WriteString(section.pathFold(val)) buf.WriteByte('\n') } continue @@ -674,8 +684,8 @@ func (section *Section) WriteTo(w io.Writer) (n int64, err error) { return int64(c), err } -// pathUnfold replace the home directory prefix with '~'. -func (section *Section) pathUnfold(in string) (out string) { +// pathFold replace the home directory prefix with '~'. +func (section *Section) pathFold(in string) (out string) { if !strings.HasPrefix(in, section.homeDir) { return in } @@ -683,6 +693,24 @@ func (section *Section) pathUnfold(in string) (out string) { return out } +// pathUnfold expand the file to make it absolute. +// If the file prefixed with '~', it will expanded into home directory. +// If the file is relative (does not start with '/'), it will expanded based +// on the "config" directory. +func (section *Section) pathUnfold(in string) string { + if len(in) == 0 { + return in + } + if in[0] == '/' { + return in + } + if in[0] == '~' { + return filepath.Join(section.homeDir, in[1:]) + } + // The path in is relative to the "config" directory. + return filepath.Join(section.dir, in) +} + // setEnv set the Environments with key and value of format "KEY=VALUE". func (section *Section) setEnv(env string) { kv := strings.SplitN(env, "=", 2) diff --git a/lib/ssh/config/section_test.go b/lib/ssh/config/section_test.go index e89cc7b5..f8326ec7 100644 --- a/lib/ssh/config/section_test.go +++ b/lib/ssh/config/section_test.go @@ -156,7 +156,7 @@ func TestSection_UserKnownHostsFile(t *testing.T) { } var listCase = []testCase{{ - exp: defaultUserKnownHostsFile(), + value: ``, }, { value: `~/.ssh/myhost ~/.ssh/myhost2`, exp: []string{ diff --git a/lib/ssh/config/testdata/config b/lib/ssh/config/testdata/config index 2f876502..58873a35 100644 --- a/lib/ssh/config/testdata/config +++ b/lib/ssh/config/testdata/config @@ -15,6 +15,7 @@ Host foo.local Port 28022 User foo IdentityFile ~/.ssh/foo + UserKnownHostsFile known_hosts ## Override the foo.local using wildcard. Host *foo.local diff --git a/lib/ssh/config/testdata/config_get_test.txt b/lib/ssh/config/testdata/config_get_test.txt index e87444c8..fcab297b 100644 --- a/lib/ssh/config/testdata/config_get_test.txt +++ b/lib/ssh/config/testdata/config_get_test.txt @@ -40,6 +40,7 @@ Host foo.local identityfile ~/.ssh/allfoo port 28022 user allfoo + userknownhostsfile known_hosts xauthlocation /usr/X11R6/bin/xauth <<< my.foo.local -- cgit v1.3