diff --git a/config.go b/config.go index 67c6484..11f4cfc 100644 --- a/config.go +++ b/config.go @@ -121,6 +121,11 @@ func (c *Config) GetConfig(name string) (cfg Config, err error) { return } +func (c *Config) Has(name string) bool { + _, ok := c.Map[name] + return ok +} + func (c *Config) Get(name string, dest interface{}) error { if obj, ok := c.Map[name]; !ok { return NewConfigError(ErrMissing, c.Name+"."+name) @@ -141,6 +146,14 @@ func (c *Config) Get(name string, dest interface{}) error { return nil } +func (c *Config) GetBool(name string) bool { + var b bool + if err := c.Get(name, &b); err != nil { + return false + } + return b +} + func convertBool(in interface{}, val reflect.Value) *ConfigError { if b, ok := in.(bool); !ok { return &ConfigError{ErrInvalidType, ""} diff --git a/tcp_tunnel.go b/tcp_tunnel.go index e3aa2e1..de0931b 100644 --- a/tcp_tunnel.go +++ b/tcp_tunnel.go @@ -2,6 +2,7 @@ package secretun import ( "bytes" + "crypto/tls" "encoding/binary" "io" "log" @@ -69,14 +70,34 @@ func packetTunnel(conn net.Conn, cli_ch ClientChan) { } func (t *RawTCP_ST) Init(cfg Config) (err error) { - var addr string + var certFile, keyFile, addr string + var cert tls.Certificate if err = cfg.Get("addr", &addr); err != nil { return } + if cfg.GetBool("tls") { + if err := cfg.Get("cert", &certFile); err != nil { + return err + } else if err := cfg.Get("key", &keyFile); err != nil { + return err + } else if cert, err = tls.LoadX509KeyPair(certFile, keyFile); err != nil { + return err + } + tls_cfg := &tls.Config{} + tls_cfg.NextProtos = []string{"http/1.1"} + tls_cfg.Certificates = []tls.Certificate{cert} + if l, err := net.Listen("tcp", addr); err != nil { + return err + } else { + t.conn = tls.NewListener(l, tls_cfg) + } + } else { + t.conn, err = net.Listen("tcp", addr) + } + log.Println("listen on ", addr) - t.conn, err = net.Listen("tcp", addr) return } @@ -87,13 +108,16 @@ func (t *RawTCP_ST) Accept() (cli_ch ClientChan, err error) { if err != nil { return } - err = conn.(*net.TCPConn).SetNoDelay(true) - if err != nil { - return - } - err = conn.(*net.TCPConn).SetKeepAlive(true) - if err != nil { - return + + if tcp_conn, ok := conn.(*net.TCPConn); ok { + err = tcp_conn.SetNoDelay(true) + if err != nil { + return + } + err = tcp_conn.SetKeepAlive(true) + if err != nil { + return + } } cli_ch = NewClientChan() @@ -114,15 +138,25 @@ func (t *RawTCP_CT) Init(cfg Config) (err error) { log.Println("connect to ", addr) - if t.conn, err = net.Dial("tcp", addr); err != nil { - return + if cfg.GetBool("tls") { + tls_cfg := &tls.Config{} + tls_cfg.InsecureSkipVerify = true + if t.conn, err = tls.Dial("tcp", addr, tls_cfg); err != nil { + return + } + } else { + if t.conn, err = net.Dial("tcp", addr); err != nil { + return + } } - if err = t.conn.(*net.TCPConn).SetNoDelay(true); err != nil { - return - } + if tcp_conn, ok := t.conn.(*net.TCPConn); ok { + if err = tcp_conn.SetNoDelay(true); err != nil { + return + } - err = t.conn.(*net.TCPConn).SetKeepAlive(true) + err = tcp_conn.SetKeepAlive(true) + } return }