Skip to content

Commit

Permalink
dnsdist: Add EDNS to responses generated from raw record data
Browse files Browse the repository at this point in the history
My reasoning is that it makes sense to add EDNS to responses generated
from DNSdist provided that:
- the initial query had EDNS
- `setAddEDNSToSelfGeneratedResponses` has not been set to `false`
- we are only provided part of the response and not a full response
  packet

(cherry picked from commit cae561a)
  • Loading branch information
rgacogne committed Sep 30, 2024
1 parent 02338bb commit 0fa402a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 22 deletions.
4 changes: 1 addition & 3 deletions pdns/dnsdist-lua-actions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,6 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string*
static_assert(recordstart.size() == 12, "sizeof(recordstart) must be equal to 12, otherwise the above check is invalid");
memcpy(&recordstart[4], &qclass, sizeof(qclass));
memcpy(&recordstart[6], &ttl, sizeof(ttl));
bool raw = false;

if (qtype == QType::CNAME) {
const auto& wireData = d_cname.getStorage(); // Note! This doesn't do compression!
Expand Down Expand Up @@ -975,7 +974,6 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string*
return true;
});
}
raw = true;
}
else {
for (const auto& addr : addrs) {
Expand Down Expand Up @@ -1007,7 +1005,7 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string*
return true;
});

if (hadEDNS && !raw) {
if (hadEDNS) {
addEDNS(dnsquestion->getMutableData(), dnsquestion->getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0);
}

Expand Down
84 changes: 65 additions & 19 deletions regression-tests.dnsdist/test_Spoofing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,33 @@ def testSpoofActionA(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)

def testSpoofActionAWithEDNS(self):
"""
Spoofing: Spoof A via Action (EDNS)
Send an A query to "spoofaction.spoofing.tests.powerdns.com.",
check that dnsdist sends a spoofed result.
"""
name = 'spoofaction.spoofing.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN', use_edns=True)
# dnsdist set RA = RD for spoofed responses
query.flags &= ~dns.flags.RD
expectedResponse = dns.message.make_response(query)
expectedResponse.use_edns(edns=True, payload=1232)
rrset = dns.rrset.from_text(name,
60,
dns.rdataclass.IN,
dns.rdatatype.A,
'192.0.2.1')
expectedResponse.answer.append(rrset)

for method in ("sendUDPQuery", "sendTCPQuery"):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)

def testSpoofActionAAAA(self):
"""
Expand Down Expand Up @@ -101,7 +127,7 @@ def testSpoofActionCNAME(self):

def testSpoofActionMultiA(self):
"""
Spoofing: Spoof multiple IPv4 addresses via AddDomainSpoof
Spoofing: Spoof multiple IPv4 addresses
Send an A query for "multispoof.spoofing.tests.powerdns.com.",
check that dnsdist sends a spoofed result.
Expand All @@ -126,7 +152,7 @@ def testSpoofActionMultiA(self):

def testSpoofActionMultiAAAA(self):
"""
Spoofing: Spoof multiple IPv6 addresses via AddDomainSpoof
Spoofing: Spoof multiple IPv6 addresses
Send an AAAA query for "multispoof.spoofing.tests.powerdns.com.",
check that dnsdist sends a spoofed result.
Expand All @@ -151,7 +177,7 @@ def testSpoofActionMultiAAAA(self):

def testSpoofActionMultiANY(self):
"""
Spoofing: Spoof multiple addresses via AddDomainSpoof
Spoofing: Spoof multiple addresses
Send an ANY query for "multispoof.spoofing.tests.powerdns.com.",
check that dnsdist sends a spoofed result.
Expand Down Expand Up @@ -320,7 +346,27 @@ def testSpoofRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

# A with EDNS
query = dns.message.make_query(name, 'A', 'IN', use_edns=True)
query.flags &= ~dns.flags.RD
expectedResponse = dns.message.make_response(query)
expectedResponse.use_edns(edns=True, payload=1232)
expectedResponse.flags &= ~dns.flags.AA
rrset = dns.rrset.from_text(name,
60,
dns.rdataclass.IN,
dns.rdatatype.A,
'192.0.2.1')
expectedResponse.answer.append(rrset)

for method in ("sendUDPQuery", "sendTCPQuery"):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

# TXT
Expand All @@ -339,7 +385,7 @@ def testSpoofRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

# SRV
Expand All @@ -359,7 +405,7 @@ def testSpoofRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 3600)

def testSpoofRawChaosAction(self):
Expand All @@ -384,7 +430,7 @@ def testSpoofRawChaosAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

def testSpoofRawANYAction(self):
Expand All @@ -408,7 +454,7 @@ def testSpoofRawANYAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

def testSpoofRawActionMulti(self):
Expand All @@ -433,7 +479,7 @@ def testSpoofRawActionMulti(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

# TXT
Expand All @@ -452,7 +498,7 @@ def testSpoofRawActionMulti(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

class TestSpoofingLuaSpoof(DNSDistTest):
Expand Down Expand Up @@ -617,7 +663,7 @@ def testLuaSpoofRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

# TXT
Expand All @@ -636,7 +682,7 @@ def testLuaSpoofRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

# SRV
Expand All @@ -656,7 +702,7 @@ def testLuaSpoofRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
# sorry, we can't set the TTL from the Lua API right now
#self.assertEqual(receivedResponse.answer[0].ttl, 3600)

Expand Down Expand Up @@ -769,7 +815,7 @@ def testLuaSpoofMultiRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

# TXT
Expand All @@ -788,7 +834,7 @@ def testLuaSpoofMultiRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

# SRV
Expand All @@ -808,7 +854,7 @@ def testLuaSpoofMultiRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
# sorry, we can't set the TTL from the Lua API right now
#self.assertEqual(receivedResponse.answer[0].ttl, 3600)

Expand Down Expand Up @@ -878,7 +924,7 @@ def testLuaSpoofMultiRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

# TXT
Expand All @@ -897,7 +943,7 @@ def testLuaSpoofMultiRawAction(self):
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertTrue(receivedResponse)
self.assertEqual(expectedResponse, receivedResponse)
self.checkMessageNoEDNS(expectedResponse, receivedResponse)
self.assertEqual(receivedResponse.answer[0].ttl, 60)

class TestSpoofingLuaWithStatistics(DNSDistTest):
Expand Down

0 comments on commit 0fa402a

Please sign in to comment.