Skip to content

Commit

Permalink
Merge pull request #9 from Snawoot/dialer_chain
Browse files Browse the repository at this point in the history
Dialer chain
  • Loading branch information
Snawoot authored Oct 8, 2024
2 parents 7dec8bb + 54b476c commit 46bf324
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 14 deletions.
8 changes: 8 additions & 0 deletions conn/dialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package conn

import (
"context"
"net"
)

type ContextDialer = func(ctx context.Context, network, address string) (net.Conn, error)
9 changes: 4 additions & 5 deletions conn/plainfactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,22 @@ import (
"fmt"
"net"
"strconv"
"time"
)

type PlainConnFactory struct {
addr string
dialer *net.Dialer
dialer ContextDialer
}

func NewPlainConnFactory(host string, port uint16, timeout time.Duration) *PlainConnFactory {
func NewPlainConnFactory(host string, port uint16, dialer ContextDialer) *PlainConnFactory {
return &PlainConnFactory{
addr: net.JoinHostPort(host, strconv.Itoa(int(port))),
dialer: &net.Dialer{Timeout: timeout},
dialer: dialer,
}
}

func (cf *PlainConnFactory) DialContext(ctx context.Context) (WrappedConn, error) {
conn, err := cf.dialer.DialContext(ctx, "tcp", cf.addr)
conn, err := cf.dialer(ctx, "tcp", cf.addr)
if err != nil {
return nil, fmt.Errorf("cf.dialer.DialContext(ctx, \"tcp\", %q) failed: %v", cf.addr, err)
}
Expand Down
9 changes: 4 additions & 5 deletions conn/tlsfactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io/ioutil"
"net"
"strconv"
"time"

"golang.org/x/sync/semaphore"

Expand All @@ -19,11 +18,11 @@ import (
type TLSConnFactory struct {
addr string
tlsConfig *tls.Config
dialer *net.Dialer
dialer ContextDialer
sem *semaphore.Weighted
}

func NewTLSConnFactory(host string, port uint16, timeout time.Duration,
func NewTLSConnFactory(host string, port uint16, dialer ContextDialer,
certfile, keyfile string, cafile string, hostname_check bool,
tls_servername string, dialers uint, sessionCache tls.ClientSessionCache, logger *clog.CondLogger) (*TLSConnFactory, error) {
if !hostname_check && cafile == "" {
Expand Down Expand Up @@ -88,7 +87,7 @@ func NewTLSConnFactory(host string, port uint16, timeout time.Duration,
return &TLSConnFactory{
addr: net.JoinHostPort(host, strconv.Itoa(int(port))),
tlsConfig: &tlsConfig,
dialer: &net.Dialer{Timeout: timeout},
dialer: dialer,
sem: semaphore.NewWeighted(int64(dialers)),
}, nil
}
Expand All @@ -98,7 +97,7 @@ func (cf *TLSConnFactory) DialContext(ctx context.Context) (WrappedConn, error)
return nil, errors.New("Context was cancelled")
}
defer cf.sem.Release(1)
netConn, err := cf.dialer.DialContext(ctx, "tcp", cf.addr)
netConn, err := cf.dialer(ctx, "tcp", cf.addr)
if err != nil {
return nil, fmt.Errorf("cf.dialer.DialContext(ctx, \"tcp\", %q) failed: %v", cf.addr, err)
}
Expand Down
39 changes: 39 additions & 0 deletions dnscache/wrapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package dnscache

import (
"context"
"fmt"
"net"
"time"

"github.com/Vonage/gosrvlib/pkg/dnscache"
)

type ContextDialer = func(ctx context.Context, network, address string) (net.Conn, error)

func WrapDialer(dialer ContextDialer, size int, ttl time.Duration) ContextDialer {
cache := dnscache.New(net.DefaultResolver, size, ttl)
wrapped := func(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("failed to extract host and port from %s: %w", address, err)
}

ips, err := cache.LookupHost(ctx, host)
if err != nil {
return nil, err
}

var conn net.Conn

for _, ip := range ips {
conn, err = dialer(ctx, network, net.JoinHostPort(ip, port))
if err == nil {
return conn, nil
}
}

return nil, fmt.Errorf("failed to dial %s: %w", address, err)
}
return wrapped
}
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
module github.com/Snawoot/steady-tun

go 1.17
go 1.23

toolchain go1.23.2

require (
github.com/Vonage/gosrvlib v1.101.5
github.com/huandu/skiplist v1.2.1
golang.org/x/sync v0.8.0
)
13 changes: 12 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
github.com/Vonage/gosrvlib v1.101.5 h1:SQTkcSKOAXwIwSBJ+Db0wLllmyObUUJ0Pmi4FISPU/s=
github.com/Vonage/gosrvlib v1.101.5/go.mod h1:sbtl37NUGtcL/C87oKkuGhSvaB7XxBdEh6e0pwg8nAk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c=
github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U=
github.com/huandu/skiplist v1.2.1 h1:dTi93MgjwErA/8idWTzIw4Y1kZsMWx35fmI2c8Rij7w=
github.com/huandu/skiplist v1.2.1/go.mod h1:7v3iFjLcSAzO4fN5B8dvebvo/qsfumiLiDXMrPiHF9w=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
17 changes: 15 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import (
"flag"
"fmt"
"log"
"net"
"os"
"os/signal"
"runtime"
"syscall"
"time"

conn "github.com/Snawoot/steady-tun/conn"
"github.com/Snawoot/steady-tun/dnscache"
clog "github.com/Snawoot/steady-tun/log"
"github.com/Snawoot/steady-tun/pool"
"github.com/Snawoot/steady-tun/server"
Expand Down Expand Up @@ -47,6 +49,7 @@ type CLIArgs struct {
tls_servername string
tlsSessionCache bool
tlsEnabled bool
dnsCacheTTL time.Duration
showVersion bool
}

Expand All @@ -72,6 +75,7 @@ func parse_args() CLIArgs {
flag.BoolVar(&args.tlsSessionCache, "tls-session-cache", true, "enable TLS session cache")
flag.BoolVar(&args.showVersion, "version", false, "show program version and exit")
flag.BoolVar(&args.tlsEnabled, "tls-enabled", true, "enable TLS client for pool connections")
flag.DurationVar(&args.dnsCacheTTL, "dns-cache-ttl", 30*time.Second, "DNS cache TTL")
flag.Parse()
if args.showVersion {
return args
Expand Down Expand Up @@ -116,17 +120,26 @@ func main() {
args.verbosity)

var (
dialer conn.ContextDialer
connfactory pool.ConnFactory
err error
)
dialer = (&net.Dialer{
Timeout: args.timeout,
}).DialContext

if args.dnsCacheTTL > 0 {
dialer = dnscache.WrapDialer(dialer, 1, args.dnsCacheTTL)
}

if args.tlsEnabled {
var sessionCache tls.ClientSessionCache
if args.tlsSessionCache {
sessionCache = tls.NewLRUClientSessionCache(2 * int(args.pool_size))
}
connfactory, err = conn.NewTLSConnFactory(args.host,
uint16(args.port),
args.timeout,
dialer,
args.cert,
args.key,
args.cafile,
Expand All @@ -139,7 +152,7 @@ func main() {
panic(err)
}
} else {
connfactory = conn.NewPlainConnFactory(args.host, uint16(args.port), args.timeout)
connfactory = conn.NewPlainConnFactory(args.host, uint16(args.port), dialer)
}
connPool := pool.NewConnPool(args.pool_size, args.ttl, args.backoff, connfactory, poolLogger)
connPool.Start()
Expand Down

0 comments on commit 46bf324

Please sign in to comment.