Skip to content

Commit

Permalink
fix: additional records should be used for unrequired records
Browse files Browse the repository at this point in the history
  • Loading branch information
ArcticLampyrid committed Jul 8, 2024
1 parent d605749 commit 61fe298
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 103 deletions.
47 changes: 46 additions & 1 deletion src/dns_parser/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ impl<T: MoveTo<Answers>> Builder<T> {

builder
}

pub fn add_answers<'a, 'b>(
self,
name: &Name,
cls: QueryClass,
ttl: u32,
data: impl Iterator<Item = RRData<'b>> + 'a,
) -> Builder<Answers> {
let mut builder = self.move_to::<Answers>();
for item in data {
builder.write_rr(name, cls, ttl, &item);
Header::inc_answers(&mut builder.buf).expect("Too many answers");
}
builder
}
}

impl<T: MoveTo<Nameservers>> Builder<T> {
Expand All @@ -213,10 +228,25 @@ impl<T: MoveTo<Nameservers>> Builder<T> {

builder
}

#[allow(dead_code)]
pub fn add_nameservers<'a, 'b>(
self,
name: &Name,
cls: QueryClass,
ttl: u32,
data: impl Iterator<Item = RRData<'b>> + 'a,
) -> Builder<Nameservers> {
let mut builder = self.move_to::<Nameservers>();
for item in data {
builder.write_rr(name, cls, ttl, &item);
Header::inc_nameservers(&mut builder.buf).expect("Too many nameservers");
}
builder
}
}

impl<T: MoveTo<Additional>> Builder<T> {
#[allow(dead_code)]
pub fn add_additional(
self,
name: &Name,
Expand All @@ -231,6 +261,21 @@ impl<T: MoveTo<Additional>> Builder<T> {

builder
}

pub fn add_additionals<'a, 'b>(
self,
name: &Name,
cls: QueryClass,
ttl: u32,
data: impl Iterator<Item = RRData<'b>> + 'a,
) -> Builder<Additional> {
let mut builder = self.move_to::<Additional>();
for item in data {
builder.write_rr(name, cls, ttl, &item);
Header::inc_additional(&mut builder.buf).expect("Too many additional answers");
}
builder
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion src/dns_parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ pub use self::header::Header;
mod rrdata;
pub use self::rrdata::RRData;
mod builder;
pub use self::builder::{Answers, Builder, Questions};
pub use self::builder::*;
193 changes: 116 additions & 77 deletions src/fsm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ use tokio::{net::UdpSocket, sync::mpsc};

use super::{DEFAULT_TTL, MDNS_PORT};
use crate::address_family::AddressFamily;
use crate::services::{ServiceData, Services};
use crate::services::{ServiceData, Services, ServicesInner};

pub type AnswerBuilder = dns_parser::Builder<dns_parser::Answers>;
pub type AdditionalBuilder = dns_parser::Builder<dns_parser::Additional>;

const SERVICE_TYPE_ENUMERATION_NAME: Cow<'static, str> =
Cow::Borrowed("_services._dns-sd._udp.local");
Expand Down Expand Up @@ -104,57 +105,46 @@ impl<AF: AddressFamily> FSM<AF> {
return;
}

let mut unicast_builder = dns_parser::Builder::new_response(packet.header.id, false, true)
.move_to::<dns_parser::Answers>();
let mut multicast_builder =
dns_parser::Builder::new_response(packet.header.id, false, true)
.move_to::<dns_parser::Answers>();
unicast_builder.set_max_size(None);
multicast_builder.set_max_size(None);

for question in packet.questions {
debug!(
"received question: {:?} {}",
question.qclass, question.qname
);

if question.qclass == QueryClass::IN || question.qclass == QueryClass::Any {
let mut builder = dns_parser::Builder::new_response(packet.header.id, false, true)
.move_to::<dns_parser::Answers>();
builder.set_max_size(None);
let builder = self.handle_question(&question, builder);
if builder.is_empty() {
continue;
}
let response = builder.build().unwrap_or_else(|x| x);
if question.qu {
unicast_builder = self.handle_question(&question, unicast_builder);
self.outgoing.push_back((response, addr));
} else {
multicast_builder = self.handle_question(&question, multicast_builder);
let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT);
self.outgoing.push_back((response, addr));
}
}
}

if !multicast_builder.is_empty() {
let response = multicast_builder.build().unwrap_or_else(|x| x);
let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT);
self.outgoing.push_back((response, addr));
}

if !unicast_builder.is_empty() {
let response = unicast_builder.build().unwrap_or_else(|x| x);
self.outgoing.push_back((response, addr));
}
}

/// https://www.rfc-editor.org/rfc/rfc6763#section-9
fn handle_service_type_enumeration<'a>(
question: &dns_parser::Question,
services: impl Iterator<Item = &'a ServiceData>,
services: &ServicesInner,
mut builder: AnswerBuilder,
) -> AnswerBuilder {
let service_type_enumeration_name = Name::FromStr(SERVICE_TYPE_ENUMERATION_NAME);
if question.qname == service_type_enumeration_name {
for svc in services {
let svc_type = ServiceData {
name: svc.typ.clone(),
typ: service_type_enumeration_name.clone(),
port: svc.port,
txt: vec![],
};
builder = svc_type.add_ptr_rr(builder, DEFAULT_TTL);
for typ in services.all_types() {
builder = builder.add_answer(
&service_type_enumeration_name,
QueryClass::IN,
DEFAULT_TTL,
&RRData::PTR(typ.clone()),
);
}
}

Expand All @@ -165,93 +155,136 @@ impl<AF: AddressFamily> FSM<AF> {
&self,
question: &dns_parser::Question,
mut builder: AnswerBuilder,
) -> AnswerBuilder {
) -> AdditionalBuilder {
let services = self.services.read().unwrap();
let hostname = services.get_hostname();

match question.qtype {
QueryType::A | QueryType::AAAA if question.qname == *hostname => {
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
}
QueryType::A | QueryType::AAAA if question.qname == *hostname => builder
.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr())
.move_to(),
QueryType::All => {
let mut include_ip_additionals = false;
// A / AAAA
if question.qname == *hostname {
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder =
builder.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr());
}
// PTR
builder =
Self::handle_service_type_enumeration(question, services.into_iter(), builder);
builder = Self::handle_service_type_enumeration(question, &services, builder);
for svc in services.find_by_type(&question.qname) {
builder = svc.add_ptr_rr(builder, DEFAULT_TTL);
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder =
builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr());
include_ip_additionals = true;
}
// SRV
if let Some(svc) = services.find_by_name(&question.qname) {
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder = builder.add_answer(
&svc.name,
QueryClass::IN,
DEFAULT_TTL,
&svc.srv_rr(hostname),
);
include_ip_additionals = true;
}
let mut builder = builder.move_to::<dns_parser::Additional>();
// PTR (additional)
for svc in services.find_by_type(&question.qname) {
builder = builder
.add_additional(
&svc.name,
QueryClass::IN,
DEFAULT_TTL,
&svc.srv_rr(hostname),
)
.add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr());
include_ip_additionals = true;
}

if include_ip_additionals {
builder = builder.add_additionals(
hostname,
QueryClass::IN,
DEFAULT_TTL,
self.ip_rr(),
);
}
builder
}
QueryType::PTR => {
builder =
Self::handle_service_type_enumeration(question, services.into_iter(), builder);
let mut builder =
Self::handle_service_type_enumeration(question, &services, builder);
for svc in services.find_by_type(&question.qname) {
builder =
builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr())
}
let mut builder = builder.move_to::<dns_parser::Additional>();
for svc in services.find_by_type(&question.qname) {
builder = svc.add_ptr_rr(builder, DEFAULT_TTL);
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder = builder
.add_additional(
&svc.name,
QueryClass::IN,
DEFAULT_TTL,
&svc.srv_rr(hostname),
)
.add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr())
.add_additionals(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr());
}
builder
}
QueryType::SRV => {
if let Some(svc) = services.find_by_name(&question.qname) {
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder
.add_answer(
&svc.name,
QueryClass::IN,
DEFAULT_TTL,
&svc.srv_rr(hostname),
)
.add_additionals(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr())
.move_to()
} else {
builder.move_to()
}
}
QueryType::TXT => {
if let Some(svc) = services.find_by_name(&question.qname) {
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
builder
.add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr())
.move_to()
} else {
builder.move_to()
}
}
_ => (),
_ => builder.move_to(),
}

builder
}

fn add_ip_rr(&self, hostname: &Name, mut builder: AnswerBuilder, ttl: u32) -> AnswerBuilder {
fn ip_rr(&self) -> impl Iterator<Item = RRData<'static>> + '_ {
let interfaces = match get_if_addrs() {
Ok(interfaces) => interfaces,
Err(err) => {
error!("could not get list of interfaces: {}", err);
return builder;
vec![]
}
};

for iface in interfaces {
interfaces.into_iter().filter_map(move |iface| {
if iface.is_loopback() {
continue;
return None;
}

trace!("found interface {:?}", iface);
if !self.allowed_ip.is_empty() && !self.allowed_ip.contains(&iface.ip()) {
trace!(" -> interface dropped");
continue;
return None;
}

match (iface.ip(), AF::DOMAIN) {
(IpAddr::V4(ip), Domain::IPV4) => {
builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::A(ip))
}
(IpAddr::V6(ip), Domain::IPV6) => {
builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::AAAA(ip))
}
_ => (),
(IpAddr::V4(ip), Domain::IPV4) => Some(RRData::A(ip)),
(IpAddr::V6(ip), Domain::IPV6) => Some(RRData::AAAA(ip)),
_ => None,
}
}

builder
})
}

fn send_unsolicited(&mut self, svc: &ServiceData, ttl: u32, include_ip: bool) {
Expand All @@ -261,11 +294,17 @@ impl<AF: AddressFamily> FSM<AF> {

let services = self.services.read().unwrap();

builder = svc.add_ptr_rr(builder, ttl);
builder = svc.add_srv_rr(services.get_hostname(), builder, ttl);
builder = svc.add_txt_rr(builder, ttl);
builder = builder.add_answer(&svc.typ, QueryClass::IN, ttl, &svc.ptr_rr());
builder = builder.add_answer(
&svc.name,
QueryClass::IN,
ttl,
&svc.srv_rr(services.get_hostname()),
);
builder = builder.add_answer(&svc.name, QueryClass::IN, ttl, &svc.txt_rr());
if include_ip {
builder = self.add_ip_rr(services.get_hostname(), builder, ttl);
builder =
builder.add_answers(services.get_hostname(), QueryClass::IN, ttl, self.ip_rr());
}

if !builder.is_empty() {
Expand Down Expand Up @@ -349,7 +388,7 @@ mod tests {

answer_builder = FSM::<Inet>::handle_service_type_enumeration(
&question,
services.read().unwrap().into_iter(),
&services.read().unwrap(),
answer_builder,
);

Expand Down
Loading

0 comments on commit 61fe298

Please sign in to comment.