Skip to content

Commit

Permalink
Add support for -4 and -6 switches.
Browse files Browse the repository at this point in the history
  • Loading branch information
jsumners committed Jan 18, 2023
1 parent db0257c commit cdf3a6a
Show file tree
Hide file tree
Showing 12 changed files with 348 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .golangci.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
text = "(tlsFeatureExtensionOID|ocspMustStapleFeature) is a global variable"
[[issues.exclude-rules]]
path = "challenge/dns01/nameserver.go"
text = "(defaultNameservers|recursiveNameservers|fqdnSoaCache|muFqdnSoaCache) is a global variable"
text = "(defaultNameservers|recursiveNameservers|currentNetworkStack|fqdnSoaCache|muFqdnSoaCache) is a global variable"
[[issues.exclude-rules]]
path = "challenge/dns01/nameserver_.+.go"
text = "dnsTimeout is a global variable"
Expand Down
50 changes: 47 additions & 3 deletions challenge/dns01/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ var defaultNameservers = []string{
// recursiveNameservers are used to pre-check DNS propagation.
var recursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)

// NetworkStack is used to indicate which IP stack should be used for DNS
// queries. Valid values are DefaultNetworkStack, IPv4Only, and IPv6Only.
type NetworkStack int

const (
// DefaultNetworkStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use.
DefaultNetworkStack NetworkStack = iota
// IPv4Only forces DNS queries to only happen over the IPv4 stack.
IPv4Only
// IPv6Only forces DNS queries to only happen over the IPv6 stack.
IPv6Only
)

// currentNetworkStack is used to define which IP stack will be used. The default is
// both IPv4 and IPv6. Set to IPv4Only or IPv6Only to select either version.
var currentNetworkStack = DefaultNetworkStack

// soaCacheEntry holds a cached SOA record (only selected fields).
type soaCacheEntry struct {
zone string // zone apex (a domain name)
Expand Down Expand Up @@ -67,6 +85,11 @@ func AddRecursiveNameservers(nameservers []string) ChallengeOption {
}
}

// SetNetworkStack defines the IP stack that will be used for DNS queries.
func SetNetworkStack(network NetworkStack) {
currentNetworkStack = network
}

// getNameservers attempts to get systems nameservers before falling back to the defaults.
func getNameservers(path string, defaults []string) []string {
config, err := dns.ClientConfigFromFile(path)
Expand Down Expand Up @@ -249,12 +272,33 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
return m
}

// getNetwork interprets the NetworkStack setting in relation to the desired
// protocol. The proto value should be either "udp" or "tcp".
func getNetwork(proto string) string {
// The dns client passes whatever value is set in [dns.Client.Net] to
// the [net.Dialer] (https://github.com/miekg/dns/blob/fe20d5d/client.go#L119-L141).
// And the [net.Dialer] accepts strings such as "udp4" or "tcp6"
// (https://cs.opensource.google/go/go/+/refs/tags/go1.18.9:src/net/dial.go;l=167-182).
if currentNetworkStack == IPv4Only {
return proto + "4"
}
if currentNetworkStack == IPv6Only {
return proto + "6"
}
return proto
}

func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
network := getNetwork("udp")
udp := &dns.Client{Net: network, Timeout: dnsTimeout}
in, _, err := udp.Exchange(m, ns)

if in != nil && in.Truncated {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
network = getNetwork("tcp")
// We can encounter a net.OpError if the nameserver is not listening
// on UDP at all, i.e. net.Dial could not make a connection.
_, isOpErr := err.(*net.OpError)
if (in != nil && in.Truncated) || isOpErr {
tcp := &dns.Client{Net: network, Timeout: dnsTimeout}
// If the TCP request succeeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns)
}
Expand Down
138 changes: 138 additions & 0 deletions challenge/dns01/nameserver_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,151 @@
package dns01

import (
"fmt"
getport "github.com/jsumners/go-getport"
"github.com/miekg/dns"
"net"
"sort"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type testDnsHandler struct{}
type testDnsServer struct {
*dns.Server
getport.PortResult
}

func (handler *testDnsHandler) ServeDNS(writer dns.ResponseWriter, reply *dns.Msg) {
msg := dns.Msg{}
msg.SetReply(reply)

switch reply.Question[0].Qtype {
case dns.TypeA:
msg.Authoritative = true
domain := msg.Question[0].Name
msg.Answer = append(
msg.Answer,
&dns.A{
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 60,
},
A: net.ParseIP("127.0.0.1"),
},
)
}

writer.WriteMsg(&msg)
}

// getTestNameserver constructs a new DNS server on a local address, or set
// of addresses, that responds to an `A` query for `example.com`.
func getTestNameserver(t *testing.T, network string) testDnsServer {
server := &dns.Server{
Handler: new(testDnsHandler),
Net: network,
}
testServer := testDnsServer{
Server: server,
}

var protocol getport.Protocol
var address string
switch network {
case "tcp":
protocol = getport.TCP
address = "0.0.0.0"
case "tcp4":
protocol = getport.TCP4
address = "127.0.0.1"
case "tcp6":
protocol = getport.TCP6
address = "::1"
case "udp":
protocol = getport.UDP
address = "0.0.0.0"
case "udp4":
protocol = getport.UDP4
address = "127.0.0.1"
case "udp6":
protocol = getport.UDP6
address = "::1"
}
portResult, portError := getport.GetPort(protocol, address)
if portError != nil {
t.Error(portError)
return testServer
}
testServer.PortResult = portResult
server.Addr = getport.PortResultToAddress(portResult)

waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock

fin := make(chan error, 1)
go func() {
fin <- server.ListenAndServe()
}()

waitLock.Lock()
return testServer
}

func TestSendDNSQuery(t *testing.T) {
t.Run("does udp4 only", func(t *testing.T) {
SetNetworkStack(IPv4Only)
nameserver := getTestNameserver(t, getNetwork("udp"))
defer nameserver.Server.Shutdown()

serverAddress := fmt.Sprintf("127.0.0.1:%d", nameserver.PortResult.Port)
recursiveNameservers = ParseNameservers([]string{serverAddress})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, serverAddress)
assert.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
})

t.Run("does udp6 only", func(t *testing.T) {
SetNetworkStack(IPv6Only)
nameserver := getTestNameserver(t, getNetwork("udp"))
defer nameserver.Server.Shutdown()

serverAddress := fmt.Sprintf("[::1]:%d", nameserver.PortResult.Port)
recursiveNameservers = ParseNameservers([]string{serverAddress})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, serverAddress)
assert.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
})

t.Run("does tcp4 and tcp6", func(t *testing.T) {
SetNetworkStack(DefaultNetworkStack)
nameserver := getTestNameserver(t, getNetwork("tcp"))
defer nameserver.Server.Shutdown()

serverAddress := fmt.Sprintf("[::1]:%d", nameserver.PortResult.Port)
recursiveNameservers = ParseNameservers([]string{serverAddress})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, serverAddress)
assert.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")

serverAddress = fmt.Sprintf("127.0.0.1:%d", nameserver.PortResult.Port)
recursiveNameservers = ParseNameservers([]string{serverAddress})
msg = createDNSMsg("example.com.", dns.TypeA, true)
result, queryError = sendDNSQuery(msg, serverAddress)
assert.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
})
}

func TestLookupNameserversOK(t *testing.T) {
testCases := []struct {
fqdn string
Expand Down
15 changes: 13 additions & 2 deletions challenge/http01/http_challenge_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ import (
"github.com/go-acme/lego/v4/log"
)

type ProviderNetwork string

const (
DefaultNetwork = "tcp"
Tcp4Network = "tcp4"
Tcp6Network = "tcp6"
)

// ProviderServer implements ChallengeProvider for `http-01` challenge.
// It may be instantiated without using the NewProviderServer function if
// you want only to use the default values.
Expand All @@ -29,12 +37,15 @@ type ProviderServer struct {
// NewProviderServer creates a new ProviderServer on the selected interface and port.
// Setting iface and / or port to an empty string will make the server fall back to
// the "any" interface and port 80 respectively.
func NewProviderServer(iface, port string) *ProviderServer {
func NewProviderServer(iface, port string, network ProviderNetwork) *ProviderServer {
if port == "" {
port = "80"
}
if network == "" {
network = DefaultNetwork
}

return &ProviderServer{network: "tcp", address: net.JoinHostPort(iface, port), matcher: &hostMatcher{}}
return &ProviderServer{network: string(network), address: net.JoinHostPort(iface, port), matcher: &hostMatcher{}}
}

func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer {
Expand Down
22 changes: 16 additions & 6 deletions challenge/http01/http_challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,27 @@ func TestProviderServer_GetAddress(t *testing.T) {
}{
{
desc: "TCP default address",
server: NewProviderServer("", ""),
server: NewProviderServer("", "", ""),
expected: ":80",
},
{
desc: "TCP with explicit port",
server: NewProviderServer("", "8080"),
server: NewProviderServer("", "8080", ""),
expected: ":8080",
},
{
desc: "TCP with host and port",
server: NewProviderServer("localhost", "8080"),
server: NewProviderServer("localhost", "8080", ""),
expected: "localhost:8080",
},
{
desc: "TCP4 with host and port",
server: NewProviderServer("localhost", "8080", Tcp4Network),
expected: "localhost:8080",
},
{
desc: "TCP6 with host and port",
server: NewProviderServer("localhost", "8080", Tcp6Network),
expected: "localhost:8080",
},
{
Expand All @@ -70,7 +80,7 @@ func TestProviderServer_GetAddress(t *testing.T) {
func TestChallenge(t *testing.T) {
_, apiURL := tester.SetupFakeAPI(t)

providerServer := NewProviderServer("", "23457")
providerServer := NewProviderServer("", "23457", "")

validate := func(_ *api.Core, _ string, chlng acme.Challenge) error {
uri := "http://localhost" + providerServer.GetAddress() + ChallengePath(chlng.Token)
Expand Down Expand Up @@ -199,7 +209,7 @@ func TestChallengeInvalidPort(t *testing.T) {

validate := func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }

solver := NewChallenge(core, validate, NewProviderServer("", "123456"))
solver := NewChallenge(core, validate, NewProviderServer("", "123456", ""))

authz := acme.Authorization{
Identifier: acme.Identifier{
Expand Down Expand Up @@ -374,7 +384,7 @@ func testServeWithProxy(t *testing.T, header, extra *testProxyHeader, expectErro

_, apiURL := tester.SetupFakeAPI(t)

providerServer := NewProviderServer("localhost", "23457")
providerServer := NewProviderServer("localhost", "23457", "")
if header != nil {
providerServer.SetProxyHeader(header.name)
}
Expand Down
23 changes: 20 additions & 3 deletions challenge/tlsalpn01/tls_alpn_challenge_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ import (
"github.com/go-acme/lego/v4/log"
)

type ProviderNetwork string

const (
DefaultNetwork = "tcp"
Tcp4Network = "tcp4"
Tcp6Network = "tcp6"
)

const (
// ACMETLS1Protocol is the ALPN Protocol ID for the ACME-TLS/1 Protocol.
ACMETLS1Protocol = "acme-tls/1"
Expand All @@ -26,14 +34,23 @@ const (
type ProviderServer struct {
iface string
port string
network string
listener net.Listener
}

// NewProviderServer creates a new ProviderServer on the selected interface and port.
// Setting iface and / or port to an empty string will make the server fall back to
// the "any" interface and port 443 respectively.
func NewProviderServer(iface, port string) *ProviderServer {
return &ProviderServer{iface: iface, port: port}
func NewProviderServer(iface, port string, network ProviderNetwork) *ProviderServer {
if port == "" {
port = defaultTLSPort
}

if network == "" {
network = DefaultNetwork
}

return &ProviderServer{iface: iface, port: port, network: string(network)}
}

func (s *ProviderServer) GetAddress() string {
Expand Down Expand Up @@ -65,7 +82,7 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error {
tlsConf.NextProtos = []string{ACMETLS1Protocol}

// Create the listener with the created tls.Config.
s.listener, err = tls.Listen("tcp", s.GetAddress(), tlsConf)
s.listener, err = tls.Listen(s.network, s.GetAddress(), tlsConf)
if err != nil {
return fmt.Errorf("could not start HTTPS server for challenge: %w", err)
}
Expand Down
Loading

0 comments on commit cdf3a6a

Please sign in to comment.