diff --git a/cli/cli.go b/cli/cli.go index b9aec58..dc32e73 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -10,15 +10,15 @@ import ( var re = regexp.MustCompile(`(?P.+@)?(?P[[:alpha:][:digit:]\_\-\.]+)?(?P:[0-9]+)?`) -// App contains main settings of application. +// App contains all supported CLI arguments given by the user. type App struct { args []string flag *flag.FlagSet Command string - Local HostInput - Remote HostInput - Server HostInput + Local AddressInputList + Remote AddressInputList + Server AddressInput Key string Verbose bool Help bool @@ -47,8 +47,8 @@ func (c *App) Parse() error { f.BoolVar(&c.AliasDelete, "delete", false, "delete a tunnel alias (must be used with -alias)") f.BoolVar(&c.AliasList, "aliases", false, "list all aliases") f.StringVar(&c.Start, "start", "", "Start a tunnel using a given alias") - f.Var(&c.Local, "local", "(optional) Set local endpoint address: []:") - f.Var(&c.Remote, "remote", "set remote endpoint address: []:") + f.Var(&c.Local, "local", "(optional) Set local endpoint address: []:. Multiple -local args can be provided.") + f.Var(&c.Remote, "remote", "(optional) Set remote endpoint address: []:. Multiple -remote args can be provided.") f.Var(&c.Server, "server", "set server address: [@][:]") f.StringVar(&c.Key, "key", "", "(optional) Set server authentication key file path") f.BoolVar(&c.Verbose, "v", false, "(optional) Increase log verbosity") @@ -95,17 +95,15 @@ func (c App) Validate() error { switch c.Command { case "new-alias": - if c.Remote.String() == "" { - return fmt.Errorf("required flag is missing: -remote") - } else if c.Server.String() == "" { + if c.Server.String() == "" { return fmt.Errorf("required flag is missing: -server") } case "start": if c.Server.String() == "" { return fmt.Errorf("required flag is missing: -server") } - } + return nil } @@ -113,8 +111,8 @@ func (c App) Validate() error { // use the tool. func (c *App) PrintUsage() { fmt.Fprintf(os.Stderr, "%s\n\n", `usage: - mole [-v] [-insecure] [-detach] [-local []:] -remote []: -server [@][:] [-key ] - mole -alias [-v] [-local []:] -remote []: -server [@][:] [-key ] + mole [-v] [-insecure] [-detach] (-local []:)... (-remote []:)... -server [@][:] [-key ] + mole -alias [-v] (-local []:)... (-remote []:)... -server [@][:] [-key ] mole -alias -delete mole -start mole -help @@ -128,15 +126,15 @@ func (c App) String() string { c.Local, c.Remote, c.Server, c.Key, c.Verbose, c.Help, c.Version, c.Detach) } -// HostInput holds information about a host -type HostInput struct { +// AddressInput holds information about a host +type AddressInput struct { User string Host string Port string } -// String returns a string representation of a HostInput -func (h HostInput) String() string { +// String returns a string representation of a AddressInput +func (h AddressInput) String() string { var s string if h.User == "" { s = h.Address() @@ -147,8 +145,8 @@ func (h HostInput) String() string { return s } -// Set parses a string representation of HostInput into its proper attributes. -func (h *HostInput) Set(value string) error { +// Set parses a string representation of AddressInput into its proper attributes. +func (h *AddressInput) Set(value string) error { result := parseServerInput(value) h.User = strings.Trim(result["user"], "@") h.Host = result["host"] @@ -157,9 +155,9 @@ func (h *HostInput) Set(value string) error { return nil } -// Address returns a string representation of HostInput to be used to perform +// Address returns a string representation of AddressInput to be used to perform // network connections. -func (h HostInput) Address() string { +func (h AddressInput) Address() string { if h.Port == "" { return fmt.Sprintf("%s", h.Host) } @@ -180,3 +178,39 @@ func parseServerInput(input string) map[string]string { return result } + +type AddressInputList []AddressInput + +func (il AddressInputList) String() string { + ils := []string{} + + for _, i := range il { + ils = append(ils, i.String()) + } + + return strings.Join(ils, ",") +} + +func (il *AddressInputList) Set(value string) error { + i := AddressInput{} + + err := i.Set(value) + if err != nil { + return err + } + + *il = append(*il, i) + + return nil +} + +func (il AddressInputList) List() []string { + sl := []string{} + + for _, i := range il { + sl = append(sl, i.String()) + } + + return sl + +} diff --git a/cli/cli_test.go b/cli/cli_test.go index 3df4b7e..8c41000 100644 --- a/cli/cli_test.go +++ b/cli/cli_test.go @@ -7,52 +7,52 @@ import ( "github.com/davrodpin/mole/cli" ) -func TestHostInput(t *testing.T) { +func TestAddressInput(t *testing.T) { tests := []struct { input string - expected cli.HostInput + expected cli.AddressInput }{ { "test", - cli.HostInput{User: "", Host: "test", Port: ""}, + cli.AddressInput{User: "", Host: "test", Port: ""}, }, { "user@test", - cli.HostInput{User: "user", Host: "test", Port: ""}, + cli.AddressInput{User: "user", Host: "test", Port: ""}, }, { "user@test:2222", - cli.HostInput{User: "user", Host: "test", Port: "2222"}, + cli.AddressInput{User: "user", Host: "test", Port: "2222"}, }, { "test-1", - cli.HostInput{User: "", Host: "test-1", Port: ""}, + cli.AddressInput{User: "", Host: "test-1", Port: ""}, }, { "test-1-2-xy", - cli.HostInput{User: "", Host: "test-1-2-xy", Port: ""}, + cli.AddressInput{User: "", Host: "test-1-2-xy", Port: ""}, }, { "test.com", - cli.HostInput{User: "", Host: "test.com", Port: ""}, + cli.AddressInput{User: "", Host: "test.com", Port: ""}, }, { "test_1", - cli.HostInput{User: "", Host: "test_1", Port: ""}, + cli.AddressInput{User: "", Host: "test_1", Port: ""}, }, { "user@test_1", - cli.HostInput{User: "user", Host: "test_1", Port: ""}, + cli.AddressInput{User: "user", Host: "test_1", Port: ""}, }, { "user@test_1:2222", - cli.HostInput{User: "user", Host: "test_1", Port: "2222"}, + cli.AddressInput{User: "user", Host: "test_1", Port: "2222"}, }, } - var h cli.HostInput + var h cli.AddressInput for _, test := range tests { - h = cli.HostInput{} + h = cli.AddressInput{} h.Set(test.input) if !reflect.DeepEqual(test.expected, h) { @@ -128,11 +128,7 @@ func TestValidate(t *testing.T) { }, { []string{"./mole", "-alias", "xyz", "-server", "example1"}, - false, - }, - { - []string{"./mole", "-alias", "xyz", "-server", "example1"}, - false, + true, }, { []string{"./mole", "-alias", "xyz", "-remote", ":443"}, @@ -142,6 +138,18 @@ func TestValidate(t *testing.T) { []string{"./mole", "-alias", "xyz"}, false, }, + { + []string{"./mole", "-local", ":8080", "-remote", ":80", "-server", "example1"}, + true, + }, + { + []string{"./mole", "-remote", ":3366", "-remote", ":443", "-server", "example1"}, + true, + }, + { + []string{"./mole", "-local", ":1234", "-remote", ":3366", "-remote", ":443", "-server", "example1"}, + true, + }, } var c *cli.App diff --git a/cmd/mole/alias.go b/cmd/mole/alias.go new file mode 100644 index 0000000..2df915d --- /dev/null +++ b/cmd/mole/alias.go @@ -0,0 +1,75 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/davrodpin/mole/cli" + "github.com/davrodpin/mole/storage" +) + +func lsAliases() error { + aliases, err := storage.FindAll() + if err != nil { + return err + } + + as := []string{} + for a := range aliases { + as = append(as, a) + } + + fmt.Printf("alias list: %s\n", strings.Join(as, ", ")) + + return nil +} + +func app2alias(app cli.App) *storage.Alias { + return &storage.Alias{ + Local: app.Local.List(), + Remote: app.Remote.List(), + Server: app.Server.String(), + Key: app.Key, + Verbose: app.Verbose, + Help: app.Help, + Version: app.Version, + Detach: app.Detach, + } +} + +func alias2app(t *storage.Alias) (*cli.App, error) { + sla, err := t.ReadLocal() + if err != nil { + return nil, err + } + + lal := cli.AddressInputList{} + for _, la := range sla { + lal.Set(la) + } + + sra, err := t.ReadRemote() + if err != nil { + return nil, err + } + + ral := cli.AddressInputList{} + for _, ra := range sra { + ral.Set(ra) + } + + server := cli.AddressInput{} + server.Set(t.Server) + + return &cli.App{ + Command: "start", + Local: lal, + Remote: ral, + Server: server, + Key: t.Key, + Verbose: t.Verbose, + Help: t.Help, + Version: t.Version, + Detach: t.Detach, + }, nil +} diff --git a/cmd/mole/main.go b/cmd/mole/main.go index a61f42f..402dac7 100644 --- a/cmd/mole/main.go +++ b/cmd/mole/main.go @@ -3,7 +3,6 @@ package main import ( "fmt" "os" - "strings" "syscall" "github.com/awnumar/memguard" @@ -56,7 +55,7 @@ func main() { case "version": fmt.Printf("mole %s\n", version) case "start": - err := start(*app) + err := start(app) if err != nil { os.Exit(1) } @@ -81,7 +80,7 @@ func main() { os.Exit(1) } case "aliases": - err := lsAliases(*app) + err := lsAliases() if err != nil { os.Exit(1) } @@ -187,7 +186,15 @@ func startFromAlias(app cli.App) error { return err } - appFromAlias := alias2app(conf) + appFromAlias, err := alias2app(conf) + if err != nil { + log.WithFields(log.Fields{ + "alias": app.Alias, + }).Errorf("error starting mole: %v", err) + + return err + } + appFromAlias.Alias = app.Alias // if use -detach when -start but none -detach in storage if app.Detach { @@ -197,7 +204,7 @@ func startFromAlias(app cli.App) error { return start(appFromAlias) } -func start(app cli.App) error { +func start(app *cli.App) error { if app.Detach { var alias string if app.Alias != "" { @@ -241,7 +248,26 @@ func start(app cli.App) error { log.Debugf("server: %s", s) - t := tunnel.New(app.Local.String(), s, app.Remote.String()) + local := make([]string, len(app.Local)) + for i, r := range app.Local { + local[i] = r.String() + } + + remote := make([]string, len(app.Remote)) + for i, r := range app.Remote { + remote[i] = r.String() + } + + channels, err := tunnel.BuildSSHChannels(s.Name, local, remote) + if err != nil { + return err + } + + t, err := tunnel.New(s, channels) + if err != nil { + log.Errorf("%v", err) + return err + } if err = t.Start(); err != nil { log.WithFields(log.Fields{ @@ -256,7 +282,6 @@ func start(app cli.App) error { func newAlias(app cli.App) error { _, err := storage.Save(app.Alias, app2alias(app)) - if err != nil { log.WithFields(log.Fields{ "alias": app.Alias, @@ -276,55 +301,3 @@ func rmAlias(app cli.App) error { return nil } - -func lsAliases(app cli.App) error { - tunnels, err := storage.FindAll() - if err != nil { - return err - } - - aliases := []string{} - for alias := range tunnels { - aliases = append(aliases, alias) - } - - fmt.Printf("alias list: %s\n", strings.Join(aliases, ", ")) - - return nil -} - -func app2alias(app cli.App) *storage.Tunnel { - return &storage.Tunnel{ - Local: app.Local.String(), - Remote: app.Remote.String(), - Server: app.Server.String(), - Key: app.Key, - Verbose: app.Verbose, - Help: app.Help, - Version: app.Version, - Detach: app.Detach, - } -} - -func alias2app(t *storage.Tunnel) cli.App { - local := cli.HostInput{} - local.Set(t.Local) - - remote := cli.HostInput{} - remote.Set(t.Remote) - - server := cli.HostInput{} - server.Set(t.Server) - - return cli.App{ - Command: "start", - Local: local, - Remote: remote, - Server: server, - Key: t.Key, - Verbose: t.Verbose, - Help: t.Help, - Version: t.Version, - Detach: t.Detach, - } -} diff --git a/go.mod b/go.mod index b729a72..8e3e468 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,27 @@ go 1.12 require ( github.com/BurntSushi/toml v0.3.1 github.com/awnumar/memguard v0.15.1 + github.com/davidrjenni/reftools v0.0.0-20190411195930-981bbac422f8 // indirect + github.com/fatih/gomodifytags v0.0.0-20190517204355-df91c5bc7551 // indirect + github.com/fatih/motion v0.0.0-20180408211639-218875ebe238 // indirect + github.com/josharian/impl v0.0.0-20180228163738-3d0f908298c4 // indirect + github.com/jstemmer/gotags v1.4.1 // indirect github.com/kevinburke/ssh_config v0.0.0-20180830205328-81db2a75821e + github.com/klauspost/asmfmt v1.2.0 // indirect + github.com/koron/iferr v0.0.0-20180615142939-bb332a3b1d91 // indirect + github.com/kr/pretty v0.1.0 // indirect + github.com/mdempsky/gocode v0.0.0-20190203001940-7fb65232883f // indirect github.com/pelletier/go-buffruneio v0.2.0 // indirect + github.com/rogpeppe/godef v1.1.1 // indirect github.com/satori/go.uuid v1.2.0 github.com/sevlyar/go-daemon v0.1.5 github.com/sirupsen/logrus v1.4.1 + github.com/zmb3/gogetdoc v0.0.0-20190228002656-b37376c5da6a // indirect golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f + golang.org/x/net v0.0.0-20190520210107-018c4d40a106 // indirect + golang.org/x/sys v0.0.0-20190520201301-c432e742b0af // indirect + golang.org/x/text v0.3.2 // indirect + golang.org/x/tools v0.0.0-20190521171243-7927dbab1be7 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a // indirect ) diff --git a/go.sum b/go.sum index 8848969..6946d9c 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,47 @@ +9fans.net/go v0.0.0-20181112161441-237454027057 h1:OcHlKWkAMJEF1ndWLGxp5dnJQkYM/YImUOvsBoz6h5E= +9fans.net/go v0.0.0-20181112161441-237454027057/go.mod h1:diCsxrliIURU9xsYtjCp5AbpQKqdhKmf0ujWDUSkfoY= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/awnumar/memguard v0.15.1 h1:RDPYo+e6rm65NLKJqmSVQpO9LuLh7R/hfC6Ed3ahbEU= github.com/awnumar/memguard v0.15.1/go.mod h1:77EUD6uwfgcd6zTmn++i5ujEFviGRQfE8ELbDJO1rpA= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davidrjenni/reftools v0.0.0-20190411195930-981bbac422f8 h1:Eu7kPTwAOeiiot8fjH/WXIYSsoaYB7Xs7sPF8NYeKhs= +github.com/davidrjenni/reftools v0.0.0-20190411195930-981bbac422f8/go.mod h1:0qWLWApvobxwtd9/A8fS62VkRImuquIgtCv/ye+KnxA= +github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8= +github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc= +github.com/fatih/gomodifytags v0.0.0-20190517204355-df91c5bc7551 h1:/fvatMHXeYKMzBfS7ZAWJxAUVdNvorVap9/T7agWgW8= +github.com/fatih/gomodifytags v0.0.0-20190517204355-df91c5bc7551/go.mod h1:p2/x7bnOQsbq/deXsDIlj2yLiKFGPkD2nuoYqwn8R4Y= +github.com/fatih/motion v0.0.0-20180408211639-218875ebe238 h1:Qo4RxRMFag+fvDqQ6A3MblYBormptQUZ1ssOtV+EeQ8= +github.com/fatih/motion v0.0.0-20180408211639-218875ebe238/go.mod h1:pseIrV+t9A4+po+KJ1LheSnYH8m1qs6WhKx2zFiGi9I= +github.com/fatih/structtag v1.0.0 h1:pTHj65+u3RKWYPSGaU290FpI/dXxTaHdVwVwbcPKmEc= +github.com/fatih/structtag v1.0.0/go.mod h1:IKitwq45uXL/yqi5mYghiD3w9H6eTOvI9vnk8tXMphA= +github.com/josharian/impl v0.0.0-20180228163738-3d0f908298c4 h1:gmIVMdGlVf5e6Yo6+ZklxdOrvtOvyrAjJyXAbmOznyo= +github.com/josharian/impl v0.0.0-20180228163738-3d0f908298c4/go.mod h1:t4Tr0tn92eq5ISef4cS5plFAMYAqZlAXtgUcKE6y8nw= +github.com/jstemmer/gotags v1.4.1 h1:aWIyXsU3lTDqhsEC49MP85p2cUUWr2ptvdGNqqGA3r4= +github.com/jstemmer/gotags v1.4.1/go.mod h1:b6J3X0bsLbR4C5SgSx3V3KjuWTtmRzcmWPbTkWZ49PA= github.com/kevinburke/ssh_config v0.0.0-20180830205328-81db2a75821e h1:RgQk53JHp/Cjunrr1WlsXSZpqXn+uREuHvUVcK82CV8= github.com/kevinburke/ssh_config v0.0.0-20180830205328-81db2a75821e/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/asmfmt v1.2.0 h1:zwsyBYgEdabg32alMful/5pRtMTcR5C5w1LKNg9OD78= +github.com/klauspost/asmfmt v1.2.0/go.mod h1:RAoUvqkWr2rUa2I19qKMEVZQe4BVtcHGTMCUOcCU2Lg= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/koron/iferr v0.0.0-20180615142939-bb332a3b1d91 h1:hunjgdb3b21ZdRmzDPXii0EcnHpjH7uCP+kODoE1JH0= +github.com/koron/iferr v0.0.0-20180615142939-bb332a3b1d91/go.mod h1:C2tFh8w3I6i4lnUJfoBx2Hwku3mgu4wPNTtUNp1i5KI= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mdempsky/gocode v0.0.0-20190203001940-7fb65232883f h1:ee+twVCignaZjt7jpbMSLxAeTN/Nfq9W/nm91E7QO1A= +github.com/mdempsky/gocode v0.0.0-20190203001940-7fb65232883f/go.mod h1:hltEC42XzfMNgg0S1v6JTywwra2Mu6F6cLR03debVQ8= github.com/pelletier/go-buffruneio v0.2.0 h1:U4t4R6YkofJ5xHm3dJzuRpPZ0mr5MMCoAWooScCR7aA= github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/godef v1.1.1 h1:NujOtt9q9vIClRTB3sCZpavac+NMRaIayzrcz1h4fSE= +github.com/rogpeppe/godef v1.1.1/go.mod h1:oEo1eMy1VUEHUzUIX4F7IqvMJRiz9UId44mvnR8oPlQ= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sevlyar/go-daemon v0.1.5 h1:Zy/6jLbM8CfqJ4x4RPr7MJlSKt90f00kNM1D401C+Qk= @@ -17,13 +49,33 @@ github.com/sevlyar/go-daemon v0.1.5/go.mod h1:6dJpPatBT9eUwM5VCw9Bt6CdX9Tk6UWvhW github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/zmb3/gogetdoc v0.0.0-20190228002656-b37376c5da6a h1:00UFliGZl2UciXe8o/2iuEsRQ9u7z0rzDTVzuj6EYY0= +github.com/zmb3/gogetdoc v0.0.0-20190228002656-b37376c5da6a/go.mod h1:ofmGw6LrMypycsiWcyug6516EXpIxSbZ+uI9ppGypfY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f h1:R423Cnkcp5JABoeemiGEPlt9tHXFfw5kvc0yqlxRPWo= golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190520210107-018c4d40a106/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190520201301-c432e742b0af h1:NXfmMfXz6JqGfG3ikSxcz2N93j6DgScr19Oo2uwFu88= +golang.org/x/sys v0.0.0-20190520201301-c432e742b0af/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180824175216-6c1c5e93cdc1/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181130195746-895048a75ecf/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181207195948-8634b1ecd393/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190408220357-e5b8258f4918/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190521171243-7927dbab1be7 h1:+ef02iDoPU4j54NNvxgyVjdhaWHJ4da+lhWX18ayOok= +golang.org/x/tools v0.0.0-20190521171243-7927dbab1be7/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a h1:LJwr7TCTghdatWv40WobzlKXc9c4s8oGa7QKJUtHhWA= +honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/storage/storage.go b/storage/storage.go index c364997..12093ac 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -9,15 +9,24 @@ import ( "github.com/BurntSushi/toml" ) -// Store contains the map of tunnels, where key is string tunnel alias and value is Tunnel. +// Store contains the map of aliases, where key contains the alias name. type Store struct { - Tunnels map[string]*Tunnel `toml:"tunnels"` + Aliases map[string]*Alias `toml:"tunnels"` } -// Tunnel represents settings of the ssh tunnel. -type Tunnel struct { - Local string `toml:"local"` - Remote string `toml:"remote"` +// Alias represents settings of the ssh tunnel. +type Alias struct { + // Local holds all local addresses configured on an alias. + // + // The type is specified as `interface{}` for backward-compatibility reasons + // since only a single value was supported before. + Local interface{} `toml:"local"` + // Remote holds all remote addresses configured on an alias. + // + // The type is specified as `interface{}` for backward-compatibility reasons + // since only a single value was supported before. + Remote interface{} `toml:"remote"` + Server string `toml:"server"` Key string `toml:"key"` Verbose bool `toml:"verbose"` @@ -26,36 +35,60 @@ type Tunnel struct { Detach bool `toml:"detach"` } -func (t Tunnel) String() string { +func (t Alias) ReadLocal() ([]string, error) { + return readAddress(t.Local) +} + +func (t Alias) ReadRemote() ([]string, error) { + return readAddress(t.Remote) +} + +func readAddress(address interface{}) ([]string, error) { + switch v := address.(type) { + case string: + return []string{v}, nil + case []interface{}: + sv := []string{} + for _, e := range v { + sv = append(sv, e.(string)) + } + return sv, nil + default: + return nil, fmt.Errorf("couldn't load addresses: %v", address) + } + +} + +func (t Alias) String() string { return fmt.Sprintf("[local=%s, remote=%s, server=%s, key=%s, verbose=%t, help=%t, version=%t, detach=%t]", t.Local, t.Remote, t.Server, t.Key, t.Verbose, t.Help, t.Version, t.Detach) } -// Save stores Tunnel to the Store. -func Save(name string, tunnel *Tunnel) (*Tunnel, error) { +// Save stores Alias to the Store. +func Save(name string, alias *Alias) (*Alias, error) { store, err := loadStore() if err != nil { return nil, fmt.Errorf("error while loading mole configuration: %v", err) } - store.Tunnels[name] = tunnel + store.Aliases[name] = alias _, err = createStore(store) if err != nil { return nil, fmt.Errorf("error while saving mole configuration: %v", err) } - return tunnel, nil + return alias, nil } -// FindByName finds the Tunnel in Store by name. -func FindByName(name string) (*Tunnel, error) { +// FindByName finds the Alias in Store by name. +func FindByName(name string) (*Alias, error) { store, err := loadStore() if err != nil { return nil, fmt.Errorf("error while loading mole configuration: %v", err) } - tun := store.Tunnels[name] + tun := store.Aliases[name] if tun == nil { return nil, fmt.Errorf("alias could not be found: %s", name) @@ -64,27 +97,27 @@ func FindByName(name string) (*Tunnel, error) { return tun, nil } -// FindAll finds all the Tunnels in Store. -func FindAll() (map[string]*Tunnel, error) { +// FindAll finds all the Aliass in Store. +func FindAll() (map[string]*Alias, error) { store, err := loadStore() if err != nil { return nil, fmt.Errorf("error while loading mole configuration: %v", err) } - return store.Tunnels, nil + return store.Aliases, nil } -// Remove deletes Tunnel from the Store by name. -func Remove(name string) (*Tunnel, error) { +// Remove deletes Alias from the Store by name. +func Remove(name string) (*Alias, error) { store, err := loadStore() if err != nil { return nil, fmt.Errorf("error while loading mole configuration: %v", err) } - tun := store.Tunnels[name] + tun := store.Aliases[name] if tun != nil { - delete(store.Tunnels, name) + delete(store.Aliases, name) _, err := createStore(store) if err != nil { return nil, err @@ -103,7 +136,7 @@ func loadStore() (*Store, error) { } if _, err := os.Stat(sp); err != nil { - store = &Store{Tunnels: make(map[string]*Tunnel)} + store = &Store{Aliases: make(map[string]*Alias)} store, err = createStore(store) if err != nil { return nil, err diff --git a/storage/storage_test.go b/storage/storage_test.go index ad3a4ef..7df3b79 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -9,9 +9,9 @@ import ( "github.com/davrodpin/mole/storage" ) -func TestSaveTunnel(t *testing.T) { +func TestSaveAlias(t *testing.T) { alias := "example-save-443" - expected := &storage.Tunnel{ + expected := &storage.Alias{ Local: "", Remote: ":443", Server: "example", @@ -33,9 +33,9 @@ func TestSaveTunnel(t *testing.T) { } } -func TestRemoveTunnel(t *testing.T) { +func TestRemoveAlias(t *testing.T) { alias := "example-rm-443" - expected := &storage.Tunnel{ + expected := &storage.Alias{ Local: "", Remote: ":443", Server: "example", @@ -62,7 +62,7 @@ func TestRemoveTunnel(t *testing.T) { func TestFindAll(t *testing.T) { alias1 := "example-save-443" - expected1 := &storage.Tunnel{ + expected1 := &storage.Alias{ Local: "", Remote: ":443", Server: "example", @@ -70,7 +70,7 @@ func TestFindAll(t *testing.T) { } alias2 := "example-save-80" - expected2 := &storage.Tunnel{ + expected2 := &storage.Alias{ Local: "", Remote: ":80", Server: "example", @@ -80,17 +80,17 @@ func TestFindAll(t *testing.T) { storage.Save(alias1, expected1) storage.Save(alias2, expected2) - expectedTunnelList := make(map[string]*storage.Tunnel) - expectedTunnelList[alias1] = expected1 - expectedTunnelList[alias2] = expected2 + expectedAliasList := make(map[string]*storage.Alias) + expectedAliasList[alias1] = expected1 + expectedAliasList[alias2] = expected2 tunnels, err := storage.FindAll() if err != nil { t.Errorf("Test failed while retrieving all tunnels: %v", err) } - if !reflect.DeepEqual(expectedTunnelList, tunnels) { - t.Errorf("Test failed.\n\texpected: %v\n\tvalue : %v", expectedTunnelList, tunnels) + if !reflect.DeepEqual(expectedAliasList, tunnels) { + t.Errorf("Test failed.\n\texpected: %v\n\tvalue : %v", expectedAliasList, tunnels) } } @@ -108,3 +108,54 @@ func TestMain(m *testing.M) { os.Exit(code) } + +func TestReadLocalAndRemote(t *testing.T) { + + tests := []struct { + alias *storage.Alias + expectedLocal []string + expectedRemote []string + }{ + { + alias: &storage.Alias{Local: ":3306", Remote: ":3306"}, + expectedLocal: []string{":3306"}, + expectedRemote: []string{":3306"}, + }, + { + alias: &storage.Alias{Local: []interface{}{":3306", ":8080"}, Remote: []interface{}{":3306", ":8080"}}, + expectedLocal: []string{":3306", ":8080"}, + expectedRemote: []string{":3306", ":8080"}, + }, + { + alias: &storage.Alias{Local: ":3306", Remote: []interface{}{":3306", ":8080"}}, + expectedLocal: []string{":3306"}, + expectedRemote: []string{":3306", ":8080"}, + }, + { + alias: &storage.Alias{Local: []interface{}{":3306", ":8080"}, Remote: []interface{}{":3306"}}, + expectedLocal: []string{":3306", ":8080"}, + expectedRemote: []string{":3306"}, + }, + } + + for testId, test := range tests { + local, err := test.alias.ReadLocal() + if err != nil { + t.Errorf("unexpected error while reading local address from alias for %d: %v", testId, err) + } + + remote, err := test.alias.ReadRemote() + if err != nil { + t.Errorf("unexpected error while reading remote address from alias for %d: %v", testId, err) + } + + if !reflect.DeepEqual(test.expectedLocal, local) { + t.Errorf("unexpected local address for %d: expected: %v, value: %v", testId, test.expectedLocal, local) + } + + if !reflect.DeepEqual(test.expectedRemote, remote) { + t.Errorf("unexpected remote address for %d: expected: %v, value: %v", testId, test.expectedRemote, remote) + } + + } +} diff --git a/test-env/README.md b/test-env/README.md index ff1f41d..d28200a 100644 --- a/test-env/README.md +++ b/test-env/README.md @@ -4,8 +4,8 @@ This provides a small envorinment where `mole` functions can be tested and debugged. Once created, the test environment will provide a container running a ssh -server and another container running an http server, so ssh tunnels can be -created using `mole`. +server and another container running two (2) http servers, so ssh tunnels can +be created using `mole`. In addition to that, the test environment provides the infrastructure to analyze the traffic going through the ssh traffic. @@ -26,14 +26,14 @@ analyze the traffic going through the ssh traffic. | | | -+-----------+-------+ +---------------------+ -| | | | -| mole_ssh | | mole_http | -| SSH Server (:22) | | HTTP Server (:80) | -| (192.168.33.10) |----------| (192.168.33.11) | -| | | | -| | | | -+-------------------+ +---------------------+ ++-----------+-------+ +---------------------------+ +| | | | +| mole_ssh | | mole_http | +| SSH Server (:22) | | HTTP Server #1 (:80) | +| (192.168.33.10) |----------| HTTP Server #2 (:8080) | +| | | (192.168.33.11) | +| | | | ++-------------------+ +---------------------------+ ``` ## Required Software @@ -63,8 +63,9 @@ connections can be made using address `127.0.0.1:22122`. All ssh connection to `mole_ssh` should be done using the user `mole` and the key file located on `test-env/key` -`mole_http` runs a http server listening on port `80`. -The http server responds only to requests to `http://192.168.33.11/`. +`mole_http` runs two http servers listening on port `80` and `8080`, so clients +would be able to access the using the following urls: `http://192.168.33.11:80/` +and `http://192.168.33.11:8080/` ### Teardown @@ -83,14 +84,17 @@ The ssh authentication key files, `test-env/key` and `test-env/key,pub` will ```sh $ make test-env -$ mole -insecure -local :21112 -remote 192.168.33.11:80 -server mole@127.0.0.1:22122 -key test-env/ssh-server/keys/key -INFO[0000] listening on local address local_address="127.0.0.1:21112" +$ mole -insecure -local :21112 -local :21113 -remote 192.168.33.11:80 -remote 192.168.33.11:8080 -server mole@127.0.0.1:22122 -key test-env/ssh-server/keys/key +INFO[0000] tunnel is ready local="127.0.0.1:21113" remote="192.168.33.11:8080" +INFO[0000] tunnel is ready local="127.0.0.1:21112" remote="192.168.33.11:80" $ curl 127.0.0.1:21112 :) +$ curl 127.0.0.1:21113 +:) ``` -NOTE: If you're wondering about the smile face, that is the response from the -http server. +NOTE: If you're wondering about the smile face, that is the response from both +http servers. ## Packet Analisys diff --git a/test-env/http-server/Dockerfile b/test-env/http-server/Dockerfile index 4c81d68..438d1a5 100644 --- a/test-env/http-server/Dockerfile +++ b/test-env/http-server/Dockerfile @@ -1,6 +1,6 @@ FROM alpine:3.6 -RUN apk update && apk add nginx +RUN apk update && apk add nginx curl RUN mkdir -p /run/nginx RUN mkdir -p /data/www diff --git a/test-env/http-server/default.conf b/test-env/http-server/default.conf index 0ff33e2..5136709 100644 --- a/test-env/http-server/default.conf +++ b/test-env/http-server/default.conf @@ -3,7 +3,7 @@ server { listen 80 default_server; - listen [::]:80 default_server; + listen 8080 default_server; # Everything is a 404 location / { @@ -15,3 +15,4 @@ server { internal; } } + diff --git a/tunnel/config.go b/tunnel/config.go index 93fb5e4..24d7fd1 100644 --- a/tunnel/config.go +++ b/tunnel/config.go @@ -59,8 +59,7 @@ func (r SSHConfigFile) Get(host string) *SSHHost { localForward, err := r.getLocalForward(host) if err != nil { - localForward = &LocalForward{Local: "", Remote: ""} - log.Warningf("error reading LocalForward configuration from ssh config file. This option will not be used: %v", err) + log.Warningf("error reading LocalForward configuration from ssh config file: %v", err) } key := r.getKey(host) @@ -92,7 +91,7 @@ func (r SSHConfigFile) getLocalForward(host string) (*LocalForward, error) { } if c == "" { - return &LocalForward{Local: "", Remote: ""}, nil + return nil, nil } l := strings.Fields(c) diff --git a/tunnel/config_test.go b/tunnel/config_test.go index 1548222..17db4e1 100644 --- a/tunnel/config_test.go +++ b/tunnel/config_test.go @@ -36,7 +36,7 @@ Host example3 Port: "3306", User: "john", Key: "/path/.ssh/id_rsa", - LocalForward: &LocalForward{Local: "", Remote: ""}, + LocalForward: nil, }, }, { diff --git a/tunnel/example_test.go b/tunnel/example_test.go index dfb16e8..9ea019a 100644 --- a/tunnel/example_test.go +++ b/tunnel/example_test.go @@ -11,8 +11,7 @@ import ( // exchange data from the local address to the remote address through the // established ssh channel. func Example() { - local := "127.0.0.1:8080" - remote := "user@example.com:22" + sshChan := &tunnel.SSHChannel{Local: "127.0.0.1:8080", Remote: "user@example.com:22"} // Initialize the SSH Server configuration providing all values so // tunnel.NewServer will not try to lookup any value using $HOME/.ssh/config @@ -21,7 +20,10 @@ func Example() { log.Fatalf("error processing server options: %v\n", err) } - t := tunnel.New(local, server, remote) + t, err := tunnel.New(server, []*tunnel.SSHChannel{sshChan}) + if err != nil { + log.Fatalf("error creating tunnel: %v\n", err) + } // Start the tunnel err = t.Start() diff --git a/tunnel/testdata/dotssh/config b/tunnel/testdata/dotssh/config index bdf07b5..3a9d0ab 100644 --- a/tunnel/testdata/dotssh/config +++ b/tunnel/testdata/dotssh/config @@ -9,3 +9,11 @@ Host test* Port 2223 User mole_test2 IdentityFile ~/.ssh/other_key + +Host hostWithLocalForward + Hostname 127.0.0.1 + Port 2222 + LocalForward 8080 172.17.0.1:8080 + User mole_test + IdentityFile ~/.ssh/id_rsa + diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 1635885..56c479a 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -16,7 +16,9 @@ import ( ) const ( - HostMissing = "server host has to be provided as part of the server address" + HostMissing = "server host has to be provided as part of the server address" + RandomPortAddress = "127.0.0.1:0" + NoRemoteGiven = "cannot create a tunnel without at least one remote address" ) // Server holds the SSH Server attributes used for the client to connect to it. @@ -98,80 +100,126 @@ func (s Server) String() string { return fmt.Sprintf("[name=%s, address=%s, user=%s]", s.Name, s.Address, s.User) } -// Tunnel represents the ssh tunnel used to forward a local connection to a -// a remote endpoint through a ssh server. +type SSHChannel struct { + Local string + Remote string + listener net.Listener + conn net.Conn +} + +func (ch *SSHChannel) Close() { + if ch.listener != nil { + ch.listener.Close() + } +} + +func (ch SSHChannel) String() string { + return fmt.Sprintf("[local=%s, remote=%s]", ch.Local, ch.Remote) +} + +// Tunnel represents the ssh tunnel and the channels connecting local and +// remote endpoints. type Tunnel struct { // Ready tells when the Tunnel is ready to accept connections Ready chan bool - local string server *Server - remote string + channels []*SSHChannel done chan error client *ssh.Client - listener net.Listener } // New creates a new instance of Tunnel. -func New(localAddress string, server *Server, remoteAddress string) *Tunnel { - cfg, err := NewSSHConfigFile() - if err != nil { - log.Warningf("error to read ssh config: %v", err) - } +func New(server *Server, channels []*SSHChannel) (*Tunnel, error) { - sh := cfg.Get(server.Name) - localAddress = reconcileLocal(localAddress, sh.LocalForward.Local) - remoteAddress = reconcileRemote(remoteAddress, sh.LocalForward.Remote) + for _, channel := range channels { + if channel.Local == "" || channel.Remote == "" { + return nil, fmt.Errorf("invalid ssh channel: local=%s, remote=%s", channel.Local, channel.Remote) + } + } return &Tunnel{ - Ready: make(chan bool, 1), - local: localAddress, - server: server, - remote: remoteAddress, - done: make(chan error), - } + Ready: make(chan bool, 1), + channels: channels, + server: server, + done: make(chan error), + }, nil } -// Start creates a new ssh tunnel, allowing data exchange between the local and -// remote endpoints. +// Start creates the ssh tunnel and initialized all channels allowing data +// exchange between local and remote enpoints. func (t *Tunnel) Start() error { - var once sync.Once - - _, err := t.Listen() + err := t.Listen() if err != nil { return err } - defer t.listener.Close() + defer func() { + for _, ch := range t.channels { + ch.Close() + } + }() log.Debugf("tunnel: %s", t) - log.WithFields(log.Fields{ - "local_address": t.local, - }).Info("listening on local address") + //connecting to ssh server + err = t.dial() + if err != nil { + return err + } + + ready := make(chan *SSHChannel) - go func(t *Tunnel) { + //TODO: use waitgroup! + // wait for all port startChannels to be ready to accept connections then sends a + // message signalling the tunnel is ready + go func(tunnel *Tunnel) { + n := 0 for { + select { + case ch := <-ready: + log.WithFields(log.Fields{ + "local": ch.Local, + "remote": ch.Remote, + }).Info("tunnel is ready") + + n = n + 1 + + if n == len(tunnel.channels) { + tunnel.Ready <- true + return + } + } + } + }(t) - once.Do(func() { - t.Ready <- true - }) + for _, ch := range t.channels { + go func(channel *SSHChannel) { + var once sync.Once + var err error - conn, err := t.listener.Accept() - if err != nil { - t.done <- fmt.Errorf("error while establishing new connection: %v", err) - return - } + for { - log.WithFields(log.Fields{ - "address": conn.RemoteAddr(), - }).Debug("new connection") + once.Do(func() { + ready <- channel + }) - err = t.forward(conn) - if err != nil { - t.done <- err - return + channel.conn, err = channel.listener.Accept() + if err != nil { + t.done <- fmt.Errorf("error while establishing new connection: %v", err) + return + } + + log.WithFields(log.Fields{ + "address": channel.conn.RemoteAddr(), + }).Debug("new connection") + + err = t.startChannel(channel) + if err != nil { + t.done <- err + return + } } - } - }(t) + }(ch) + } select { case err = <-t.done: @@ -182,32 +230,40 @@ func (t *Tunnel) Start() error { } } -// Listen binds the local address configured on Tunnel. -func (t *Tunnel) Listen() (net.Listener, error) { - - if t.listener != nil { - return t.listener, nil - } +// Listen creates tcp listeners for each channel defined. +func (t *Tunnel) Listen() error { + for _, ch := range t.channels { + if ch.listener == nil { + l, err := net.Listen("tcp", ch.Local) + if err != nil { + return err + } - local, err := net.Listen("tcp", t.local) - if err != nil { - return nil, err + ch.listener = l + ch.Local = l.Addr().String() + } } - t.listener = local - t.local = local.Addr().String() - - return t.listener, nil + return nil } -func (t *Tunnel) forward(localConn net.Conn) error { - remoteConn, err := t.proxy() +func (t *Tunnel) startChannel(channel *SSHChannel) error { + if t.client == nil { + return fmt.Errorf("new channel can't be established: missing connection to the ssh server") + } + + remoteConn, err := t.client.Dial("tcp", channel.Remote) if err != nil { - return err + return fmt.Errorf("remote dial error: %s", err) } - go copyConn(localConn, remoteConn) - go copyConn(remoteConn, localConn) + go copyConn(channel.conn, remoteConn) + go copyConn(remoteConn, channel.conn) + + log.WithFields(log.Fields{ + "remote": channel.Remote, + "server": t.server, + }).Debug("new connection established to remote") return nil } @@ -219,38 +275,29 @@ func (t Tunnel) Stop() { // String returns a string representation of a Tunnel. func (t Tunnel) String() string { - return fmt.Sprintf("[local:%s, server:%s, remote:%s]", t.local, t.server.Address, t.remote) + return fmt.Sprintf("[channels:%s, server:%s]", t.channels, t.server.Address) } -func (t *Tunnel) proxy() (net.Conn, error) { - c, err := sshClientConfig(*t.server) - if err != nil { - return nil, fmt.Errorf("error generating ssh client config: %s", err) +func (t *Tunnel) dial() error { + if t.client != nil { + return nil } - if t.client == nil { - t.client, err = ssh.Dial("tcp", t.server.Address, c) - if err != nil { - return nil, fmt.Errorf("server dial error: %s", err) - } - - log.WithFields(log.Fields{ - "server": t.server, - }).Debug("new connection established to server") - + c, err := sshClientConfig(*t.server) + if err != nil { + return fmt.Errorf("error generating ssh client config: %s", err) } - remoteConn, err := t.client.Dial("tcp", t.remote) + t.client, err = ssh.Dial("tcp", t.server.Address, c) if err != nil { - return nil, fmt.Errorf("remote dial error: %s", err) + return fmt.Errorf("server dial error: %s", err) } log.WithFields(log.Fields{ - "remote": t.remote, "server": t.server, - }).Debug("new connection established to remote") + }).Debug("new connection established to server") - return remoteConn, nil + return nil } func sshClientConfig(server Server) (*ssh.ClientConfig, error) { @@ -355,29 +402,88 @@ func reconcileKey(givenKey, resolvedKey string) string { return "" } -func reconcileLocal(givenLocal, resolvedLocal string) string { +func expandAddress(address string) string { + if strings.HasPrefix(address, ":") { + return fmt.Sprintf("127.0.0.1%s", address) + } - if givenLocal == "" && resolvedLocal != "" { - return resolvedLocal + return address +} + +// BuildSSHChannels normalizes the given set of local and remote addresses, +// combining them to build a set of ssh channel objects. +func BuildSSHChannels(serverName string, local, remote []string) ([]*SSHChannel, error) { + // if not local and remote were given, try to find the addresses from the SSH + // configuration file. + if len(local) == 0 && len(remote) == 0 { + lf, err := getLocalForward(serverName) + if err != nil { + return nil, err + } + + local = []string{lf.Local} + remote = []string{lf.Remote} + } else { + + lSize := len(local) + rSize := len(remote) + + if lSize > rSize { + // if there are more local than remote addresses given, the additional + // addresses must be removed. + if rSize == 0 { + return nil, fmt.Errorf(NoRemoteGiven) + } + + local = local[0:rSize] + } else if lSize < rSize { + // if there are more remote than local addresses given, the missing local + // addresses should be configured as localhost with random ports. + nl := make([]string, rSize) + + for i, _ := range remote { + if i < lSize { + if local[i] != "" { + nl[i] = local[i] + } else { + nl[i] = RandomPortAddress + } + } else { + nl[i] = RandomPortAddress + } + } + + local = nl + } } - if givenLocal == "" { - return "127.0.0.1:0" + + for i, addr := range local { + local[i] = expandAddress(addr) } - if strings.HasPrefix(givenLocal, ":") { - return fmt.Sprintf("127.0.0.1%s", givenLocal) + + for i, addr := range remote { + remote[i] = expandAddress(addr) } - return givenLocal -} + channels := make([]*SSHChannel, len(remote)) + for i, r := range remote { + channels[i] = &SSHChannel{Local: local[i], Remote: r} + } -func reconcileRemote(givenRemote, resolvedRemote string) string { + return channels, nil +} - if givenRemote == "" && resolvedRemote != "" { - return resolvedRemote +func getLocalForward(serverName string) (*LocalForward, error) { + cfg, err := NewSSHConfigFile() + if err != nil { + return nil, fmt.Errorf("error reading ssh configuration file: %v", err) } - if strings.HasPrefix(givenRemote, ":") { - return fmt.Sprintf("127.0.0.1%s", givenRemote) + + sh := cfg.Get(serverName) + + if sh.LocalForward == nil { + return nil, fmt.Errorf("LocalForward could not be found or has invalid syntax for host %s", serverName) } - return givenRemote + return sh.LocalForward, nil } diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go index dda136f..772d12f 100644 --- a/tunnel/tunnel_test.go +++ b/tunnel/tunnel_test.go @@ -110,138 +110,90 @@ func TestServerOptions(t *testing.T) { } } -func TestTunnelOptions(t *testing.T) { - server := &Server{Name: "s"} - tests := []struct { - local string - server *Server - remote string - expected *Tunnel - }{ - { - "172.17.0.10:2222", - server, - "172.17.0.10:2222", - &Tunnel{ - local: "172.17.0.10:2222", - server: server, - remote: "172.17.0.10:2222", - }, - }, - { - "", - server, - "172.17.0.10:2222", - &Tunnel{ - local: "127.0.0.1:0", - server: server, - remote: "172.17.0.10:2222", - }, - }, - { - ":8443", - server, - ":443", - &Tunnel{ - local: "127.0.0.1:8443", - server: server, - remote: "127.0.0.1:443", - }, - }, - } - - for _, test := range tests { - tun := New(test.local, test.server, test.remote) - - if test.expected.local != tun.local { - t.Errorf("unexpected local result : expected: %s, result: %s", test.expected, tun) - } - - if test.expected.remote != tun.remote { - t.Errorf("unexpected remote result : expected: %s, result: %s", test.expected, tun) - } - - if !reflect.DeepEqual(test.expected.server, tun.server) { - t.Errorf("unexpected server result : expected: %s, result: %s", test.expected, tun) - } - - } - -} - func TestTunnel(t *testing.T) { - expected := "ABC" - tun := prepareTunnel(t, false) + tun := prepareTunnel(t, 1, false) select { case <-tun.Ready: t.Log("tunnel is ready to accept connections") case <-time.After(1 * time.Second): - t.Errorf("no connection after a while") + t.Errorf("error waiting for tunnel to be ready") return } - resp, err := http.Get(fmt.Sprintf("http://%s/%s", tun.listener.Addr(), expected)) + expected := "ABC" + err := validateTunnelConnectivity(expected, tun) if err != nil { - t.Errorf("error while making local connection: %v", err) - return - } - defer resp.Body.Close() - - body, _ := ioutil.ReadAll(resp.Body) - response := string(body) - - if expected != response { - t.Errorf("expected: %s, value: %s", expected, response) + t.Errorf("%v", err) } tun.Stop() } -func TestInsecureTunnel(t *testing.T) { +func TestTunnelInsecure(t *testing.T) { expected := "ABC" - tun := prepareTunnel(t, true) + tun := prepareTunnel(t, 1, true) select { case <-tun.Ready: t.Log("tunnel is ready to accept connections") case <-time.After(1 * time.Second): - t.Errorf("no connection after a while") + t.Errorf("error waiting for tunnel to be ready") return } - resp, err := http.Get(fmt.Sprintf("http://%s/%s", tun.listener.Addr(), expected)) + err := validateTunnelConnectivity(expected, tun) if err != nil { - t.Errorf("error while making local connection: %v", err) - return + t.Errorf("%v", err) } - defer resp.Body.Close() - body, _ := ioutil.ReadAll(resp.Body) - response := string(body) + tun.Stop() +} + +func TestTunnelMultipleRemotes(t *testing.T) { + expected := "ABC" + tun := prepareTunnel(t, 2, false) + + for i := 1; i <= 1; i++ { + select { + case <-tun.Ready: + t.Log("tunnel is ready to accept connections") + case <-time.After(1 * time.Second): + t.Errorf("error waiting for tunnel to be ready") + return + } + } - if expected != response { - t.Errorf("expected: %s, value: %s", expected, response) + err := validateTunnelConnectivity(expected, tun) + if err != nil { + t.Errorf("%v", err) } tun.Stop() } -func TestRandomLocalPort(t *testing.T) { - expected := "127.0.0.1:0" - local := "" - remote := "172.17.0.1:80" - server, _ := NewServer("", "test", "") +func validateTunnelConnectivity(expected string, tun *Tunnel) error { + for _, sshChan := range tun.channels { + resp, err := http.Get(fmt.Sprintf("http://%s/%s", sshChan.listener.Addr(), expected)) + if err != nil { + return fmt.Errorf("error while making http request: %v", err) + } + defer resp.Body.Close() - tun := New(local, server, remote) + body, _ := ioutil.ReadAll(resp.Body) + response := string(body) - if tun.local != expected { - t.Errorf("unexpected local endpoint: expected: %s, value: %s", expected, tun.local) + if expected != response { + return fmt.Errorf("expected: %s, value: %s", expected, response) + } } + + return nil } func TestCloseServerConn(t *testing.T) { - tun := &Tunnel{local: "127.0.0.1:0", done: make(chan error, 1), Ready: make(chan bool, 1)} + sshChans := []*SSHChannel{&SSHChannel{Local: "127.0.0.1:0", Remote: "172.17.0.1:80"}} + tun := &Tunnel{channels: sshChans, done: make(chan error, 1), Ready: make(chan bool, 1)} tun.client = &ssh.Client{Conn: MockConn{isConnectionOpen: true}} result := make(chan error) @@ -276,9 +228,111 @@ func TestMain(m *testing.M) { os.Exit(code) } +func TestBuildSSHChannels(t *testing.T) { + tests := []struct { + serverName string + local []string + remote []string + expected int + expectedError error + }{ + { + serverName: "test", + local: []string{":3360"}, + remote: []string{":3360"}, + expected: 1, + expectedError: nil, + }, + { + serverName: "test", + local: []string{":3360", ":8080"}, + remote: []string{":3360", ":8080"}, + expected: 2, + expectedError: nil, + }, + { + serverName: "test", + local: []string{}, + remote: []string{":3360"}, + expected: 1, + expectedError: nil, + }, + { + serverName: "test", + local: []string{":3360"}, + remote: []string{":3360", ":8080"}, + expected: 2, + expectedError: nil, + }, + { + serverName: "hostWithLocalForward", + local: []string{}, + remote: []string{}, + expected: 1, + expectedError: nil, + }, + { + serverName: "test", + local: []string{":3360", ":8080"}, + remote: []string{":3360"}, + expected: 1, + expectedError: nil, + }, + { + serverName: "test", + local: []string{":3360"}, + remote: []string{}, + expected: 0, + expectedError: fmt.Errorf(NoRemoteGiven), + }, + } + + for testId, test := range tests { + sshChannels, err := BuildSSHChannels(test.serverName, test.local, test.remote) + if err != nil { + if test.expectedError != nil { + if test.expectedError.Error() != err.Error() { + t.Errorf("error '%v' was expected, but got '%v'", test.expectedError, err) + } + } else { + t.Errorf("unable to build ssh channels objects for test %d: %v", testId, err) + } + } + + if test.expected != len(sshChannels) { + t.Errorf("wrong number of ssh channel objects created for test %d: expected: %d, value: %d", testId, test.expected, len(sshChannels)) + } + + localSize := len(test.local) + remoteSize := len(test.remote) + + // check if the local addresses match only if any address is given + if localSize > 0 && remoteSize > 0 { + for i, sshChannel := range sshChannels { + local := "" + if i < localSize { + local = test.local[i] + } else { + local = RandomPortAddress + } + + local = expandAddress(local) + + if sshChannel.Local != local { + t.Errorf("local address don't match for test %d: expected: %s, value: %s", testId, sshChannel.Local, local) + } + + } + } + } +} + // prepareTunnel creates a Tunnel object making sure all infrastructure // dependencies (ssh and http servers) are ready. -func prepareTunnel(t *testing.T, insecure bool) *Tunnel { +// +// The 'remotes' argument tells how many remote endpoints will be available +// through the tunnel. +func prepareTunnel(t *testing.T, remotes int, insecure bool) *Tunnel { ssh := createSSHServer(keyPath) srv, _ := NewServer("mole", ssh.Addr().String(), "") @@ -288,15 +342,20 @@ func prepareTunnel(t *testing.T, insecure bool) *Tunnel { generateKnownHosts(ssh.Addr().String(), publicKeyPath, knownHostsPath) } - web := createWebServer() - tun := &Tunnel{local: "127.0.0.1:0", server: srv, remote: web.Addr().String(), done: make(chan error), Ready: make(chan bool, 1)} + sshChannels := []*SSHChannel{} + for i := 1; i <= remotes; i++ { + web := createWebServer() + sshChannels = append(sshChannels, &SSHChannel{Local: "127.0.0.1:0", Remote: web.Addr().String()}) + } + + tun := &Tunnel{server: srv, channels: sshChannels, done: make(chan error), Ready: make(chan bool, 1)} - go func(t *testing.T) { + go func(t *testing.T, tun *Tunnel) { err := tun.Start() if err != nil { t.Errorf("tunnel could not be started: %v", err) } - }(t) + }(t, tun) return tun }