Skip to content

Commit

Permalink
add tls dialer
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslu committed Oct 27, 2024
1 parent cf59546 commit c10854f
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 15 deletions.
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ func (c *Client) exchange(ctx context.Context, req, resp *Message) error {

_, err = conn.Write(req.Raw)
if err != nil {
return nil
return err
}

resp.Raw = resp.Raw[:cap(resp.Raw)]
n, err := conn.Read(resp.Raw)
if err != nil {
return nil
return err
}

resp.Raw = resp.Raw[:n]
Expand Down
87 changes: 87 additions & 0 deletions client_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fastdns

import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -66,6 +67,92 @@ func (d *UDPDialer) put(conn net.Conn) {
d.conns <- conn
}

// TLSDialer is a custom dialer for creating TLS connections.
// It manages a pool of connections to optimize performance in scenarios
// where multiple TLS connections to the same server are required.
type TLSDialer struct {
// Addr specifies the remote TLS address that the dialer will connect to.
Addr *net.TCPAddr

TLSConfig *tls.Config

// Timeout specifies the maximum duration for a query to complete.
// If a query exceeds this duration, it will result in a timeout error.
Timeout time.Duration

// MaxConns limits the maximum number of TLS connections that can be created
// and reused. Once this limit is reached, no new connections will be made.
// If not set, use 8 as default.
MaxConns uint16

once sync.Once
conns chan net.Conn
}

func (d *TLSDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
return d.get()
}

func (d *TLSDialer) get() (_ net.Conn, err error) {
d.once.Do(func() {
if d.MaxConns == 0 {
d.MaxConns = 8
}
d.conns = make(chan net.Conn, d.MaxConns)
for range d.MaxConns {
d.conns <- &tlsConn{nil, d, make([]byte, 0, 1024)}
}
})

if err != nil {
return
}

c := <-d.conns

return c, nil
}

func (d *TLSDialer) put(conn net.Conn) {
d.conns <- conn
}

type tlsConn struct {
*tls.Conn
dialer *TLSDialer
buffer []byte
}

func (c *tlsConn) Write(b []byte) (int, error) {
if c.Conn == nil {
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: c.dialer.Timeout}, "tcp", c.dialer.Addr.String(), c.dialer.TLSConfig)
if err != nil {
return 0, err
}
c.Conn = conn
}

n := len(b)
c.buffer = append(c.buffer[:0], byte(n>>8), byte(n&0xFF))
c.buffer = append(c.buffer, b...)
_, err := c.Conn.Write(c.buffer)
return n, err
}

func (c *tlsConn) Read(b []byte) (n int, err error) {
c.buffer = c.buffer[:cap(c.buffer)]
n, err = c.Conn.Read(c.buffer)
if err != nil {
return
}
m := int(c.buffer[0])<<8 | int(c.buffer[1])
if m+2 != n {
return 0, ErrInvalidAnswer
}
copy(b, c.buffer[2:n])
return n - 2, nil
}

// HTTPDialer is a custom dialer for creating HTTP connections.
// It allows sending HTTP requests with a specified endpoint, user agent, and transport configuration.
type HTTPDialer struct {
Expand Down
49 changes: 36 additions & 13 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,14 @@ func TestClientLookup(t *testing.T) {
Addr: "1.1.1.1:53",
Dialer: &UDPDialer{
Addr: func() (u *net.UDPAddr) { u, _ = net.ResolveUDPAddr("udp", "1.1.1.1:53"); return }(),
MaxConns: 1000,
MaxConns: 16,
},
},
{
Addr: "1.1.1.1:853",
Dialer: &TLSDialer{
Addr: func() (u *net.TCPAddr) { u, _ = net.ResolveTCPAddr("tcp", "1.1.1.1:853"); return }(),
MaxConns: 16,
},
},
{
Expand Down Expand Up @@ -163,18 +170,6 @@ func BenchmarkResolverPureGo(b *testing.B) {
})
}

func BenchmarkResolverCGO(b *testing.B) {
resolver := net.Resolver{PreferGo: false}

b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(b *testing.PB) {
for b.Next() {
_, _ = resolver.LookupNetIP(context.Background(), "ip4", "www.google.com")
}
})
}

func BenchmarkResolverFastdnsDefault(b *testing.B) {
server := "8.8.8.8:53"
if data, err := os.ReadFile("/etc/resolv.conf"); err == nil {
Expand Down Expand Up @@ -260,6 +255,34 @@ func BenchmarkResolverFastdnsUDPAppend(b *testing.B) {
})
}

func BenchmarkResolverFastdnsTLS(b *testing.B) {
server := "1.1.1.1:853"

resolver := &Client{
Addr: server,
Dialer: &TLSDialer{
Addr: func() (u *net.TCPAddr) { u, _ = net.ResolveTCPAddr("tcp", server); return }(),
MaxConns: 8,
TLSConfig: &tls.Config{
InsecureSkipVerify: false,
ServerName: server,
ClientSessionCache: tls.NewLRUClientSessionCache(1024),
},
},
}

b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ips, err := resolver.LookupNetIP(context.Background(), "ip4", "www.google.com")
if len(ips) == 0 || err != nil {
b.Errorf("fastdns return ips: %+v error: %+v", ips, err)
}
}
})
}

func BenchmarkResolverFastdnsHTTP(b *testing.B) {
server := "1.1.1.1"

Expand Down

0 comments on commit c10854f

Please sign in to comment.