Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unrequired records should be sent as additional records #57

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions src/dns_parser/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ impl<T: MoveTo<Answers>> Builder<T> {

builder
}

pub fn add_answers<'a, 'b>(
self,
name: &Name,
cls: QueryClass,
ttl: u32,
data: impl Iterator<Item = RRData<'b>> + 'a,
) -> Builder<Answers> {
let mut builder = self.move_to::<Answers>();
for item in data {
builder.write_rr(name, cls, ttl, &item);
Header::inc_answers(&mut builder.buf).expect("Too many answers");
}
builder
}
}

impl<T: MoveTo<Nameservers>> Builder<T> {
Expand All @@ -213,10 +228,25 @@ impl<T: MoveTo<Nameservers>> Builder<T> {

builder
}
}

impl Builder<Additional> {
#[allow(dead_code)]
pub fn add_nameservers<'a, 'b>(
self,
name: &Name,
cls: QueryClass,
ttl: u32,
data: impl Iterator<Item = RRData<'b>> + 'a,
) -> Builder<Nameservers> {
let mut builder = self.move_to::<Nameservers>();
for item in data {
builder.write_rr(name, cls, ttl, &item);
Header::inc_nameservers(&mut builder.buf).expect("Too many nameservers");
}
builder
}
}

impl<T: MoveTo<Additional>> Builder<T> {
pub fn add_additional(
self,
name: &Name,
Expand All @@ -227,8 +257,23 @@ impl Builder<Additional> {
let mut builder = self.move_to::<Additional>();

builder.write_rr(name, cls, ttl, data);
Header::inc_nameservers(&mut builder.buf).expect("Too many additional answers");
Header::inc_additional(&mut builder.buf).expect("Too many additional answers");

builder
}

pub fn add_additionals<'a, 'b>(
self,
name: &Name,
cls: QueryClass,
ttl: u32,
data: impl Iterator<Item = RRData<'b>> + 'a,
) -> Builder<Additional> {
let mut builder = self.move_to::<Additional>();
for item in data {
builder.write_rr(name, cls, ttl, &item);
Header::inc_additional(&mut builder.buf).expect("Too many additional answers");
}
builder
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/dns_parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ pub use self::header::Header;
mod rrdata;
pub use self::rrdata::RRData;
mod builder;
pub use self::builder::{Answers, Builder, Questions};
pub use self::builder::*;
204 changes: 127 additions & 77 deletions src/fsm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ use tokio::{net::UdpSocket, sync::mpsc};

use super::{DEFAULT_TTL, MDNS_PORT};
use crate::address_family::AddressFamily;
use crate::services::{ServiceData, Services};
use crate::services::{ServiceData, Services, ServicesInner};

pub type AnswerBuilder = dns_parser::Builder<dns_parser::Answers>;
pub type AdditionalBuilder = dns_parser::Builder<dns_parser::Additional>;

const SERVICE_TYPE_ENUMERATION_NAME: Cow<'static, str> =
Cow::Borrowed("_services._dns-sd._udp.local");
Expand Down Expand Up @@ -104,57 +105,46 @@ impl<AF: AddressFamily> FSM<AF> {
return;
}

let mut unicast_builder = dns_parser::Builder::new_response(packet.header.id, false, true)
.move_to::<dns_parser::Answers>();
let mut multicast_builder =
dns_parser::Builder::new_response(packet.header.id, false, true)
.move_to::<dns_parser::Answers>();
unicast_builder.set_max_size(None);
multicast_builder.set_max_size(None);

for question in packet.questions {
debug!(
"received question: {:?} {}",
question.qclass, question.qname
);

if question.qclass == QueryClass::IN || question.qclass == QueryClass::Any {
let mut builder = dns_parser::Builder::new_response(packet.header.id, false, true)
.move_to::<dns_parser::Answers>();
builder.set_max_size(None);
let builder = self.handle_question(&question, builder);
if builder.is_empty() {
continue;
}
let response = builder.build().unwrap_or_else(|x| x);
if question.qu {
unicast_builder = self.handle_question(&question, unicast_builder);
self.outgoing.push_back((response, addr));
} else {
multicast_builder = self.handle_question(&question, multicast_builder);
let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT);
self.outgoing.push_back((response, addr));
}
}
}

if !multicast_builder.is_empty() {
let response = multicast_builder.build().unwrap_or_else(|x| x);
let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT);
self.outgoing.push_back((response, addr));
}

if !unicast_builder.is_empty() {
let response = unicast_builder.build().unwrap_or_else(|x| x);
self.outgoing.push_back((response, addr));
}
}

/// https://www.rfc-editor.org/rfc/rfc6763#section-9
fn handle_service_type_enumeration<'a>(
question: &dns_parser::Question,
services: impl Iterator<Item = &'a ServiceData>,
services: &ServicesInner,
mut builder: AnswerBuilder,
) -> AnswerBuilder {
let service_type_enumeration_name = Name::FromStr(SERVICE_TYPE_ENUMERATION_NAME);
if question.qname == service_type_enumeration_name {
for svc in services {
let svc_type = ServiceData {
name: svc.typ.clone(),
typ: service_type_enumeration_name.clone(),
port: svc.port,
txt: vec![],
};
builder = svc_type.add_ptr_rr(builder, DEFAULT_TTL);
for typ in services.all_types() {
builder = builder.add_answer(
&service_type_enumeration_name,
QueryClass::IN,
DEFAULT_TTL,
&RRData::PTR(typ.clone()),
);
}
}

Expand All @@ -165,93 +155,147 @@ impl<AF: AddressFamily> FSM<AF> {
&self,
question: &dns_parser::Question,
mut builder: AnswerBuilder,
) -> AnswerBuilder {
) -> AdditionalBuilder {
let services = self.services.read().unwrap();
let hostname = services.get_hostname();

match question.qtype {
QueryType::A | QueryType::AAAA if question.qname == *hostname => {
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
}
QueryType::A | QueryType::AAAA if question.qname == *hostname => builder
.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr())
.move_to(),
QueryType::All => {
let mut include_ip_additionals = false;
// A / AAAA
if question.qname == *hostname {
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder =
builder.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr());
}
// PTR
builder =
Self::handle_service_type_enumeration(question, services.into_iter(), builder);
builder = Self::handle_service_type_enumeration(question, &services, builder);
for svc in services.find_by_type(&question.qname) {
builder = svc.add_ptr_rr(builder, DEFAULT_TTL);
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder =
builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr());
include_ip_additionals = true;
}
// SRV
if let Some(svc) = services.find_by_name(&question.qname) {
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder = builder
.add_answer(
&svc.name,
QueryClass::IN,
DEFAULT_TTL,
&svc.srv_rr(hostname),
)
.add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr());
include_ip_additionals = true;
}
let mut builder = builder.move_to::<dns_parser::Additional>();
// PTR (additional)
for svc in services.find_by_type(&question.qname) {
builder = builder
.add_additional(
&svc.name,
QueryClass::IN,
DEFAULT_TTL,
&svc.srv_rr(hostname),
)
.add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr());
include_ip_additionals = true;
}

if include_ip_additionals {
builder = builder.add_additionals(
hostname,
QueryClass::IN,
DEFAULT_TTL,
self.ip_rr(),
);
}
builder
}
QueryType::PTR => {
builder =
Self::handle_service_type_enumeration(question, services.into_iter(), builder);
let mut include_ip_additionals = false;
let mut builder =
Self::handle_service_type_enumeration(question, &services, builder);
for svc in services.find_by_type(&question.qname) {
builder = svc.add_ptr_rr(builder, DEFAULT_TTL);
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder =
builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr())
}
let mut builder = builder.move_to::<dns_parser::Additional>();
for svc in services.find_by_type(&question.qname) {
builder = builder
.add_additional(
&svc.name,
QueryClass::IN,
DEFAULT_TTL,
&svc.srv_rr(hostname),
)
.add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr());
include_ip_additionals = true;
}
if include_ip_additionals {
builder = builder.add_additionals(
hostname,
QueryClass::IN,
DEFAULT_TTL,
self.ip_rr(),
);
}
builder
}
QueryType::SRV => {
if let Some(svc) = services.find_by_name(&question.qname) {
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
builder
.add_answer(
&svc.name,
QueryClass::IN,
DEFAULT_TTL,
&svc.srv_rr(hostname),
)
.add_additionals(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr())
.move_to()
} else {
builder.move_to()
}
}
QueryType::TXT => {
if let Some(svc) = services.find_by_name(&question.qname) {
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
builder
.add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr())
.move_to()
} else {
builder.move_to()
}
}
_ => (),
_ => builder.move_to(),
}

builder
}

fn add_ip_rr(&self, hostname: &Name, mut builder: AnswerBuilder, ttl: u32) -> AnswerBuilder {
fn ip_rr(&self) -> impl Iterator<Item = RRData<'static>> + '_ {
let interfaces = match get_if_addrs() {
Ok(interfaces) => interfaces,
Err(err) => {
error!("could not get list of interfaces: {}", err);
return builder;
vec![]
}
};

for iface in interfaces {
interfaces.into_iter().filter_map(move |iface| {
if iface.is_loopback() {
continue;
return None;
}

trace!("found interface {:?}", iface);
if !self.allowed_ip.is_empty() && !self.allowed_ip.contains(&iface.ip()) {
trace!(" -> interface dropped");
continue;
return None;
}

match (iface.ip(), AF::DOMAIN) {
(IpAddr::V4(ip), Domain::IPV4) => {
builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::A(ip))
}
(IpAddr::V6(ip), Domain::IPV6) => {
builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::AAAA(ip))
}
_ => (),
(IpAddr::V4(ip), Domain::IPV4) => Some(RRData::A(ip)),
(IpAddr::V6(ip), Domain::IPV6) => Some(RRData::AAAA(ip)),
_ => None,
}
}

builder
})
}

fn send_unsolicited(&mut self, svc: &ServiceData, ttl: u32, include_ip: bool) {
Expand All @@ -261,11 +305,17 @@ impl<AF: AddressFamily> FSM<AF> {

let services = self.services.read().unwrap();

builder = svc.add_ptr_rr(builder, ttl);
builder = svc.add_srv_rr(services.get_hostname(), builder, ttl);
builder = svc.add_txt_rr(builder, ttl);
builder = builder.add_answer(&svc.typ, QueryClass::IN, ttl, &svc.ptr_rr());
builder = builder.add_answer(
&svc.name,
QueryClass::IN,
ttl,
&svc.srv_rr(services.get_hostname()),
);
builder = builder.add_answer(&svc.name, QueryClass::IN, ttl, &svc.txt_rr());
if include_ip {
builder = self.add_ip_rr(services.get_hostname(), builder, ttl);
builder =
builder.add_answers(services.get_hostname(), QueryClass::IN, ttl, self.ip_rr());
}

if !builder.is_empty() {
Expand Down Expand Up @@ -349,7 +399,7 @@ mod tests {

answer_builder = FSM::<Inet>::handle_service_type_enumeration(
&question,
services.read().unwrap().into_iter(),
&services.read().unwrap(),
answer_builder,
);

Expand Down
Loading
Loading