Skip to content

Commit

Permalink
switch DialContext to Dialer
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslu committed Oct 17, 2024
1 parent 382835f commit 7b60e95
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 25 deletions.
20 changes: 9 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ package fastdns

import (
"context"
"errors"
"net"
"time"
)

var (
// ErrMaxConns is returned when dns client reaches the max connections limitation.
ErrMaxConns = errors.New("dns client reaches the max connections limitation")
)
type Dialer interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}

// Client is an UDP client that supports DNS protocol.
type Client struct {
Expand All @@ -23,9 +21,9 @@ type Client struct {
// Timeout
Timeout time.Duration

// DialContext specifies the dial function for creating TCP/UDP connections.
// Dialer specifies the dialer for creating TCP/UDP connections.
// If it is set, Network, AddrPort and Timeout will be ignored.
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
Dialer Dialer
}

// Exchange executes a single DNS transaction, returning
Expand All @@ -39,12 +37,12 @@ func (c *Client) Exchange(ctx context.Context, req, resp *Message) (err error) {
}

func (c *Client) exchange(ctx context.Context, req, resp *Message) error {
dial := c.DialContext
if dial == nil {
dial = defaultDialer.DialContext
dialer := c.Dialer
if dialer == nil {
dialer = &net.Dialer{Timeout: c.Timeout}
}

conn, err := dial(ctx, c.Network, c.Addr)
conn, err := dialer.DialContext(ctx, c.Network, c.Addr)
if err != nil {
return err
}
Expand Down
6 changes: 0 additions & 6 deletions client_dailer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ import (
"time"
)

var defaultDialer = &NetDialer{
Dialer: &net.Dialer{
Timeout: 5 * time.Second,
},
}

type NetDialer struct {
// MaxIdleConns int
// MaxConns int
Expand Down
12 changes: 6 additions & 6 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ func TestLookupCNAME(t *testing.T) {
Network: "udp",
Addr: "1.1.1.1:53",
Timeout: 1 * time.Second,
DialContext: (&HTTPDialer{
Dialer: &HTTPDialer{
Endpoint: cloudflare,
UserAgent: "fastdns/0.9",
}).DialContext,
},
}

cname, err := client.LookupCNAME(context.Background(), host)
Expand All @@ -72,10 +72,10 @@ func TestLookupTXT(t *testing.T) {
Network: "udp",
Addr: "1.1.1.1:53",
Timeout: 1 * time.Second,
DialContext: (&HTTPDialer{
Dialer: &HTTPDialer{
Endpoint: cloudflare,
UserAgent: "fastdns/0.9",
}).DialContext,
},
}

txt, err := client.LookupTXT(context.Background(), host)
Expand All @@ -90,10 +90,10 @@ func TestLookupNetIP(t *testing.T) {
Network: "udp",
Addr: "1.1.1.1:53",
Timeout: 1 * time.Second,
DialContext: (&HTTPDialer{
Dialer: &HTTPDialer{
Endpoint: cloudflare,
UserAgent: "fastdns/0.9",
}).DialContext,
},
}

ips, err := client.LookupNetIP(context.Background(), "ip", host)
Expand Down
4 changes: 2 additions & 2 deletions cmd/fastdig/fastdig.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ func main() {
fmt.Fprintf(os.Stderr, "client=%+v parse server(\"%s\") error: %+v\n", client, server, err)
os.Exit(1)
}
client.DialContext = (&fastdns.HTTPDialer{
client.Dialer = &fastdns.HTTPDialer{
Endpoint: endpoint,
UserAgent: "fastdig/0.9",
}).DialContext
}
}

req, resp := fastdns.AcquireMessage(), fastdns.AcquireMessage()
Expand Down
2 changes: 2 additions & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ var (
ErrInvalidQuestion = errors.New("dns message does not have the expected question size")
// ErrInvalidAnswer is returned when dns message does not have the expected answer size.
ErrInvalidAnswer = errors.New("dns message does not have the expected answer size")
// ErrMaxConns is returned when dns client reaches the max connections limitation.
ErrMaxConns = errors.New("dns client reaches the max connections limitation")
)

// ParseMessage parses dns request from payload into dst and returns the error.
Expand Down

0 comments on commit 7b60e95

Please sign in to comment.