From 85ca3e6c3f78d62b57806b0232e5aa15f9d3f827 Mon Sep 17 00:00:00 2001 From: ArcticLampyrid Date: Mon, 8 Jul 2024 20:23:49 +0800 Subject: [PATCH] fix: additional records should be used for unrequired records Link: https://datatracker.ietf.org/doc/html/rfc6763#section-12 --- src/dns_parser/builder.rs | 47 ++++++++- src/dns_parser/mod.rs | 2 +- src/fsm.rs | 204 ++++++++++++++++++++++++-------------- src/services.rs | 40 +++----- 4 files changed, 190 insertions(+), 103 deletions(-) diff --git a/src/dns_parser/builder.rs b/src/dns_parser/builder.rs index 0ff1a53..2d6be73 100644 --- a/src/dns_parser/builder.rs +++ b/src/dns_parser/builder.rs @@ -195,6 +195,21 @@ impl> Builder { builder } + + pub fn add_answers<'a, 'b>( + self, + name: &Name, + cls: QueryClass, + ttl: u32, + data: impl Iterator> + 'a, + ) -> Builder { + let mut builder = self.move_to::(); + for item in data { + builder.write_rr(name, cls, ttl, &item); + Header::inc_answers(&mut builder.buf).expect("Too many answers"); + } + builder + } } impl> Builder { @@ -213,10 +228,25 @@ impl> Builder { builder } + + #[allow(dead_code)] + pub fn add_nameservers<'a, 'b>( + self, + name: &Name, + cls: QueryClass, + ttl: u32, + data: impl Iterator> + 'a, + ) -> Builder { + let mut builder = self.move_to::(); + for item in data { + builder.write_rr(name, cls, ttl, &item); + Header::inc_nameservers(&mut builder.buf).expect("Too many nameservers"); + } + builder + } } impl> Builder { - #[allow(dead_code)] pub fn add_additional( self, name: &Name, @@ -231,6 +261,21 @@ impl> Builder { builder } + + pub fn add_additionals<'a, 'b>( + self, + name: &Name, + cls: QueryClass, + ttl: u32, + data: impl Iterator> + 'a, + ) -> Builder { + let mut builder = self.move_to::(); + for item in data { + builder.write_rr(name, cls, ttl, &item); + Header::inc_additional(&mut builder.buf).expect("Too many additional answers"); + } + builder + } } #[cfg(test)] diff --git a/src/dns_parser/mod.rs b/src/dns_parser/mod.rs index 23234a3..e4c2709 100644 --- a/src/dns_parser/mod.rs +++ b/src/dns_parser/mod.rs @@ -12,4 +12,4 @@ pub use self::header::Header; mod rrdata; pub use self::rrdata::RRData; mod builder; -pub use self::builder::{Answers, Builder, Questions}; +pub use self::builder::*; diff --git a/src/fsm.rs b/src/fsm.rs index b429de0..b7eef17 100644 --- a/src/fsm.rs +++ b/src/fsm.rs @@ -18,9 +18,10 @@ use tokio::{net::UdpSocket, sync::mpsc}; use super::{DEFAULT_TTL, MDNS_PORT}; use crate::address_family::AddressFamily; -use crate::services::{ServiceData, Services}; +use crate::services::{ServiceData, Services, ServicesInner}; pub type AnswerBuilder = dns_parser::Builder; +pub type AdditionalBuilder = dns_parser::Builder; const SERVICE_TYPE_ENUMERATION_NAME: Cow<'static, str> = Cow::Borrowed("_services._dns-sd._udp.local"); @@ -104,14 +105,6 @@ impl FSM { return; } - let mut unicast_builder = dns_parser::Builder::new_response(packet.header.id, false, true) - .move_to::(); - let mut multicast_builder = - dns_parser::Builder::new_response(packet.header.id, false, true) - .move_to::(); - unicast_builder.set_max_size(None); - multicast_builder.set_max_size(None); - for question in packet.questions { debug!( "received question: {:?} {}", @@ -119,42 +112,39 @@ impl FSM { ); if question.qclass == QueryClass::IN || question.qclass == QueryClass::Any { + let mut builder = dns_parser::Builder::new_response(packet.header.id, false, true) + .move_to::(); + builder.set_max_size(None); + let builder = self.handle_question(&question, builder); + if builder.is_empty() { + continue; + } + let response = builder.build().unwrap_or_else(|x| x); if question.qu { - unicast_builder = self.handle_question(&question, unicast_builder); + self.outgoing.push_back((response, addr)); } else { - multicast_builder = self.handle_question(&question, multicast_builder); + let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT); + self.outgoing.push_back((response, addr)); } } } - - if !multicast_builder.is_empty() { - let response = multicast_builder.build().unwrap_or_else(|x| x); - let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT); - self.outgoing.push_back((response, addr)); - } - - if !unicast_builder.is_empty() { - let response = unicast_builder.build().unwrap_or_else(|x| x); - self.outgoing.push_back((response, addr)); - } } /// https://www.rfc-editor.org/rfc/rfc6763#section-9 fn handle_service_type_enumeration<'a>( question: &dns_parser::Question, - services: impl Iterator, + services: &ServicesInner, mut builder: AnswerBuilder, ) -> AnswerBuilder { let service_type_enumeration_name = Name::FromStr(SERVICE_TYPE_ENUMERATION_NAME); if question.qname == service_type_enumeration_name { - for svc in services { - let svc_type = ServiceData { - name: svc.typ.clone(), - typ: service_type_enumeration_name.clone(), - port: svc.port, - txt: vec![], - }; - builder = svc_type.add_ptr_rr(builder, DEFAULT_TTL); + for typ in services.all_types() { + builder = builder.add_answer( + &service_type_enumeration_name, + QueryClass::IN, + DEFAULT_TTL, + &RRData::PTR(typ.clone()), + ); } } @@ -165,93 +155,147 @@ impl FSM { &self, question: &dns_parser::Question, mut builder: AnswerBuilder, - ) -> AnswerBuilder { + ) -> AdditionalBuilder { let services = self.services.read().unwrap(); let hostname = services.get_hostname(); match question.qtype { - QueryType::A | QueryType::AAAA if question.qname == *hostname => { - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); - } + QueryType::A | QueryType::AAAA if question.qname == *hostname => builder + .add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr()) + .move_to(), QueryType::All => { + let mut include_ip_additionals = false; // A / AAAA if question.qname == *hostname { - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder = + builder.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr()); } // PTR - builder = - Self::handle_service_type_enumeration(question, services.into_iter(), builder); + builder = Self::handle_service_type_enumeration(question, &services, builder); for svc in services.find_by_type(&question.qname) { - builder = svc.add_ptr_rr(builder, DEFAULT_TTL); - builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL); - builder = svc.add_txt_rr(builder, DEFAULT_TTL); - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder = + builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr()); + include_ip_additionals = true; } // SRV if let Some(svc) = services.find_by_name(&question.qname) { - builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL); - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder = builder + .add_answer( + &svc.name, + QueryClass::IN, + DEFAULT_TTL, + &svc.srv_rr(hostname), + ) + .add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr()); + include_ip_additionals = true; + } + let mut builder = builder.move_to::(); + // PTR (additional) + for svc in services.find_by_type(&question.qname) { + builder = builder + .add_additional( + &svc.name, + QueryClass::IN, + DEFAULT_TTL, + &svc.srv_rr(hostname), + ) + .add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr()); + include_ip_additionals = true; + } + + if include_ip_additionals { + builder = builder.add_additionals( + hostname, + QueryClass::IN, + DEFAULT_TTL, + self.ip_rr(), + ); } + builder } QueryType::PTR => { - builder = - Self::handle_service_type_enumeration(question, services.into_iter(), builder); + let mut include_ip_additionals = false; + let mut builder = + Self::handle_service_type_enumeration(question, &services, builder); for svc in services.find_by_type(&question.qname) { - builder = svc.add_ptr_rr(builder, DEFAULT_TTL); - builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL); - builder = svc.add_txt_rr(builder, DEFAULT_TTL); - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder = + builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr()) } + let mut builder = builder.move_to::(); + for svc in services.find_by_type(&question.qname) { + builder = builder + .add_additional( + &svc.name, + QueryClass::IN, + DEFAULT_TTL, + &svc.srv_rr(hostname), + ) + .add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr()); + include_ip_additionals = true; + } + if include_ip_additionals { + builder = builder.add_additionals( + hostname, + QueryClass::IN, + DEFAULT_TTL, + self.ip_rr(), + ); + } + builder } QueryType::SRV => { if let Some(svc) = services.find_by_name(&question.qname) { - builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL); - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder + .add_answer( + &svc.name, + QueryClass::IN, + DEFAULT_TTL, + &svc.srv_rr(hostname), + ) + .add_additionals(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr()) + .move_to() + } else { + builder.move_to() } } QueryType::TXT => { if let Some(svc) = services.find_by_name(&question.qname) { - builder = svc.add_txt_rr(builder, DEFAULT_TTL); + builder + .add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr()) + .move_to() + } else { + builder.move_to() } } - _ => (), + _ => builder.move_to(), } - - builder } - fn add_ip_rr(&self, hostname: &Name, mut builder: AnswerBuilder, ttl: u32) -> AnswerBuilder { + fn ip_rr(&self) -> impl Iterator> + '_ { let interfaces = match get_if_addrs() { Ok(interfaces) => interfaces, Err(err) => { error!("could not get list of interfaces: {}", err); - return builder; + vec![] } }; - - for iface in interfaces { + interfaces.into_iter().filter_map(move |iface| { if iface.is_loopback() { - continue; + return None; } trace!("found interface {:?}", iface); if !self.allowed_ip.is_empty() && !self.allowed_ip.contains(&iface.ip()) { trace!(" -> interface dropped"); - continue; + return None; } match (iface.ip(), AF::DOMAIN) { - (IpAddr::V4(ip), Domain::IPV4) => { - builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::A(ip)) - } - (IpAddr::V6(ip), Domain::IPV6) => { - builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::AAAA(ip)) - } - _ => (), + (IpAddr::V4(ip), Domain::IPV4) => Some(RRData::A(ip)), + (IpAddr::V6(ip), Domain::IPV6) => Some(RRData::AAAA(ip)), + _ => None, } - } - - builder + }) } fn send_unsolicited(&mut self, svc: &ServiceData, ttl: u32, include_ip: bool) { @@ -261,11 +305,17 @@ impl FSM { let services = self.services.read().unwrap(); - builder = svc.add_ptr_rr(builder, ttl); - builder = svc.add_srv_rr(services.get_hostname(), builder, ttl); - builder = svc.add_txt_rr(builder, ttl); + builder = builder.add_answer(&svc.typ, QueryClass::IN, ttl, &svc.ptr_rr()); + builder = builder.add_answer( + &svc.name, + QueryClass::IN, + ttl, + &svc.srv_rr(services.get_hostname()), + ); + builder = builder.add_answer(&svc.name, QueryClass::IN, ttl, &svc.txt_rr()); if include_ip { - builder = self.add_ip_rr(services.get_hostname(), builder, ttl); + builder = + builder.add_answers(services.get_hostname(), QueryClass::IN, ttl, self.ip_rr()); } if !builder.is_empty() { @@ -349,7 +399,7 @@ mod tests { answer_builder = FSM::::handle_service_type_enumeration( &question, - services.read().unwrap().into_iter(), + &services.read().unwrap(), answer_builder, ); diff --git a/src/services.rs b/src/services.rs index 7800437..2c0c5a9 100644 --- a/src/services.rs +++ b/src/services.rs @@ -1,12 +1,10 @@ -use crate::dns_parser::{self, Name, QueryClass, RRData}; +use crate::dns_parser::{Name, RRData}; use multimap::MultiMap; use rand::{thread_rng, Rng}; use std::collections::HashMap; use std::slice; use std::sync::{Arc, RwLock}; -pub type AnswerBuilder = dns_parser::Builder; - /// A collection of registered services is shared between threads. pub type Services = Arc>; @@ -81,6 +79,10 @@ impl ServicesInner { svc } + + pub fn all_types(&self) -> impl Iterator> { + self.by_type.keys() + } } impl<'a> IntoIterator for &'a ServicesInner { @@ -119,30 +121,20 @@ pub struct ServiceData { /// Packet building helpers for `fsm` to respond with `ServiceData` impl ServiceData { - pub fn add_ptr_rr(&self, builder: AnswerBuilder, ttl: u32) -> AnswerBuilder { - builder.add_answer( - &self.typ, - QueryClass::IN, - ttl, - &RRData::PTR(self.name.clone()), - ) + pub fn ptr_rr(&self) -> RRData { + RRData::PTR(self.name.clone()) } - pub fn add_srv_rr(&self, hostname: &Name, builder: AnswerBuilder, ttl: u32) -> AnswerBuilder { - builder.add_answer( - &self.name, - QueryClass::IN, - ttl, - &RRData::SRV { - priority: 0, - weight: 0, - port: self.port, - target: hostname.clone(), - }, - ) + pub fn srv_rr<'a>(&self, hostname: &'a Name) -> RRData<'a> { + RRData::SRV { + priority: 0, + weight: 0, + port: self.port, + target: hostname.clone(), + } } - pub fn add_txt_rr(&self, builder: AnswerBuilder, ttl: u32) -> AnswerBuilder { - builder.add_answer(&self.name, QueryClass::IN, ttl, &RRData::TXT(&self.txt)) + pub fn txt_rr(&self) -> RRData { + RRData::TXT(&self.txt) } }