Skip to content

Commit

Permalink
ssh/config: make Config Get return Section with pattern
Browse files Browse the repository at this point in the history
Previously, if ssh config contains non-wildacard and wildcard pattern,
and the requested hostname only return the first section that match.
For example, given the following SSH config

  foo.local
    User foo
  *foo.local
    User allfoo

If we request Get("foo.local"), tt will return all fields under
"foo.local" only not "*foo.local".

This changes fix this by returning new section that contains all fields
from matched Section.
  • Loading branch information
shuLhan committed Jul 23, 2023
1 parent 3538b21 commit 5236931
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 64 deletions.
12 changes: 8 additions & 4 deletions lib/ssh/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ func Load(file string) (cfg *Config, err error) {
}

// Get the Host or Match configuration that match with the pattern "s".
// If no Host or Match found, it still return non-nil Section but with empty
// fields.
func (cfg *Config) Get(s string) (section *Section) {
for _, section := range cfg.sections {
if section.isMatch(s) {
return section
section = newSection(s)
for _, hostMatch := range cfg.sections {
if hostMatch.isMatch(s) {
section.mergeField(cfg, hostMatch)
}
}
return nil
section.init(cfg.workDir, cfg.homeDir)
return section
}

// Prepend other Config's sections to this Config.
Expand Down
110 changes: 64 additions & 46 deletions lib/ssh/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,74 +55,92 @@ func TestPatternToRegex(t *testing.T) {
}

func TestConfig_Get(t *testing.T) {
cfg, err := Load("./testdata/config")
type testCase struct {
exp func() Section
s string
}

var (
cfg *Config
err error
)

cfg, err = Load(`./testdata/config`)
if err != nil {
t.Fatal(err)
}

cases := []struct {
exp func(def Section) *Section
s string
}{{
s: "",
exp: func(def Section) *Section {
return nil
var listTestCase = []testCase{{
s: ``,
exp: func() Section {
var sec = *testDefaultSection
return sec
},
}, {
s: "example.local",
exp: func(def Section) *Section {
def.name = `example.local`
def.Hostname = "127.0.0.1"
def.User = "test"
def.PrivateKeyFile = ""
def.IdentityFile = []string{
filepath.Join(def.homeDir, ".ssh", "notexist"),
s: `example.local`,
exp: func() Section {
var sec = *testDefaultSection
sec.name = `example.local`
sec.Hostname = `127.0.0.1`
sec.User = `test`
sec.IdentityFile = []string{
filepath.Join(testDefaultSection.homeDir, `.ssh`, `notexist`),
}
def.useDefaultIdentityFile = false
def.Field = map[string]string{
sec.useDefaultIdentityFile = false
sec.Field = map[string]string{
`hostname`: `127.0.0.1`,
`user`: `test`,
`identityfile`: `~/.ssh/notexist`,
}
return &def
return sec
},
}, {
s: "my.example.local",
exp: func(def Section) *Section {
def.name = `*.example.local`
def.Hostname = "127.0.0.2"
def.User = "wildcard"
def.PrivateKeyFile = ""
def.IdentityFile = []string{
filepath.Join(def.homeDir, ".ssh", "notexist"),
s: `my.example.local`,
exp: func() Section {
var sec = *testDefaultSection
sec.name = `my.example.local`
sec.Hostname = `127.0.0.2`
sec.User = `wildcard`
sec.IdentityFile = []string{
filepath.Join(testDefaultSection.homeDir, `.ssh`, `notexist`),
}
def.useDefaultIdentityFile = false
def.Field = map[string]string{
sec.useDefaultIdentityFile = false
sec.Field = map[string]string{
`hostname`: `127.0.0.2`,
`user`: `wildcard`,
`identityfile`: `~/.ssh/notexist`,
}
return &def
return sec
},
}, {
s: `foo.local`,
exp: func() Section {
var sec = *testDefaultSection
sec.name = `foo.local`
sec.Hostname = `127.0.0.3`
sec.User = `allfoo`
sec.IdentityFile = []string{
filepath.Join(testDefaultSection.homeDir, `.ssh`, `foo`),
filepath.Join(testDefaultSection.homeDir, `.ssh`, `allfoo`),
}
sec.useDefaultIdentityFile = false
sec.Field = map[string]string{
`hostname`: `127.0.0.3`,
`user`: `allfoo`,
`identityfile`: `~/.ssh/allfoo`,
}
return sec
},
}}

for _, c := range cases {
got := cfg.Get(c.s)

// Clear the patterns and criteria for comparison.
if got != nil {
got.patterns = nil
got.criteria = nil
got.init(testParser.workDir, testParser.homeDir)
}
var (
c testCase
got *Section
)

exp := c.exp(*testDefaultSection)
if exp != nil {
exp.init(testParser.workDir, testParser.homeDir)
} else if got == nil {
continue
}
test.Assert(t, c.s, *exp, *got)
for _, c = range listTestCase {
got = cfg.Get(c.s)
test.Assert(t, c.s, c.exp(), *got)
}
}

Expand Down
33 changes: 20 additions & 13 deletions lib/ssh/config/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ package config

import (
"os"
"reflect"
"testing"

"github.com/shuLhan/share/lib/test"
)

func TestIsIncludeDirective(t *testing.T) {
Expand Down Expand Up @@ -57,9 +58,7 @@ func TestParseInclude(t *testing.T) {

for _, c := range cases {
got := parseInclude(c.line)
if !reflect.DeepEqual(c.exp, got) {
t.Fatalf("parseInclude: expecting %v, got %v", c.exp, got)
}
test.Assert(t, c.line, c.exp, got)
}
}

Expand All @@ -77,6 +76,13 @@ func TestReadLines(t *testing.T) {
`IdentityFile ~/.ssh/notexist`,
`Host *.example.local`,
`Include sub/config`,
`Host foo.local`,
`Hostname 127.0.0.3`,
`User foo`,
`IdentityFile ~/.ssh/foo`,
`Host *foo.local`,
`User allfoo`,
`IdentityFile ~/.ssh/allfoo`,
},
}}

Expand All @@ -86,9 +92,7 @@ func TestReadLines(t *testing.T) {
t.Fatal(err)
}

if !reflect.DeepEqual(c.exp, got) {
t.Fatalf("readLines: expecting %v, got %v", c.exp, got)
}
test.Assert(t, c.file, c.exp, got)
}
}

Expand Down Expand Up @@ -117,6 +121,13 @@ func TestConfigParser_load(t *testing.T) {
`Hostname 127.0.0.2`,
`User wildcard`,
`IdentityFile ~/.ssh/notexist`,
`Host foo.local`,
`Hostname 127.0.0.3`,
`User foo`,
`IdentityFile ~/.ssh/foo`,
`Host *foo.local`,
`User allfoo`,
`IdentityFile ~/.ssh/allfoo`,
},
}}

Expand All @@ -133,9 +144,7 @@ func TestConfigParser_load(t *testing.T) {
}
continue
}
if !reflect.DeepEqual(c.exp, got) {
t.Fatalf("parser.load: expecting %v, got %v", c.exp, got)
}
test.Assert(t, c.pattern, c.exp, got)
}
}

Expand Down Expand Up @@ -169,8 +178,6 @@ func TestParseArgs(t *testing.T) {
for _, c := range cases {
got := parseArgs(c.raw, ' ')

if !reflect.DeepEqual(c.exp, got) {
t.Fatalf("parseArgs: expecting %v, got %v", c.exp, got)
}
test.Assert(t, c.raw, c.exp, got)
}
}
12 changes: 11 additions & 1 deletion lib/ssh/config/section.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,17 @@ func (section *Section) init(workDir, homeDir string) {
}
}

func (section *Section) merge(other *Section) {
// mergeField merge the Field from other Section.
func (section *Section) mergeField(cfg *Config, other *Section) {
var (
key string
value string
)
for key, value = range other.Field {
// The key and value in other should be valid, so no need to
// check for error.
_ = section.set(cfg, key, value)
}
}

// set the section field by raw key and value.
Expand Down
10 changes: 10 additions & 0 deletions lib/ssh/config/testdata/config
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ Host example.local
# comment
Host *.example.local
Include sub/config

Host foo.local
Hostname 127.0.0.3
User foo
IdentityFile ~/.ssh/foo

## Override the foo.local using wildcard.
Host *foo.local
User allfoo
IdentityFile ~/.ssh/allfoo

0 comments on commit 5236931

Please sign in to comment.