From 52369316969b97d20588b1a2e9c6ad3734fb2e4e Mon Sep 17 00:00:00 2001 From: Shulhan Date: Sun, 23 Jul 2023 23:40:36 +0700 Subject: [PATCH] ssh/config: make Config Get return Section with pattern Previously, if ssh config contains non-wildacard and wildcard pattern, and the requested hostname only return the first section that match. For example, given the following SSH config foo.local User foo *foo.local User allfoo If we request Get("foo.local"), tt will return all fields under "foo.local" only not "*foo.local". This changes fix this by returning new section that contains all fields from matched Section. --- lib/ssh/config/config.go | 12 ++-- lib/ssh/config/config_test.go | 110 +++++++++++++++++++-------------- lib/ssh/config/parser_test.go | 33 ++++++---- lib/ssh/config/section.go | 12 +++- lib/ssh/config/testdata/config | 10 +++ 5 files changed, 113 insertions(+), 64 deletions(-) diff --git a/lib/ssh/config/config.go b/lib/ssh/config/config.go index f9ad37a7..36b84dce 100644 --- a/lib/ssh/config/config.go +++ b/lib/ssh/config/config.go @@ -113,13 +113,17 @@ func Load(file string) (cfg *Config, err error) { } // Get the Host or Match configuration that match with the pattern "s". +// If no Host or Match found, it still return non-nil Section but with empty +// fields. func (cfg *Config) Get(s string) (section *Section) { - for _, section := range cfg.sections { - if section.isMatch(s) { - return section + section = newSection(s) + for _, hostMatch := range cfg.sections { + if hostMatch.isMatch(s) { + section.mergeField(cfg, hostMatch) } } - return nil + section.init(cfg.workDir, cfg.homeDir) + return section } // Prepend other Config's sections to this Config. diff --git a/lib/ssh/config/config_test.go b/lib/ssh/config/config_test.go index 790bb96e..81ce6ea2 100644 --- a/lib/ssh/config/config_test.go +++ b/lib/ssh/config/config_test.go @@ -55,74 +55,92 @@ func TestPatternToRegex(t *testing.T) { } func TestConfig_Get(t *testing.T) { - cfg, err := Load("./testdata/config") + type testCase struct { + exp func() Section + s string + } + + var ( + cfg *Config + err error + ) + + cfg, err = Load(`./testdata/config`) if err != nil { t.Fatal(err) } - cases := []struct { - exp func(def Section) *Section - s string - }{{ - s: "", - exp: func(def Section) *Section { - return nil + var listTestCase = []testCase{{ + s: ``, + exp: func() Section { + var sec = *testDefaultSection + return sec }, }, { - s: "example.local", - exp: func(def Section) *Section { - def.name = `example.local` - def.Hostname = "127.0.0.1" - def.User = "test" - def.PrivateKeyFile = "" - def.IdentityFile = []string{ - filepath.Join(def.homeDir, ".ssh", "notexist"), + s: `example.local`, + exp: func() Section { + var sec = *testDefaultSection + sec.name = `example.local` + sec.Hostname = `127.0.0.1` + sec.User = `test` + sec.IdentityFile = []string{ + filepath.Join(testDefaultSection.homeDir, `.ssh`, `notexist`), } - def.useDefaultIdentityFile = false - def.Field = map[string]string{ + sec.useDefaultIdentityFile = false + sec.Field = map[string]string{ `hostname`: `127.0.0.1`, `user`: `test`, `identityfile`: `~/.ssh/notexist`, } - return &def + return sec }, }, { - s: "my.example.local", - exp: func(def Section) *Section { - def.name = `*.example.local` - def.Hostname = "127.0.0.2" - def.User = "wildcard" - def.PrivateKeyFile = "" - def.IdentityFile = []string{ - filepath.Join(def.homeDir, ".ssh", "notexist"), + s: `my.example.local`, + exp: func() Section { + var sec = *testDefaultSection + sec.name = `my.example.local` + sec.Hostname = `127.0.0.2` + sec.User = `wildcard` + sec.IdentityFile = []string{ + filepath.Join(testDefaultSection.homeDir, `.ssh`, `notexist`), } - def.useDefaultIdentityFile = false - def.Field = map[string]string{ + sec.useDefaultIdentityFile = false + sec.Field = map[string]string{ `hostname`: `127.0.0.2`, `user`: `wildcard`, `identityfile`: `~/.ssh/notexist`, } - return &def + return sec + }, + }, { + s: `foo.local`, + exp: func() Section { + var sec = *testDefaultSection + sec.name = `foo.local` + sec.Hostname = `127.0.0.3` + sec.User = `allfoo` + sec.IdentityFile = []string{ + filepath.Join(testDefaultSection.homeDir, `.ssh`, `foo`), + filepath.Join(testDefaultSection.homeDir, `.ssh`, `allfoo`), + } + sec.useDefaultIdentityFile = false + sec.Field = map[string]string{ + `hostname`: `127.0.0.3`, + `user`: `allfoo`, + `identityfile`: `~/.ssh/allfoo`, + } + return sec }, }} - for _, c := range cases { - got := cfg.Get(c.s) - - // Clear the patterns and criteria for comparison. - if got != nil { - got.patterns = nil - got.criteria = nil - got.init(testParser.workDir, testParser.homeDir) - } + var ( + c testCase + got *Section + ) - exp := c.exp(*testDefaultSection) - if exp != nil { - exp.init(testParser.workDir, testParser.homeDir) - } else if got == nil { - continue - } - test.Assert(t, c.s, *exp, *got) + for _, c = range listTestCase { + got = cfg.Get(c.s) + test.Assert(t, c.s, c.exp(), *got) } } diff --git a/lib/ssh/config/parser_test.go b/lib/ssh/config/parser_test.go index f003776c..3e7d0a48 100644 --- a/lib/ssh/config/parser_test.go +++ b/lib/ssh/config/parser_test.go @@ -6,8 +6,9 @@ package config import ( "os" - "reflect" "testing" + + "github.com/shuLhan/share/lib/test" ) func TestIsIncludeDirective(t *testing.T) { @@ -57,9 +58,7 @@ func TestParseInclude(t *testing.T) { for _, c := range cases { got := parseInclude(c.line) - if !reflect.DeepEqual(c.exp, got) { - t.Fatalf("parseInclude: expecting %v, got %v", c.exp, got) - } + test.Assert(t, c.line, c.exp, got) } } @@ -77,6 +76,13 @@ func TestReadLines(t *testing.T) { `IdentityFile ~/.ssh/notexist`, `Host *.example.local`, `Include sub/config`, + `Host foo.local`, + `Hostname 127.0.0.3`, + `User foo`, + `IdentityFile ~/.ssh/foo`, + `Host *foo.local`, + `User allfoo`, + `IdentityFile ~/.ssh/allfoo`, }, }} @@ -86,9 +92,7 @@ func TestReadLines(t *testing.T) { t.Fatal(err) } - if !reflect.DeepEqual(c.exp, got) { - t.Fatalf("readLines: expecting %v, got %v", c.exp, got) - } + test.Assert(t, c.file, c.exp, got) } } @@ -117,6 +121,13 @@ func TestConfigParser_load(t *testing.T) { `Hostname 127.0.0.2`, `User wildcard`, `IdentityFile ~/.ssh/notexist`, + `Host foo.local`, + `Hostname 127.0.0.3`, + `User foo`, + `IdentityFile ~/.ssh/foo`, + `Host *foo.local`, + `User allfoo`, + `IdentityFile ~/.ssh/allfoo`, }, }} @@ -133,9 +144,7 @@ func TestConfigParser_load(t *testing.T) { } continue } - if !reflect.DeepEqual(c.exp, got) { - t.Fatalf("parser.load: expecting %v, got %v", c.exp, got) - } + test.Assert(t, c.pattern, c.exp, got) } } @@ -169,8 +178,6 @@ func TestParseArgs(t *testing.T) { for _, c := range cases { got := parseArgs(c.raw, ' ') - if !reflect.DeepEqual(c.exp, got) { - t.Fatalf("parseArgs: expecting %v, got %v", c.exp, got) - } + test.Assert(t, c.raw, c.exp, got) } } diff --git a/lib/ssh/config/section.go b/lib/ssh/config/section.go index ab08b23b..90c1cc58 100644 --- a/lib/ssh/config/section.go +++ b/lib/ssh/config/section.go @@ -371,7 +371,17 @@ func (section *Section) init(workDir, homeDir string) { } } -func (section *Section) merge(other *Section) { +// mergeField merge the Field from other Section. +func (section *Section) mergeField(cfg *Config, other *Section) { + var ( + key string + value string + ) + for key, value = range other.Field { + // The key and value in other should be valid, so no need to + // check for error. + _ = section.set(cfg, key, value) + } } // set the section field by raw key and value. diff --git a/lib/ssh/config/testdata/config b/lib/ssh/config/testdata/config index 79b1088f..1716e1af 100644 --- a/lib/ssh/config/testdata/config +++ b/lib/ssh/config/testdata/config @@ -9,3 +9,13 @@ Host example.local # comment Host *.example.local Include sub/config + +Host foo.local + Hostname 127.0.0.3 + User foo + IdentityFile ~/.ssh/foo + +## Override the foo.local using wildcard. +Host *foo.local + User allfoo + IdentityFile ~/.ssh/allfoo