Skip to content

Commit

Permalink
Add -4 and -6 flags (redo)
Browse files Browse the repository at this point in the history
This PR is a redo of go-acme#1802. Since that PR has been idle so long,
the branches have diverged quite a bit and it was easier to start
anew.

The work in this PR includes the work originally done by dmke in go-acme#1802.

This PR is to resolve go-acme#1801.
  • Loading branch information
jsumners committed Jul 11, 2024
1 parent 321cea5 commit c81ec83
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ issues:
text: 'dnsTimeout is a global variable'
- path: challenge/dns01/nameserver_test.go
text: 'findXByFqdnTestCases is a global variable'
- path: challenge/dns01/network.go
text: 'currentNetworkStack is a global variable'
- path: challenge/http01/domain_matcher.go
text: 'string `Host` has \d occurrences, make it a constant'
- path: challenge/http01/domain_matcher.go
Expand Down
14 changes: 10 additions & 4 deletions challenge/dns01/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {

func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
network := currentNetworkStack.Network("tcp")
tcp := &dns.Client{Net: network, Timeout: dnsTimeout}
r, _, err := tcp.Exchange(m, ns)
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
Expand All @@ -274,11 +275,16 @@ func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
return r, nil
}

udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
udpNetwork := currentNetworkStack.Network("udp")
udp := &dns.Client{Net: udpNetwork, Timeout: dnsTimeout}
r, _, err := udp.Exchange(m, ns)

if r != nil && r.Truncated {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
// We can encounter a net.OpError if the nameserver is not listening
// on UDP at all, i.e. net.Dial could not make a connection.
var opErr *net.OpError
if (r != nil && r.Truncated) || errors.As(err, &opErr) {
tcpNetwork := currentNetworkStack.Network("tcp")
tcp := &dns.Client{Net: tcpNetwork, Timeout: dnsTimeout}
// If the TCP request succeeds, the "err" will reset to nil
r, _, err = tcp.Exchange(m, ns)
}
Expand Down
125 changes: 123 additions & 2 deletions challenge/dns01/nameserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,133 @@ package dns01

import (
"errors"
"net"
"sort"
"sync"
"testing"

"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func testDNSHandler(writer dns.ResponseWriter, reply *dns.Msg) {
msg := dns.Msg{}
msg.SetReply(reply)

if reply.Question[0].Qtype == dns.TypeA {
msg.Authoritative = true
domain := msg.Question[0].Name
msg.Answer = append(
msg.Answer,
&dns.A{
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 60,
},
A: net.IPv4(127, 0, 0, 1),
},
)
}

_ = writer.WriteMsg(&msg)
}

// getTestNameserver constructs a new DNS server on a local address, or set
// of addresses, that responds to an `A` query for `example.com`.
func getTestNameserver(t *testing.T, network string) *dns.Server {
t.Helper()
server := &dns.Server{
Handler: dns.HandlerFunc(testDNSHandler),
Net: network,
}
switch network {
case "tcp", "udp":
server.Addr = "0.0.0.0:0"
case "tcp4", "udp4":
server.Addr = "127.0.0.1:0"
case "tcp6", "udp6":
server.Addr = "[::1]:0"
}

waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock

go func() { _ = server.ListenAndServe() }()

waitLock.Lock()
return server
}

func startTestNameserver(t *testing.T, stack networkStack, proto string) (shutdown func(), addr string) {
t.Helper()
currentNetworkStack = stack
srv := getTestNameserver(t, currentNetworkStack.Network(proto))

shutdown = func() { _ = srv.Shutdown() }
if proto == "tcp" {
addr = srv.Listener.Addr().String()
} else {
addr = srv.PacketConn.LocalAddr().String()
}
return
}

func TestSendDNSQuery(t *testing.T) {
currentNameservers := recursiveNameservers

t.Cleanup(func() {
recursiveNameservers = currentNameservers
currentNetworkStack = dualStack
})

t.Run("does udp4 only", func(t *testing.T) {
stop, addr := startTestNameserver(t, ipv4only, "udp")
defer stop()

recursiveNameservers = ParseNameservers([]string{addr})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, addr)
require.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
})

t.Run("does udp6 only", func(t *testing.T) {
stop, addr := startTestNameserver(t, ipv6only, "udp")
defer stop()

recursiveNameservers = ParseNameservers([]string{addr})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, addr)
require.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
})

t.Run("does tcp4 and tcp6", func(t *testing.T) {
stop, addr := startTestNameserver(t, dualStack, "tcp")
host, port, _ := net.SplitHostPort(addr)
defer stop()
t.Logf("### port: %s", port)

addr6 := net.JoinHostPort(host, port)
recursiveNameservers = ParseNameservers([]string{addr6})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, addr6)
require.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")

addr4 := net.JoinHostPort("127.0.0.1", port)
recursiveNameservers = ParseNameservers([]string{addr4})
msg = createDNSMsg("example.com.", dns.TypeA, true)
result, queryError = sendDNSQuery(msg, addr4)
require.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
})
}

func TestLookupNameserversOK(t *testing.T) {
testCases := []struct {
fqdn string
Expand Down Expand Up @@ -123,8 +242,10 @@ var findXByFqdnTestCases = []struct {
fqdn: "mail.google.com.",
zone: "google.com.",
nameservers: []string{":7053", ":8053", ":9053"},
// use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053.
expectedError: "[fqdn=mail.google.com.] could not find the start of authority for 'mail.google.com.': DNS call error: read udp ",
// NOTE: On Windows, net.DialContext finds a way down to the ContectEx syscall.
// There a fault is marked as "connectex", not "connect", see
// https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/net/fd_windows.go;l=112
expectedError: "could not find the start of authority for 'mail.google.com.':",
},
{
desc: "no nameservers",
Expand Down
41 changes: 41 additions & 0 deletions challenge/dns01/network.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package dns01

// networkStack is used to indicate which IP stack should be used for DNS queries.
type networkStack int

const (
dualStack networkStack = iota
ipv4only
ipv6only
)

// currentNetworkStack is used to define which IP stack will be used. The default is
// both IPv4 and IPv6. Set to IPv4Only or IPv6Only to select either version.
var currentNetworkStack = dualStack

// Network interprets the NetworkStack setting in relation to the desired
// protocol. The proto value should be either "udp" or "tcp".
func (s networkStack) Network(proto string) string {
// The DNS client passes whatever value is set in (*dns.Client).Net to
// the [net.Dialer](https://github.com/miekg/dns/blob/fe20d5d/client.go#L119-L141).
// And the net.Dialer accepts strings such as "udp4" or "tcp6"
// (https://cs.opensource.google/go/go/+/refs/tags/go1.18.9:src/net/dial.go;l=167-182).
switch s {
case ipv4only:
return proto + "4"
case ipv6only:
return proto + "6"
default:
return proto
}
}

// SetIPv4Only forces DNS queries to only happen over the IPv4 stack.
func SetIPv4Only() { currentNetworkStack = ipv4only }

// SetIPv6Only forces DNS queries to only happen over the IPv6 stack.
func SetIPv6Only() { currentNetworkStack = ipv6only }

// SetDualStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use.
func SetDualStack() { currentNetworkStack = dualStack }
22 changes: 22 additions & 0 deletions challenge/http01/http_challenge_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer
return &ProviderServer{network: "unix", address: socketPath, socketMode: mode, matcher: &hostMatcher{}}
}

// SetIPv4Only starts the challenge server on an IPv4 address.
//
// Calling this method has no effect if s was created with NewUnixProviderServer.
func (s *ProviderServer) SetIPv4Only() { s.setTCPStack("tcp4") }

// SetIPv6Only starts the challenge server on an IPv6 address.
//
// Calling this method has no effect if s was created with NewUnixProviderServer.
func (s *ProviderServer) SetIPv6Only() { s.setTCPStack("tcp6") }

// SetDualStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use for the challenge server.
//
// Calling this method has no effect if s was created with NewUnixProviderServer.
func (s *ProviderServer) SetDualStack() { s.setTCPStack("tcp") }

func (s *ProviderServer) setTCPStack(network string) {
if s.network != "unix" {
s.network = network
}
}

// Present starts a web server and makes the token available at `ChallengePath(token)` for web requests.
func (s *ProviderServer) Present(domain, token, keyAuth string) error {
var err error
Expand Down
17 changes: 17 additions & 0 deletions challenge/http01/http_challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func TestProviderServer_GetAddress(t *testing.T) {
testCases := []struct {
desc string
server *ProviderServer
network func(server *ProviderServer)
expected string
}{
{
Expand All @@ -49,6 +50,18 @@ func TestProviderServer_GetAddress(t *testing.T) {
server: NewProviderServer("localhost", "8080"),
expected: "localhost:8080",
},
{
desc: "TCP4 with host and port",
server: NewProviderServer("localhost", "8080"),
network: func(s *ProviderServer) { s.SetIPv4Only() },
expected: "localhost:8080",
},
{
desc: "TCP6 with host and port",
server: NewProviderServer("localhost", "8080"),
network: func(s *ProviderServer) { s.SetIPv6Only() },
expected: "localhost:8080",
},
{
desc: "UDS socket",
server: NewUnixProviderServer(sock, fs.ModeSocket|0o666),
Expand All @@ -60,6 +73,10 @@ func TestProviderServer_GetAddress(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()

if test.network != nil {
test.network(test.server)
}

address := test.server.GetAddress()
assert.Equal(t, test.expected, address)
})
Expand Down
18 changes: 16 additions & 2 deletions challenge/tlsalpn01/tls_alpn_challenge_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,30 @@ const (
type ProviderServer struct {
iface string
port string
network string
listener net.Listener
}

// NewProviderServer creates a new ProviderServer on the selected interface and port.
// Setting iface and / or port to an empty string will make the server fall back to
// the "any" interface and port 443 respectively.
func NewProviderServer(iface, port string) *ProviderServer {
return &ProviderServer{iface: iface, port: port}
if port == "" {
port = defaultTLSPort
}
return &ProviderServer{iface: iface, port: port, network: "tcp"}
}

// SetIPv4Only starts the challenge server on an IPv4 address.
func (s *ProviderServer) SetIPv4Only() { s.network = "tcp4" }

// SetIPv6Only starts the challenge server on an IPv6 address.
func (s *ProviderServer) SetIPv6Only() { s.network = "tcp6" }

// SetDualStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use for the challenge server.
func (s *ProviderServer) SetDualStack() { s.network = "tcp" }

func (s *ProviderServer) GetAddress() string {
return net.JoinHostPort(s.iface, s.port)
}
Expand Down Expand Up @@ -65,7 +79,7 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error {
tlsConf.NextProtos = []string{ACMETLS1Protocol}

// Create the listener with the created tls.Config.
s.listener, err = tls.Listen("tcp", s.GetAddress(), tlsConf)
s.listener, err = tls.Listen(s.network, s.GetAddress(), tlsConf)
if err != nil {
return fmt.Errorf("could not start HTTPS server for challenge: %w", err)
}
Expand Down
Loading

0 comments on commit c81ec83

Please sign in to comment.