From c10854f30785e0d709ce549162543c2f8115e3f1 Mon Sep 17 00:00:00 2001 From: phuslu Date: Sat, 26 Oct 2024 23:54:12 +0800 Subject: [PATCH] add tls dialer --- client.go | 4 +-- client_dialer.go | 87 ++++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 49 +++++++++++++++++++-------- 3 files changed, 125 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index c8e06ca..9693674 100644 --- a/client.go +++ b/client.go @@ -56,13 +56,13 @@ func (c *Client) exchange(ctx context.Context, req, resp *Message) error { _, err = conn.Write(req.Raw) if err != nil { - return nil + return err } resp.Raw = resp.Raw[:cap(resp.Raw)] n, err := conn.Read(resp.Raw) if err != nil { - return nil + return err } resp.Raw = resp.Raw[:n] diff --git a/client_dialer.go b/client_dialer.go index 163fb7a..6f6b461 100644 --- a/client_dialer.go +++ b/client_dialer.go @@ -2,6 +2,7 @@ package fastdns import ( "context" + "crypto/tls" "fmt" "io" "net" @@ -66,6 +67,92 @@ func (d *UDPDialer) put(conn net.Conn) { d.conns <- conn } +// TLSDialer is a custom dialer for creating TLS connections. +// It manages a pool of connections to optimize performance in scenarios +// where multiple TLS connections to the same server are required. +type TLSDialer struct { + // Addr specifies the remote TLS address that the dialer will connect to. + Addr *net.TCPAddr + + TLSConfig *tls.Config + + // Timeout specifies the maximum duration for a query to complete. + // If a query exceeds this duration, it will result in a timeout error. + Timeout time.Duration + + // MaxConns limits the maximum number of TLS connections that can be created + // and reused. Once this limit is reached, no new connections will be made. + // If not set, use 8 as default. + MaxConns uint16 + + once sync.Once + conns chan net.Conn +} + +func (d *TLSDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { + return d.get() +} + +func (d *TLSDialer) get() (_ net.Conn, err error) { + d.once.Do(func() { + if d.MaxConns == 0 { + d.MaxConns = 8 + } + d.conns = make(chan net.Conn, d.MaxConns) + for range d.MaxConns { + d.conns <- &tlsConn{nil, d, make([]byte, 0, 1024)} + } + }) + + if err != nil { + return + } + + c := <-d.conns + + return c, nil +} + +func (d *TLSDialer) put(conn net.Conn) { + d.conns <- conn +} + +type tlsConn struct { + *tls.Conn + dialer *TLSDialer + buffer []byte +} + +func (c *tlsConn) Write(b []byte) (int, error) { + if c.Conn == nil { + conn, err := tls.DialWithDialer(&net.Dialer{Timeout: c.dialer.Timeout}, "tcp", c.dialer.Addr.String(), c.dialer.TLSConfig) + if err != nil { + return 0, err + } + c.Conn = conn + } + + n := len(b) + c.buffer = append(c.buffer[:0], byte(n>>8), byte(n&0xFF)) + c.buffer = append(c.buffer, b...) + _, err := c.Conn.Write(c.buffer) + return n, err +} + +func (c *tlsConn) Read(b []byte) (n int, err error) { + c.buffer = c.buffer[:cap(c.buffer)] + n, err = c.Conn.Read(c.buffer) + if err != nil { + return + } + m := int(c.buffer[0])<<8 | int(c.buffer[1]) + if m+2 != n { + return 0, ErrInvalidAnswer + } + copy(b, c.buffer[2:n]) + return n - 2, nil +} + // HTTPDialer is a custom dialer for creating HTTP connections. // It allows sending HTTP requests with a specified endpoint, user agent, and transport configuration. type HTTPDialer struct { diff --git a/client_test.go b/client_test.go index 4de82cf..5e307d0 100644 --- a/client_test.go +++ b/client_test.go @@ -87,7 +87,14 @@ func TestClientLookup(t *testing.T) { Addr: "1.1.1.1:53", Dialer: &UDPDialer{ Addr: func() (u *net.UDPAddr) { u, _ = net.ResolveUDPAddr("udp", "1.1.1.1:53"); return }(), - MaxConns: 1000, + MaxConns: 16, + }, + }, + { + Addr: "1.1.1.1:853", + Dialer: &TLSDialer{ + Addr: func() (u *net.TCPAddr) { u, _ = net.ResolveTCPAddr("tcp", "1.1.1.1:853"); return }(), + MaxConns: 16, }, }, { @@ -163,18 +170,6 @@ func BenchmarkResolverPureGo(b *testing.B) { }) } -func BenchmarkResolverCGO(b *testing.B) { - resolver := net.Resolver{PreferGo: false} - - b.ReportAllocs() - b.ResetTimer() - b.RunParallel(func(b *testing.PB) { - for b.Next() { - _, _ = resolver.LookupNetIP(context.Background(), "ip4", "www.google.com") - } - }) -} - func BenchmarkResolverFastdnsDefault(b *testing.B) { server := "8.8.8.8:53" if data, err := os.ReadFile("/etc/resolv.conf"); err == nil { @@ -260,6 +255,34 @@ func BenchmarkResolverFastdnsUDPAppend(b *testing.B) { }) } +func BenchmarkResolverFastdnsTLS(b *testing.B) { + server := "1.1.1.1:853" + + resolver := &Client{ + Addr: server, + Dialer: &TLSDialer{ + Addr: func() (u *net.TCPAddr) { u, _ = net.ResolveTCPAddr("tcp", server); return }(), + MaxConns: 8, + TLSConfig: &tls.Config{ + InsecureSkipVerify: false, + ServerName: server, + ClientSessionCache: tls.NewLRUClientSessionCache(1024), + }, + }, + } + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ips, err := resolver.LookupNetIP(context.Background(), "ip4", "www.google.com") + if len(ips) == 0 || err != nil { + b.Errorf("fastdns return ips: %+v error: %+v", ips, err) + } + } + }) +} + func BenchmarkResolverFastdnsHTTP(b *testing.B) { server := "1.1.1.1"