diff --git a/Cargo.toml b/Cargo.toml index 9a1bef7..db06d05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ description = "no_std and no-alloc async implementations of various network prot repository = "https://github.com/ivmarkov/edge-net" license = "MIT OR Apache-2.0" readme = "README.md" -rust-version = "1.77" +rust-version = "1.78" [features] default = ["io"] @@ -109,8 +109,7 @@ embedded-io-async = { version = "0.6", default-features = false } embedded-svc = { version = "0.28", default-features = false } log = { version = "0.4", default-features = false } heapless = { version = "0.8", default-features = false } -domain = { version = "0.9.3", default-features = false, features = ["heapless"] } -octseq = { version = "0.3.2", default-features = false } +domain = { version = "0.10", default-features = false, features = ["heapless"] } edge-captive = { version = "0.2.0", path = "edge-captive", default-features = false } edge-dhcp = { version = "0.2.0", path = "edge-dhcp", default-features = false } diff --git a/edge-captive/Cargo.toml b/edge-captive/Cargo.toml index b62fe81..e6de78b 100644 --- a/edge-captive/Cargo.toml +++ b/edge-captive/Cargo.toml @@ -23,5 +23,4 @@ io = ["edge-nal"] [dependencies] log = { workspace = true } domain = { workspace = true } -octseq = { workspace = true } edge-nal = { workspace = true, optional = true } \ No newline at end of file diff --git a/edge-captive/src/lib.rs b/edge-captive/src/lib.rs index d56cb80..06ab9f7 100644 --- a/edge-captive/src/lib.rs +++ b/edge-captive/src/lib.rs @@ -5,7 +5,7 @@ use core::fmt::{self, Display}; use core::time::Duration; use domain::base::wire::Composer; -use domain::dep::octseq::OctetsBuilder; +use domain::dep::octseq::{OctetsBuilder, Truncate}; use log::debug; use domain::{ @@ -20,7 +20,6 @@ use domain::{ dep::octseq::ShortBuf, rdata::A, }; -use octseq::Truncate; #[cfg(feature = "io")] pub mod io; @@ -80,15 +79,15 @@ pub fn reply( let mut responseb = domain::base::MessageBuilder::from_target(buf)?; - let buf = if matches!(message.header().opcode(), Opcode::Query) { + let buf = if matches!(message.header().opcode(), Opcode::QUERY) { debug!("Message is of type Query, processing all questions"); - let mut answerb = responseb.start_answer(&message, Rcode::NoError)?; + let mut answerb = responseb.start_answer(&message, Rcode::NOERROR)?; for question in message.question() { let question = question?; - if matches!(question.qtype(), Rtype::A) && matches!(question.qclass(), Class::In) { + if matches!(question.qtype(), Rtype::A) && matches!(question.qclass(), Class::IN) { log::info!( "Question {:?} is of type A, answering with IP {:?}, TTL {:?}", question, @@ -98,7 +97,7 @@ pub fn reply( let record = Record::new( question.qname(), - Class::In, + Class::IN, Ttl::from_duration_lossy(ttl), A::from_octets(ip[0], ip[1], ip[2], ip[3]), ); @@ -118,7 +117,7 @@ pub fn reply( headerb.set_id(message.header().id()); headerb.set_opcode(message.header().opcode()); headerb.set_rd(message.header().rd()); - headerb.set_rcode(domain::base::iana::Rcode::NotImp); + headerb.set_rcode(domain::base::iana::Rcode::NOTIMP); responseb.finish() }; diff --git a/edge-mdns/Cargo.toml b/edge-mdns/Cargo.toml index ca9c296..184603e 100644 --- a/edge-mdns/Cargo.toml +++ b/edge-mdns/Cargo.toml @@ -23,8 +23,6 @@ io = ["embassy-futures", "embassy-sync", "embassy-time", "edge-nal"] log = { workspace = true } heapless = { workspace = true } domain = { workspace = true } -octseq = { workspace = true } -heapless07 = { package = "heapless", version = "0.7" } embassy-futures = { workspace = true, optional = true } embassy-sync = { workspace = true, optional = true } embassy-time = { workspace = true, optional = true } diff --git a/edge-mdns/README.md b/edge-mdns/README.md index 2d5b5ce..d446623 100644 --- a/edge-mdns/README.md +++ b/edge-mdns/README.md @@ -13,14 +13,22 @@ For other protocols, look at the [edge-net](https://github.com/ivmarkov/edge-net ## Example ```rust -use core::net::Ipv4Addr; +use core::net::{Ipv4Addr, Ipv6Addr}; -use edge_mdns::io::{self, MdnsIoError, MdnsRunBuffers, DEFAULT_SOCKET}; -use edge_mdns::Host; -use edge_nal::{Multicast, UdpBind, UdpSplit}; +use edge_mdns::buf::BufferAccess; +use edge_mdns::domain::base::Ttl; +use edge_mdns::io::{self, MdnsIoError, DEFAULT_SOCKET}; +use edge_mdns::{host::Host, HostAnswersMdnsHandler}; +use edge_nal::{UdpBind, UdpSplit}; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; + +use embassy_sync::mutex::Mutex; +use embassy_sync::signal::Signal; use log::*; +use rand::{thread_rng, RngCore}; + // Change this to the IP address of the machine where you'll run this example const OUR_IP: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1); @@ -33,45 +41,59 @@ fn main() { let stack = edge_nal_std::Stack::new(); - let mut buffers = MdnsRunBuffers::new(); + let (recv_buf, send_buf) = ( + Mutex::::new([0; 1500]), + Mutex::::new([0; 1500]), + ); - futures_lite::future::block_on(run::( - &stack, - &mut buffers, - OUR_NAME, - OUR_IP, + futures_lite::future::block_on(run::( + &stack, &recv_buf, &send_buf, OUR_NAME, OUR_IP, )) .unwrap(); } -async fn run( +async fn run( stack: &T, - buffers: &mut MdnsRunBuffers, + recv_buf: RB, + send_buf: SB, our_name: &str, our_ip: Ipv4Addr, ) -> Result<(), MdnsIoError> where T: UdpBind, - for<'a> ::Socket<'a>: Multicast + UdpSplit, + RB: BufferAccess, + SB: BufferAccess, + RB::BufferSurface: AsMut<[u8]>, + SB::BufferSurface: AsMut<[u8]>, { info!("About to run an mDNS responder for our PC. It will be addressable using {our_name}.local, so try to `ping {our_name}.local`."); + let mut socket = io::bind(stack, DEFAULT_SOCKET, Some(Ipv4Addr::UNSPECIFIED), Some(0)).await?; + + let (recv, send) = socket.split(); + let host = Host { - id: 0, hostname: our_name, - ip: our_ip.octets(), - ipv6: None, + ipv4: our_ip, + ipv6: Ipv6Addr::UNSPECIFIED, + ttl: Ttl::from_secs(60), }; - io::run( - &host, + // A way to notify the mDNS responder that the data in `Host` had changed + // We don't use it in this example, because the data is hard-coded + let signal = Signal::new(); + + let mdns = io::Mdns::::new( Some(Ipv4Addr::UNSPECIFIED), Some(0), - [], - stack, - DEFAULT_SOCKET, - buffers, - ) - .await + recv, + send, + recv_buf, + send_buf, + |buf| thread_rng().fill_bytes(buf), + &signal, + ); + + mdns.run(HostAnswersMdnsHandler::new(&host)).await } ``` diff --git a/edge-mdns/src/buf.rs b/edge-mdns/src/buf.rs new file mode 100644 index 0000000..c336574 --- /dev/null +++ b/edge-mdns/src/buf.rs @@ -0,0 +1,106 @@ +use core::ops::{Deref, DerefMut}; + +use embassy_sync::{ + blocking_mutex::raw::RawMutex, + mutex::{Mutex, MutexGuard}, +}; + +/// A trait for getting access to a `&mut T` buffer, potentially awaiting until a buffer becomes available. +pub trait BufferAccess +where + T: ?Sized, +{ + type Buffer<'a>: DerefMut + where + Self: 'a; + + /// Get a reference to a buffer. + /// Might await until a buffer is available, as it might be in use by somebody else. + /// + /// Depending on its internal implementation details, access to a buffer might also be denied + /// immediately, or after a certain amount of time (subject to the concrete implementation of the method). + /// In that case, the method will return `None`. + async fn get(&self) -> Option>; +} + +impl BufferAccess for &B +where + B: BufferAccess, + T: ?Sized, +{ + type Buffer<'a> = B::Buffer<'a> where Self: 'a; + + async fn get(&self) -> Option> { + (*self).get().await + } +} + +pub struct VecBufAccess(Mutex>) +where + M: RawMutex; + +impl VecBufAccess +where + M: RawMutex, +{ + pub const fn new() -> Self { + Self(Mutex::new(heapless::Vec::new())) + } +} + +pub struct VecBuf<'a, M, const N: usize>(MutexGuard<'a, M, heapless::Vec>) +where + M: RawMutex; + +impl<'a, M, const N: usize> Drop for VecBuf<'a, M, N> +where + M: RawMutex, +{ + fn drop(&mut self) { + self.0.clear(); + } +} + +impl<'a, M, const N: usize> Deref for VecBuf<'a, M, N> +where + M: RawMutex, +{ + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, M, const N: usize> DerefMut for VecBuf<'a, M, N> +where + M: RawMutex, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl BufferAccess<[u8]> for VecBufAccess +where + M: RawMutex, +{ + type Buffer<'a> = VecBuf<'a, M, N> where Self: 'a; + + async fn get(&self) -> Option> { + let mut guard = self.0.lock().await; + + guard.resize_default(N).unwrap(); + + Some(VecBuf(guard)) + } +} + +impl Default for VecBufAccess +where + M: RawMutex, +{ + fn default() -> Self { + Self::new() + } +} diff --git a/edge-mdns/src/host.rs b/edge-mdns/src/host.rs new file mode 100644 index 0000000..5d28d0a --- /dev/null +++ b/edge-mdns/src/host.rs @@ -0,0 +1,190 @@ +use core::net::{Ipv4Addr, Ipv6Addr}; + +use crate::domain::base::{iana::Class, Record, Ttl}; +use crate::domain::rdata::{Aaaa, AllRecordData, Ptr, Srv, A}; + +use crate::{HostAnswer, HostAnswers, MdnsError, NameSlice, RecordDataChain, Txt, DNS_SD_OWNER}; + +/// A simple representation of a host that can be used to generate mDNS answers. +/// +/// This structure implements the `HostAnswers` trait, which allows it to be used +/// as a responder for mDNS queries coming from other network peers. +#[derive(Debug, Clone)] +pub struct Host<'a> { + /// The name of the host. I.e. a name "foo" will be pingable as "foo.local" + pub hostname: &'a str, + /// The IPv4 address of the host. + /// Leaving it as `Ipv4Addr::UNSPECIFIED` means that the host will not aswer it to A queries. + pub ipv4: Ipv4Addr, + /// The IPv6 address of the host. + /// Leaving it as `Ipv6Addr::UNSPECIFIED` means that the host will not aswer it to AAAA queries. + pub ipv6: Ipv6Addr, + /// The time-to-live of the mDNS answers. + pub ttl: Ttl, +} + +impl<'a> Host<'a> { + fn visit_answers(&self, mut f: F) -> Result<(), E> + where + F: FnMut(HostAnswer) -> Result<(), E>, + E: From, + { + let owner = &[self.hostname, "local"]; + + if !self.ipv4.is_unspecified() { + f(Record::new( + NameSlice::new(owner), + Class::IN, + self.ttl, + RecordDataChain::Next(AllRecordData::A(A::new(domain::base::net::Ipv4Addr::from( + self.ipv4.octets(), + )))), + ))?; + } + + if !self.ipv6.is_unspecified() { + f(Record::new( + NameSlice::new(owner), + Class::IN, + self.ttl, + RecordDataChain::Next(AllRecordData::Aaaa(Aaaa::new( + domain::base::net::Ipv6Addr::from(self.ipv6.octets()), + ))), + ))?; + } + + Ok(()) + } +} + +impl<'a> HostAnswers for Host<'a> { + fn visit(&self, mut f: F) -> Result<(), E> + where + F: FnMut(HostAnswer) -> Result<(), E>, + E: From, + { + self.visit_answers(&mut f) + } +} + +/// A simple representation of a DNS-SD service that can be used to generate mDNS answers. +/// +/// This structure (indirectly - via the `ServiceAnswers` wraper which also provides the hostname) +/// implements the `HostAnswers` trait, which allows it to be used as a responder for mDNS queries +/// coming from other network peers. +#[derive(Debug, Clone)] +pub struct Service<'a> { + /// The name of the service. + pub name: &'a str, + /// The priority of the service. + pub priority: u16, + /// The weight of the service. + pub weight: u16, + /// The service type. I.e. "_http" + pub service: &'a str, + /// The protocol of the service. I.e. "_tcp" or "_udp" + pub protocol: &'a str, + /// The TCP/UDP port where the service listens for incoming requests. + pub port: u16, + /// The subtypes of the service, if any. + pub service_subtypes: &'a [&'a str], + /// The key-value pairs that will be included in the TXT record, as per the DNS-SD spec. + pub txt_kvs: &'a [(&'a str, &'a str)], +} + +impl<'a> Service<'a> { + fn visit_answers(&self, host: &Host, mut f: F) -> Result<(), E> + where + F: FnMut(HostAnswer) -> Result<(), E>, + E: From, + { + let owner = &[self.name, self.service, self.protocol, "local"]; + let stype = &[self.service, self.protocol, "local"]; + let target = &[host.hostname, "local"]; + + f(Record::new( + NameSlice::new(owner), + Class::IN, + host.ttl, + RecordDataChain::Next(AllRecordData::Srv(Srv::new( + self.priority, + self.weight, + self.port, + NameSlice::new(target), + ))), + ))?; + + f(Record::new( + NameSlice::new(owner), + Class::IN, + host.ttl, + RecordDataChain::This(Txt::new(self.txt_kvs)), + ))?; + + f(Record::new( + DNS_SD_OWNER, + Class::IN, + host.ttl, + RecordDataChain::Next(AllRecordData::Ptr(Ptr::new(NameSlice::new(stype)))), + ))?; + + f(Record::new( + NameSlice::new(stype), + Class::IN, + host.ttl, + RecordDataChain::Next(AllRecordData::Ptr(Ptr::new(NameSlice::new(owner)))), + ))?; + + for subtype in self.service_subtypes { + let subtype_owner = &[subtype, self.name, self.service, self.protocol, "local"]; + let subtype = &[subtype, "_sub", self.service, self.protocol, "local"]; + + f(Record::new( + NameSlice::new(subtype_owner), + Class::IN, + host.ttl, + RecordDataChain::Next(AllRecordData::Ptr(Ptr::new(NameSlice::new(owner)))), + ))?; + + f(Record::new( + NameSlice::new(subtype), + Class::IN, + host.ttl, + RecordDataChain::Next(AllRecordData::Ptr(Ptr::new(NameSlice::new(subtype_owner)))), + ))?; + + f(Record::new( + DNS_SD_OWNER, + Class::IN, + host.ttl, + RecordDataChain::Next(AllRecordData::Ptr(Ptr::new(NameSlice::new(subtype)))), + ))?; + } + + Ok(()) + } +} + +/// A wrapper around a `Service` that also provides the Host of the service +/// and thus allows the `HostAnswers` trait contract to be fullfilled for a `Service` instance. +pub struct ServiceAnswers<'a> { + host: &'a Host<'a>, + service: &'a Service<'a>, +} + +impl<'a> ServiceAnswers<'a> { + /// Create a new `ServiceAnswers` instance. + pub const fn new(host: &'a Host<'a>, service: &'a Service<'a>) -> Self { + Self { host, service } + } +} + +impl<'a> HostAnswers for ServiceAnswers<'a> { + fn visit(&self, mut f: F) -> Result<(), E> + where + F: FnMut(HostAnswer) -> Result<(), E>, + E: From, + { + self.service.visit_answers(self.host, &mut f) + } +} diff --git a/edge-mdns/src/io.rs b/edge-mdns/src/io.rs index 06c957b..ce1fe8d 100644 --- a/edge-mdns/src/io.rs +++ b/edge-mdns/src/io.rs @@ -1,32 +1,41 @@ +use core::cell::RefCell; use core::fmt; use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use core::pin::pin; +use buf::BufferAccess; use embassy_futures::select::{select, Either}; -use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; +use embassy_sync::blocking_mutex; +use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_sync::mutex::Mutex; -use embassy_time::{Duration, Timer}; +use embassy_sync::signal::Signal; -use edge_nal::{MulticastV4, MulticastV6, UdpBind, UdpReceive, UdpSend, UdpSplit}; +use edge_nal::{MulticastV4, MulticastV6, Readable, UdpBind, UdpReceive, UdpSend}; +use embassy_time::{Duration, Timer}; use log::{info, warn}; use super::*; +/// A quick-and-dirty socket address that binds to a "default" interface. +/// Don't use in production code. pub const DEFAULT_SOCKET: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT); -const IP_BROADCAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251); -const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb); +/// The IPv4 mDNS broadcast address, as per spec. +pub const IP_BROADCAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251); +/// The IPv6 mDNS broadcast address, as per spec. +pub const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb); -const PORT: u16 = 5353; - -const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; -const MAX_RX_BUF_SIZE: usize = 1583; +/// The mDNS port, as per spec. +pub const PORT: u16 = 5353; +/// A wrapper for mDNS and IO errors. #[derive(Debug)] pub enum MdnsIoError { MdnsError(MdnsError), + NoRecvBufError, + NoSendBufError, IoError(E), } @@ -43,6 +52,8 @@ where fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::MdnsError(err) => write!(f, "mDNS error: {}", err), + Self::NoRecvBufError => write!(f, "No recv buf available"), + Self::NoSendBufError => write!(f, "No send buf available"), Self::IoError(err) => write!(f, "IO error: {}", err), } } @@ -51,97 +62,272 @@ where #[cfg(feature = "std")] impl std::error::Error for MdnsIoError where E: std::error::Error {} -pub struct MdnsRunBuffers { - tx_buf: core::mem::MaybeUninit<[u8; MAX_TX_BUF_SIZE]>, - rx_buf: core::mem::MaybeUninit<[u8; MAX_RX_BUF_SIZE]>, -} - -impl MdnsRunBuffers { - #[inline(always)] - pub const fn new() -> Self { - Self { - tx_buf: core::mem::MaybeUninit::uninit(), - rx_buf: core::mem::MaybeUninit::uninit(), - } - } -} - -impl Default for MdnsRunBuffers { - fn default() -> Self { - Self::new() - } -} - -pub async fn run<'s, T, S>( - host: &Host<'_>, +/// A utility method to bind a socket suitable for mDNS, by using the provided +/// stack and address, and optionally joining the provided interfaces via multicast. +/// +/// Note that mDNS is pointless without multicast, so at least one - or both - of the +/// ipv4 and ipv6 interfaces need to be provided. +pub async fn bind( + stack: &S, + addr: SocketAddr, ipv4_interface: Option, ipv6_interface: Option, - services: T, - stack: &S, - socket: SocketAddr, - buffers: &mut MdnsRunBuffers, -) -> Result<(), MdnsIoError> +) -> Result, MdnsIoError> where - T: IntoIterator> + Clone, S: UdpBind, - for<'a> S::Socket<'a>: - MulticastV4 + MulticastV6 + UdpSplit, { - let mut udp = stack.bind(socket).await.map_err(MdnsIoError::IoError)?; + let mut socket = stack.bind(addr).await.map_err(MdnsIoError::IoError)?; if let Some(v4) = ipv4_interface { - udp.join_v4(IP_BROADCAST_ADDR, v4) + socket + .join_v4(IP_BROADCAST_ADDR, v4) .await .map_err(MdnsIoError::IoError)?; } if let Some(v6) = ipv6_interface { - udp.join_v6(IPV6_BROADCAST_ADDR, v6) + socket + .join_v6(IPV6_BROADCAST_ADDR, v6) .await .map_err(MdnsIoError::IoError)?; } - let (recv, send) = udp.split(); + Ok(socket) +} + +/// Represents an mDNS service that can respond to queries using the provided handler. +/// +/// This structure is generic over the mDNS handler, the UDP receiver and sender, and the +/// raw mutex type. +/// +/// The handler is expected to be a type that implements the `MdnsHandler` trait, which +/// allows it to handle mDNS queries and generate responses, as well as to handle mDNS +/// responses to queries which we might have issues using the `query` method. +pub struct Mdns<'a, M, R, S, RB, SB> +where + M: RawMutex, +{ + ipv4_interface: Option, + ipv6_interface: Option, + recv: Mutex, + send: Mutex, + recv_buf: RB, + send_buf: SB, + rand: fn(&mut [u8]), + broadcast_signal: &'a Signal, +} + +impl<'a, M, R, S, RB, SB> Mdns<'a, M, R, S, RB, SB> +where + M: RawMutex, + R: UdpReceive + Readable, + S: UdpSend, + RB: BufferAccess<[u8]>, + SB: BufferAccess<[u8]>, +{ + /// Creates a new mDNS service with the provided handler, interfaces, and UDP receiver and sender. + #[allow(clippy::too_many_arguments)] + pub fn new( + ipv4_interface: Option, + ipv6_interface: Option, + recv: R, + send: S, + recv_buf: RB, + send_buf: SB, + rand: fn(&mut [u8]), + broadcast_signal: &'a Signal, + ) -> Self { + Self { + ipv4_interface, + ipv6_interface, + recv: Mutex::new(recv), + send: Mutex::new(send), + recv_buf, + send_buf, + rand, + broadcast_signal, + } + } + + /// Runs the mDNS service, handling queries and responding to them, as well as broadcasting + /// mDNS answers and handling responses to our own queries. + /// + /// All of the handling logic is expected to be implemented by the provided handler: + /// - I.e. hanbdling responses to our own queries cannot happen, unless the supplied handler + /// is capable of doing that (i.e. it is a `PeerMdnsHandler`, or a chain containing it, or similar). + /// - Ditto for handling queries coming from other peers - this can only happen if the handler + /// is capable of doing that. I.e., it is a `HostMdnsHandler`, or a chain containing it, or similar. + pub async fn run(&self, handler: T) -> Result<(), MdnsIoError> + where + T: MdnsHandler, + { + let handler = blocking_mutex::Mutex::::new(RefCell::new(handler)); + + let mut broadcast = pin!(self.broadcast(&handler)); + let mut respond = pin!(self.respond(&handler)); + + let result = select(&mut broadcast, &mut respond).await; + + match result { + Either::First(result) => result, + Either::Second(result) => result, + } + } + + /// Sends a multicast query with the provided payload. + /// It is assumed that the payload represents a valid mDNS query message. + /// + /// The payload is constructed via a closure, because this way we can provide to + /// the payload-constructing closure a ready-to-use `&mut [u8]` slice, where the + /// closure can arrange the mDNS query message (i.e. we avoid extra memory usage + /// by constructing the mDNS query directly in the `send_buf` buffer that was supplied + /// when the `Mdns` instance was constructed). + pub async fn query(&self, q: Q) -> Result<(), MdnsIoError> + where + Q: FnOnce(&mut [u8]) -> Result, + { + let mut send_buf = self + .send_buf + .get() + .await + .ok_or(MdnsIoError::NoSendBufError)?; + + let mut send_guard = self.send.lock().await; + let send = &mut *send_guard; - let send_buf: &mut [u8] = unsafe { buffers.tx_buf.assume_init_mut() }; - let recv_buf = unsafe { buffers.rx_buf.assume_init_mut() }; + let len = q(send_buf.as_mut())?; - let send = Mutex::::new((send, send_buf)); + if len > 0 { + self.broadcast_once(send, &send_buf.as_mut()[..len], true, true) + .await?; + } - let mut broadcast = pin!(broadcast( - host, - services.clone(), - ipv4_interface.is_some(), - ipv6_interface, - &send - )); - let mut respond = pin!(respond(host, services, recv, recv_buf, &send)); + Ok(()) + } - let result = select(&mut broadcast, &mut respond).await; + async fn broadcast( + &self, + handler: &blocking_mutex::Mutex>, + ) -> Result<(), MdnsIoError> + where + T: MdnsHandler, + { + loop { + { + let mut send_buf = self + .send_buf + .get() + .await + .ok_or(MdnsIoError::NoSendBufError)?; + + let mut send_guard = self.send.lock().await; + let send = &mut *send_guard; + + let response = handler.lock(|handler| { + handler + .borrow_mut() + .handle(MdnsRequest::None, send_buf.as_mut()) + })?; + + if let MdnsResponse::Reply { data, delay } = response { + if delay { + // TODO: Not ideal, as we hold the lock during the delay + self.delay().await; + } + + self.broadcast_once(send, data, true, true).await?; + } + } - match result { - Either::First(result) => result, - Either::Second(result) => result, + self.broadcast_signal.wait().await; + } } -} -async fn broadcast<'s, T, S>( - host: &Host<'_>, - services: T, - ipv4: bool, - ipv6_interface: Option, - send: &Mutex, -) -> Result<(), MdnsIoError> -where - T: IntoIterator> + Clone, - S: UdpSend, -{ - loop { + async fn respond( + &self, + handler: &blocking_mutex::Mutex>, + ) -> Result<(), MdnsIoError> + where + T: MdnsHandler, + { + let mut recv = self.recv.lock().await; + + loop { + recv.readable().await.map_err(MdnsIoError::IoError)?; + + { + let mut recv_buf = self + .recv_buf + .get() + .await + .ok_or(MdnsIoError::NoRecvBufError)?; + let mut send_buf = self + .send_buf + .get() + .await + .ok_or(MdnsIoError::NoSendBufError)?; + + let (len, remote) = recv + .receive(recv_buf.as_mut()) + .await + .map_err(MdnsIoError::IoError)?; + + debug!("Got mDNS query from {remote}"); + + let mut send_guard = self.send.lock().await; + let send = &mut *send_guard; + + let response = match handler.lock(|handler| { + handler.borrow_mut().handle( + MdnsRequest::Request { + data: &recv_buf.as_mut()[..len], + legacy: remote.port() != PORT, + multicast: true, // TODO: Cannot determine this + }, + send_buf.as_mut(), + ) + }) { + Ok(len) => len, + Err(err) => match err { + MdnsError::InvalidMessage => { + warn!("Got invalid message from {remote}, skipping"); + continue; + } + other => Err(other)?, + }, + }; + + if let MdnsResponse::Reply { data, delay } = response { + if delay { + self.delay().await; + } + + info!("Replying to mDNS query from {remote}"); + + self.broadcast_once( + send, + data, + matches!(remote, SocketAddr::V4(_)), + matches!(remote, SocketAddr::V6(_)), + ) + .await?; + } + } + } + } + + async fn broadcast_once( + &self, + send: &mut S, + data: &[u8], + ipv4: bool, + ipv6: bool, + ) -> Result<(), MdnsIoError> { for remote_addr in core::iter::once(SocketAddr::V4(SocketAddrV4::new(IP_BROADCAST_ADDR, PORT))) - .filter(|_| ipv4) + .filter(|_| ipv4 && self.ipv4_interface.is_some()) .chain( - ipv6_interface + self.ipv6_interface .map(|interface| { SocketAddr::V6(SocketAddrV6::new( IPV6_BROADCAST_ADDR, @@ -150,73 +336,29 @@ where interface, )) }) - .into_iter(), + .into_iter() + .filter(|_| ipv6), ) { - let mut guard = send.lock().await; - let (send, send_buf) = &mut *guard; - - let len = host.broadcast(services.clone(), send_buf, 60)?; - - if len > 0 { + if !data.is_empty() { info!("Broadcasting mDNS entry to {remote_addr}"); - let fut = pin!(send.send(remote_addr, &send_buf[..len])); + let fut = pin!(send.send(remote_addr, data)); fut.await.map_err(MdnsIoError::IoError)?; } } - Timer::after(Duration::from_secs(30)).await; + Ok(()) } -} -async fn respond<'s, T, R, S>( - host: &Host<'_>, - services: T, - mut recv: R, - recv_buf: &mut [u8], - send: &Mutex, -) -> Result<(), MdnsIoError> -where - T: IntoIterator> + Clone, - R: UdpReceive, - S: UdpSend, -{ - loop { - let (len, remote) = recv.receive(recv_buf).await.map_err(MdnsIoError::IoError)?; - - let mut guard = send.lock().await; - let (send, send_buf) = &mut *guard; - - let len = match host.respond(services.clone(), &recv_buf[..len], send_buf, 60) { - Ok(len) => len, - Err(err) => match err { - MdnsError::InvalidMessage => { - warn!("Got invalid message from {remote}, skipping"); - continue; - } - other => Err(other)?, - }, - }; + async fn delay(&self) { + let mut b = [0]; + (self.rand)(&mut b); - if len > 0 { - info!("Replying to mDNS query from {remote}"); - - let fut = pin!(send.send(remote, &send_buf[..len])); - - match fut.await { - Ok(_) => (), - Err(err) => { - // Turns out we might receive queries from Ipv6 addresses which are actually unreachable by us - // Still to be investigated why, but it does seem that we are receiving packets which contain - // non-link-local Ipv6 addresses, to which we cannot respond - // - // A possible reason for this might be that we are receiving these packets via the broadcast group - // - yet - it is still unclear how these arrive given that we are only listening on the link-local address - warn!("IO error {err:?} while replying to {remote}"); - } - } - } + // Generate a delay between 20 and 120 ms, as per spec + let delay_ms = 20 + (b[0] as u32 * 100 / 256); + + Timer::after(Duration::from_millis(delay_ms as _)).await; } } diff --git a/edge-mdns/src/lib.rs b/edge-mdns/src/lib.rs index 74ae741..4172b9e 100644 --- a/edge-mdns/src/lib.rs +++ b/edge-mdns/src/lib.rs @@ -1,27 +1,41 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(clippy::large_futures)] - -use core::fmt::{self, Display, Write}; - -use domain::{ - base::{ - header::Flags, - iana::Class, - message::ShortMessage, - message_builder::{AnswerBuilder, PushError}, - name::FromStrError, - wire::{Composer, ParseError}, - Dname, Message, MessageBuilder, Rtype, ToDname, - }, - dep::octseq::{OctetsBuilder, ShortBuf}, - rdata::{Aaaa, Ptr, Srv, Txt, A}, +#![allow(async_fn_in_trait)] + +use core::cmp::Ordering; +use core::fmt::{self, Display}; +use core::ops::RangeBounds; + +use domain::base::header::Flags; +use domain::base::iana::{Opcode, Rcode}; +use domain::base::message::ShortMessage; +use domain::base::message_builder::PushError; +use domain::base::name::{FromStrError, Label, ToLabelIter}; +use domain::base::rdata::ComposeRecordData; +use domain::base::wire::{Composer, ParseError}; +use domain::base::{ + Message, MessageBuilder, ParsedName, Question, Record, RecordData, Rtype, ToName, }; -use log::trace; -use octseq::Truncate; +use domain::dep::octseq::{FreezeBuilder, FromBuilder, Octets, OctetsBuilder, ShortBuf, Truncate}; +use domain::rdata::AllRecordData; + +use log::debug; +pub mod buf; // TODO: Maybe move to a generic `edge-buf` crate in future +/// Re-export the domain lib if the user would like to directly +/// assemble / parse mDNS messages. +pub mod domain { + pub use domain::*; +} +pub mod host; #[cfg(feature = "io")] pub mod io; +/// The DNS-SD owner name. +pub const DNS_SD_OWNER: NameSlice = NameSlice::new(&["_services", "_dns-sd", "_udp", "local"]); + +/// A wrapper type for the errors returned by the `domain` library during parsing and +/// constructing mDNS messages. #[derive(Debug)] pub enum MdnsError { ShortBuf, @@ -70,419 +84,258 @@ impl From for MdnsError { } } +/// This newtype struct allows the construction of a `domain` lib Name from +/// a bunch of `&str` labels represented as a slice. +/// +/// Implements the `domain` lib `ToName` trait. #[derive(Debug, Clone)] -pub struct Host<'a> { - pub id: u16, - pub hostname: &'a str, - pub ip: [u8; 4], - pub ipv6: Option<[u8; 16]>, -} - -impl<'a> Host<'a> { - pub fn broadcast<'s, T>( - &self, - services: T, - buf: &mut [u8], - ttl_sec: u32, - ) -> Result - where - T: IntoIterator> + Clone, - { - let buf = Buf(buf, 0); - - let message = MessageBuilder::from_target(buf)?; - - let mut answer = message.answer(); +pub struct NameSlice<'a>(&'a [&'a str]); - self.set_broadcast(services, &mut answer, ttl_sec)?; +impl<'a> NameSlice<'a> { + /// Create a new `NameSlice` instance from a slice of `&str` labels. + pub const fn new(labels: &'a [&'a str]) -> Self { + Self(labels) + } +} - let buf = answer.finish(); +impl<'a> fmt::Display for NameSlice<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for label in self.0 { + write!(f, "{}.", label)?; + } - Ok(buf.1) + Ok(()) } +} - pub fn respond<'s, T>( - &self, - services: T, - data: &[u8], - buf: &mut [u8], - ttl_sec: u32, - ) -> Result - where - T: IntoIterator> + Clone, - { - let buf = Buf(buf, 0); +impl<'a> ToName for NameSlice<'a> {} - let message = MessageBuilder::from_target(buf)?; +/// An iterator over the labels in a `NameSlice` instance. +#[derive(Clone)] +pub struct NameSliceIter<'a> { + name: &'a NameSlice<'a>, + index: usize, +} - let mut answer = message.answer(); +impl<'a> Iterator for NameSliceIter<'a> { + type Item = &'a Label; - if self.set_response(data, services, &mut answer, ttl_sec)? { - let buf = answer.finish(); + fn next(&mut self) -> Option { + match self.index.cmp(&self.name.0.len()) { + Ordering::Less => { + let label = Label::from_slice(self.name.0[self.index].as_bytes()).unwrap(); + self.index += 1; + Some(label) + } + Ordering::Equal => { + let label = Label::root(); + self.index += 1; + Some(label) + } + Ordering::Greater => None, + } + } +} - Ok(buf.1) +impl<'a> DoubleEndedIterator for NameSliceIter<'a> { + fn next_back(&mut self) -> Option { + if self.index > 0 { + self.index -= 1; + if self.index == self.name.0.len() { + let label = Label::root(); + Some(label) + } else { + let label = Label::from_slice(self.name.0[self.index].as_bytes()).unwrap(); + Some(label) + } } else { - Ok(0) + None } } +} - fn set_broadcast<'s, T, F>( - &self, - services: F, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result<(), MdnsError> - where - T: Composer, - F: IntoIterator> + Clone, - { - self.set_header(answer); - - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; +impl<'a> ToLabelIter for NameSlice<'a> { + type LabelIter<'t> = NameSliceIter<'t> where Self: 't; - for service in services.clone() { - service.add_service(answer, self.hostname, ttl_sec)?; - service.add_service_type(answer, ttl_sec)?; - service.add_service_subtypes(answer, ttl_sec)?; - service.add_txt(answer, ttl_sec)?; + fn iter_labels(&self) -> Self::LabelIter<'_> { + NameSliceIter { + name: self, + index: 0, } - - Ok(()) } +} - fn set_response<'s, T, F>( - &self, - data: &[u8], - services: F, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result - where - T: Composer, - F: IntoIterator> + Clone, - { - self.set_header(answer); - - let message = Message::from_octets(data)?; - - let mut replied = false; - - for question in message.question() { - trace!("Handling question {:?}", question); - - let question = question?; +/// A custom struct for representing a TXT data record off from a slice of +/// key-value `&str` pairs. +#[derive(Debug, Clone)] +pub struct Txt<'a>(&'a [(&'a str, &'a str)]); - match question.qtype() { - Rtype::A - if question - .qname() - .name_eq(&Host::host_fqdn(self.hostname, true)?) => - { - self.add_ipv4(answer, ttl_sec)?; - replied = true; - } - Rtype::Aaaa - if question - .qname() - .name_eq(&Host::host_fqdn(self.hostname, true)?) => - { - self.add_ipv6(answer, ttl_sec)?; - replied = true; - } - Rtype::Srv => { - for service in services.clone() { - if question.qname().name_eq(&service.service_fqdn(true)?) { - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; - service.add_service(answer, self.hostname, ttl_sec)?; - replied = true; - } - } - } - Rtype::Ptr => { - for service in services.clone() { - if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) { - service.add_service_type(answer, ttl_sec)?; - replied = true; - } else if question.qname().name_eq(&service.service_type_fqdn(true)?) { - // TODO - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; - service.add_service(answer, self.hostname, ttl_sec)?; - service.add_service_type(answer, ttl_sec)?; - service.add_service_subtypes(answer, ttl_sec)?; - service.add_txt(answer, ttl_sec)?; - replied = true; - } - } - } - Rtype::Txt => { - for service in services.clone() { - if question.qname().name_eq(&service.service_fqdn(true)?) { - service.add_txt(answer, ttl_sec)?; - replied = true; - } - } - } - Rtype::Any => { - // A / AAAA - if question - .qname() - .name_eq(&Host::host_fqdn(self.hostname, true)?) - { - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; - replied = true; - } +impl<'a> Txt<'a> { + pub const fn new(txt: &'a [(&'a str, &'a str)]) -> Self { + Self(txt) + } +} - // PTR - for service in services.clone() { - if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) { - service.add_service_type(answer, ttl_sec)?; - replied = true; - } else if question.qname().name_eq(&service.service_type_fqdn(true)?) { - // TODO - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; - service.add_service(answer, self.hostname, ttl_sec)?; - service.add_service_type(answer, ttl_sec)?; - service.add_service_subtypes(answer, ttl_sec)?; - service.add_txt(answer, ttl_sec)?; - replied = true; - } - } +impl<'a> fmt::Display for Txt<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Txt [")?; - // SRV - for service in services.clone() { - if question.qname().name_eq(&service.service_fqdn(true)?) { - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; - service.add_service(answer, self.hostname, ttl_sec)?; - replied = true; - } - } - } - _ => (), + for (i, (k, v)) in self.0.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; } + + write!(f, "{}={}", k, v)?; } - Ok(replied) + write!(f, "]")?; + + Ok(()) } +} - fn set_header(&self, answer: &mut AnswerBuilder) { - let header = answer.header_mut(); - header.set_id(self.id); - header.set_opcode(domain::base::iana::Opcode::Query); - header.set_rcode(domain::base::iana::Rcode::NoError); +impl<'a> RecordData for Txt<'a> { + fn rtype(&self) -> Rtype { + Rtype::TXT + } +} - let mut flags = Flags::new(); - flags.qr = true; - flags.aa = true; - header.set_flags(flags); +impl<'a> ComposeRecordData for Txt<'a> { + fn rdlen(&self, _compress: bool) -> Option { + None } - fn add_ipv4( + fn compose_rdata( &self, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result<(), PushError> { - answer.push(( - Self::host_fqdn(self.hostname, false).unwrap(), - Class::In, - ttl_sec, - A::from_octets(self.ip[0], self.ip[1], self.ip[2], self.ip[3]), - )) - } - - fn add_ipv6( - &self, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result<(), PushError> { - if let Some(ip) = &self.ipv6 { - answer.push(( - Self::host_fqdn(self.hostname, false).unwrap(), - Class::In, - ttl_sec, - Aaaa::new((*ip).into()), - )) + target: &mut Target, + ) -> Result<(), Target::AppendError> { + if self.0.is_empty() { + target.append_slice(&[0])?; } else { - Ok(()) + // TODO: Will not work for (k, v) pairs larger than 254 bytes in length + for (k, v) in self.0 { + target.append_slice(&[(k.len() + v.len() + 1) as u8])?; + target.append_slice(k.as_bytes())?; + target.append_slice(&[b'='])?; + target.append_slice(v.as_bytes())?; + } } - } - fn host_fqdn(hostname: &str, suffix: bool) -> Result { - let suffix = if suffix { "." } else { "" }; - - let mut host_fqdn = heapless07::String::<60>::new(); - write!(host_fqdn, "{}.local{}", hostname, suffix,).unwrap(); + Ok(()) + } - Dname::>::from_chars(host_fqdn.chars()) + fn compose_canonical_rdata( + &self, + target: &mut Target, + ) -> Result<(), Target::AppendError> { + self.compose_rdata(target) } } +/// A custom struct allowing to chain together multiple custom record data types. +/// Allows e.g. using the custom `Txt` struct from above and chain it with `domain`'s `AllRecordData`, #[derive(Debug, Clone)] -pub struct Service<'a> { - pub name: &'a str, - pub service: &'a str, - pub protocol: &'a str, - pub port: u16, - pub service_subtypes: &'a [&'a str], - pub txt_kvs: &'a [(&'a str, &'a str)], +pub enum RecordDataChain { + This(T), + Next(U), } -impl<'a> Service<'a> { - fn add_service( - &self, - answer: &mut AnswerBuilder, - hostname: &str, - ttl_sec: u32, - ) -> Result<(), PushError> { - answer.push(( - self.service_fqdn(false).unwrap(), - Class::In, - ttl_sec, - Srv::new(0, 0, self.port, Host::host_fqdn(hostname, false).unwrap()), - )) - } - - fn add_service_type( - &self, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result<(), PushError> { - answer.push(( - Self::dns_sd_fqdn(false).unwrap(), - Class::In, - ttl_sec, - Ptr::new(self.service_type_fqdn(false).unwrap()), - ))?; - - answer.push(( - self.service_type_fqdn(false).unwrap(), - Class::In, - ttl_sec, - Ptr::new(self.service_fqdn(false).unwrap()), - )) - } - - fn add_service_subtypes( - &self, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result<(), PushError> { - for service_subtype in self.service_subtypes { - self.add_service_subtype(answer, service_subtype, ttl_sec)?; +impl fmt::Display for RecordDataChain +where + T: fmt::Display, + U: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::This(data) => write!(f, "{}", data), + Self::Next(data) => write!(f, "{}", data), } + } +} - Ok(()) +impl RecordData for RecordDataChain +where + T: RecordData, + U: RecordData, +{ + fn rtype(&self) -> Rtype { + match self { + Self::This(data) => data.rtype(), + Self::Next(data) => data.rtype(), + } + } +} + +impl ComposeRecordData for RecordDataChain +where + T: ComposeRecordData, + U: ComposeRecordData, +{ + fn rdlen(&self, compress: bool) -> Option { + match self { + Self::This(data) => data.rdlen(compress), + Self::Next(data) => data.rdlen(compress), + } } - fn add_service_subtype( + fn compose_rdata( &self, - answer: &mut AnswerBuilder, - service_subtype: &str, - ttl_sec: u32, - ) -> Result<(), PushError> { - answer.push(( - Self::dns_sd_fqdn(false).unwrap(), - Class::In, - ttl_sec, - Ptr::new(self.service_subtype_fqdn(service_subtype, false).unwrap()), - ))?; - - answer.push(( - self.service_subtype_fqdn(service_subtype, false).unwrap(), - Class::In, - ttl_sec, - Ptr::new(self.service_fqdn(false).unwrap()), - )) - } - - fn add_txt( + target: &mut Target, + ) -> Result<(), Target::AppendError> { + match self { + Self::This(data) => data.compose_rdata(target), + Self::Next(data) => data.compose_rdata(target), + } + } + + fn compose_canonical_rdata( &self, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result<(), PushError> { - // only way I found to create multiple parts in a Txt - // each slice is the length and then the data - let mut octets = heapless07::Vec::<_, 256>::new(); - //octets.append_slice(&[1u8, b'X'])?; - //octets.append_slice(&[2u8, b'A', b'B'])?; - //octets.append_slice(&[0u8])?; - for (k, v) in self.txt_kvs { - octets.append_slice(&[(k.len() + v.len() + 1) as u8])?; - octets.append_slice(k.as_bytes())?; - octets.append_slice(&[b'='])?; - octets.append_slice(v.as_bytes())?; + target: &mut Target, + ) -> Result<(), Target::AppendError> { + match self { + Self::This(data) => data.compose_canonical_rdata(target), + Self::Next(data) => data.compose_canonical_rdata(target), } + } +} - let txt = Txt::from_octets(&mut octets).unwrap(); +/// This struct allows one to use a regular `&mut [u8]` slice as an octet buffer +/// with the `domain` library. +/// +/// Useful when a `domain` message needs to be constructed in a `&mut [u8]` slice. +pub struct Buf<'a>(pub &'a mut [u8], pub usize); - answer.push((self.service_fqdn(false).unwrap(), Class::In, ttl_sec, txt)) +impl<'a> Buf<'a> { + /// Create a new `Buf` instance from a mutable slice. + pub fn new(buf: &'a mut [u8]) -> Self { + Self(buf, 0) } +} - fn service_fqdn(&self, suffix: bool) -> Result { - let suffix = if suffix { "." } else { "" }; - - let mut service_fqdn = heapless07::String::<60>::new(); - write!( - service_fqdn, - "{}.{}.{}.local{}", - self.name, self.service, self.protocol, suffix, - ) - .unwrap(); +impl<'a> FreezeBuilder for Buf<'a> { + type Octets = Self; - Dname::>::from_chars(service_fqdn.chars()) + fn freeze(self) -> Self { + self } +} - fn service_type_fqdn(&self, suffix: bool) -> Result { - let suffix = if suffix { "." } else { "" }; - - let mut service_type_fqdn = heapless07::String::<60>::new(); - write!( - service_type_fqdn, - "{}.{}.local{}", - self.service, self.protocol, suffix, - ) - .unwrap(); +impl<'a> Octets for Buf<'a> { + type Range<'r> = &'r [u8] where Self: 'r; - Dname::>::from_chars(service_type_fqdn.chars()) + fn range(&self, range: impl RangeBounds) -> Self::Range<'_> { + self.0[..self.1].range(range) } +} - fn service_subtype_fqdn( - &self, - service_subtype: &str, - suffix: bool, - ) -> Result { - let suffix = if suffix { "." } else { "" }; - - let mut service_subtype_fqdn = heapless07::String::<40>::new(); - write!( - service_subtype_fqdn, - "{}._sub.{}.{}.local{}", - service_subtype, self.service, self.protocol, suffix, - ) - .unwrap(); - - Dname::>::from_chars(service_subtype_fqdn.chars()) - } - - fn dns_sd_fqdn(suffix: bool) -> Result { - Dname::>::from_chars( - if suffix { - "_services._dns-sd._udp.local." - } else { - "_services._dns-sd._udp.local" - } - .chars(), - ) +impl<'a> FromBuilder for Buf<'a> { + type Builder = Buf<'a>; + + fn from_builder(builder: Self::Builder) -> Self { + Buf(&mut builder.0[builder.1..], 0) } } -struct Buf<'a>(pub &'a mut [u8], pub usize); - impl<'a> Composer for Buf<'a> {} impl<'a> OctetsBuilder for Buf<'a> { @@ -518,3 +371,647 @@ impl<'a> AsRef<[u8]> for Buf<'a> { &self.0[..self.1] } } + +/// Type of request for `MdnsHandler::handle`. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum MdnsRequest<'a> { + /// No incoming mDNS request. Send a broadcast message + None, + /// Incoming mDNS request + Request { + /// Whether it is a legacy request (i.e. UDP packet source port is not 5353, as per spec) + legacy: bool, + /// Whether the request arrived on the multicast address + multicast: bool, + /// The data of the request + data: &'a [u8], + }, +} + +/// Return type for `MdnsHandler::handle`. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum MdnsResponse<'a> { + None, + Reply { data: &'a [u8], delay: bool }, +} + +/// A trait that abstracts the processing logic for an incoming mDNS message. +/// +/// Handles an incoming mDNS message by parsing it and potentially preparing a response. +/// +/// If request is `None`, the handler should prepare a broadcast message with +/// all its data (i.e. mDNS responder brodcasts on internal state changes). +/// +/// Returns an `MdnsResponse` instance that instructs the caller +/// what data to send as a response (if any) and whether to generate a random delay +/// before sending (as per spec). +pub trait MdnsHandler { + fn handle<'a>( + &mut self, + request: MdnsRequest<'_>, + response_buf: &'a mut [u8], + ) -> Result, MdnsError>; +} + +impl MdnsHandler for &mut T +where + T: MdnsHandler, +{ + fn handle<'a>( + &mut self, + request: MdnsRequest<'_>, + response_buf: &'a mut [u8], + ) -> Result, MdnsError> { + (**self).handle(request, response_buf) + } +} + +/// A structure representing a handler that does not do any processing. +/// +/// Useful only when chaining multiple `MdnsHandler` instances. +pub struct NoHandler; + +impl NoHandler { + /// Chains a `NoHandler` with another handler. + pub fn chain(self, handler: T) -> ChainedHandler { + ChainedHandler::new(handler, self) + } +} + +impl MdnsHandler for NoHandler { + fn handle<'a>( + &mut self, + _request: MdnsRequest<'_>, + _response_buf: &'a mut [u8], + ) -> Result, MdnsError> { + Ok(MdnsResponse::None) + } +} + +/// A composite handler that chains two handlers together. +pub struct ChainedHandler { + first: T, + second: U, +} + +impl ChainedHandler { + /// Create a new `ChainedHandler` instance from two handlers. + pub const fn new(first: T, second: U) -> Self { + Self { first, second } + } + + /// Chains a `ChainedHandler` with another handler, + /// where our instance would be the first one to be called. + /// + /// Chaining works by calling each chained handler from the first to the last, + /// until a handler in the chain returns a non-zero `usize` result. + /// + /// Once that happens, traversing the handlers down the chain stops. + pub fn chain(self, handler: V) -> ChainedHandler { + ChainedHandler::new(handler, self) + } +} + +impl MdnsHandler for ChainedHandler +where + T: MdnsHandler, + U: MdnsHandler, +{ + fn handle<'a>( + &mut self, + request: MdnsRequest<'_>, + response_buf: &'a mut [u8], + ) -> Result, MdnsError> { + match self.first.handle(request.clone(), response_buf)? { + MdnsResponse::None => self.second.handle(request, response_buf), + MdnsResponse::Reply { data, delay } => { + let len = data.len(); + + Ok(MdnsResponse::Reply { + data: &response_buf[..len], + delay, + }) + } + } + } +} + +/// A type alias for the answer which is expected to be returned by instances +/// implementing the `HostAnswers` trait. +pub type HostAnswer<'a> = + Record, RecordDataChain, AllRecordData<&'a [u8], NameSlice<'a>>>>; + +/// A trait that abstracts the logic for providing answers to incoming mDNS queries. +/// +/// The visitor-pattern-with-a-callback is chosen on purpose, as that allows `domain` +/// Names to be constructed on-the-fly, possibly without interim buffer allocations. +/// +/// Look at the implementation of `HostAnswers` for `host::Host` and `host::Service` +/// for examples of this technique. +pub trait HostAnswers { + /// Visits an entity that does have answers to mDNS queries. + /// + /// The answers will be provided to the supplied `f` callback. + /// + /// Note that the entity should provide ALL of its answers, regardless of the + /// concrete questions. + /// + /// The filtering of the answers relevant for the asked questions is done by the caller, + /// and only if necessary (i.e. only if these answers are used to reply to a concrete mDNS query, + /// rather than just broadcasting all answers that the entity has, which is also a valid mDNS + /// operation, that should be done when the entity providing answers has changes ot its internal state). + fn visit(&self, f: F) -> Result<(), E> + where + F: FnMut(HostAnswer) -> Result<(), E>, + E: From; +} + +impl HostAnswers for &T +where + T: HostAnswers, +{ + fn visit(&self, f: F) -> Result<(), E> + where + F: FnMut(HostAnswer) -> Result<(), E>, + E: From, + { + (*self).visit(f) + } +} + +impl HostAnswers for &mut T +where + T: HostAnswers, +{ + fn visit(&self, f: F) -> Result<(), E> + where + F: FnMut(HostAnswer) -> Result<(), E>, + E: From, + { + (**self).visit(f) + } +} + +/// A type alias for the question which is expected to be returned by instances +/// implementing the `HostQuestions` trait. +pub type HostQuestion<'a> = Question>; + +/// A trait that abstracts the logic for providing questions to outgoing mDNS queries. +/// +/// The visitor-pattern-with-a-callback is chosen on purpose, as that allows `domain` +/// Names to be constructed on-the-fly, possibly without interim buffer allocations. +pub trait HostQuestions { + /// Visits an entity that does have questions. + /// + /// The questions will be provided to the supplied `f` callback. + fn visit(&self, f: F) -> Result<(), E> + where + F: FnMut(HostQuestion) -> Result<(), E>, + E: From; + + /// A function that constructs an mDNS query message in a `&mut [u8]` buffer + /// using questions generated by this trait. + fn query(&self, id: u16, buf: &mut [u8]) -> Result { + let buf = Buf(buf, 0); + + let mut mb = MessageBuilder::from_target(buf)?; + + set_header(&mut mb, id, false); + + let mut qb = mb.question(); + + let mut pushed = false; + + self.visit(|question| { + qb.push(question)?; + + pushed = true; + + Ok::<_, MdnsError>(()) + })?; + + let buf = qb.finish(); + + if pushed { + Ok(buf.1) + } else { + Ok(0) + } + } +} + +impl HostQuestions for &T +where + T: HostQuestions, +{ + fn visit(&self, f: F) -> Result<(), E> + where + F: FnMut(HostQuestion) -> Result<(), E>, + E: From, + { + (*self).visit(f) + } +} + +impl HostQuestions for &mut T +where + T: HostQuestions, +{ + fn visit(&self, f: F) -> Result<(), E> + where + F: FnMut(HostQuestion) -> Result<(), E>, + E: From, + { + (**self).visit(f) + } +} + +/// A structure modeling an entity that does not generate any questions. +/// +/// Useful only when chaining multiple `HostQuestions` instances. +pub struct NoHostQuestions; + +impl NoHostQuestions { + /// Chains a `HostQuestions` with another `HostAnswers` instance. + pub fn chain(self, questions: T) -> ChainedHostQuestions { + ChainedHostQuestions::new(questions, self) + } +} + +impl HostQuestions for NoHostQuestions { + fn visit(&self, _f: F) -> Result<(), E> + where + F: FnMut(HostQuestion) -> Result<(), E>, + { + Ok(()) + } +} + +/// A composite `HostQuestions` that chains two `HostQuestions` instances together. +pub struct ChainedHostQuestions { + first: T, + second: U, +} + +impl ChainedHostQuestions { + /// Create a new `ChainedHostQuestions` instance from two `HostQuestions` instances. + pub const fn new(first: T, second: U) -> Self { + Self { first, second } + } + + /// Chains this instance with another `HostQuestions` instance, + pub fn chain(self, answers: V) -> ChainedHostQuestions { + ChainedHostQuestions::new(answers, self) + } +} + +impl HostQuestions for ChainedHostQuestions +where + T: HostQuestions, + U: HostQuestions, +{ + fn visit(&self, mut f: F) -> Result<(), E> + where + F: FnMut(HostQuestion) -> Result<(), E>, + E: From, + { + self.first.visit(&mut f)?; + self.second.visit(f) + } +} + +/// A structure modeling an entity that does not generate any answers. +/// +/// Useful only when chaining multiple `HostAnswers` instances. +pub struct NoHostAnswers; + +impl NoHostAnswers { + /// Chains a `NoHostAnswers` with another `HostAnswers` instance. + pub fn chain(self, answers: T) -> ChainedHostAnswers { + ChainedHostAnswers::new(answers, self) + } +} + +impl HostAnswers for NoHostAnswers { + fn visit(&self, _f: F) -> Result<(), E> + where + F: FnMut(HostAnswer) -> Result<(), E>, + { + Ok(()) + } +} + +/// A composite `HostAnswers` that chains two `HostAnswers` instances together. +pub struct ChainedHostAnswers { + first: T, + second: U, +} + +impl ChainedHostAnswers { + /// Create a new `ChainedHostAnswers` instance from two `HostAnswers` instances. + pub const fn new(first: T, second: U) -> Self { + Self { first, second } + } + + /// Chains this instance with another `HostAnswers` instance, + pub fn chain(self, answers: V) -> ChainedHostAnswers { + ChainedHostAnswers::new(answers, self) + } +} + +impl HostAnswers for ChainedHostAnswers +where + T: HostAnswers, + U: HostAnswers, +{ + fn visit(&self, mut f: F) -> Result<(), E> + where + F: FnMut(HostAnswer) -> Result<(), E>, + E: From, + { + self.first.visit(&mut f)?; + self.second.visit(f) + } +} + +/// An `MdnsHandler` implementation that answers mDNS queries with the answers +/// provided by an entity implementing the `HostAnswers` trait. +/// +/// Typically, this structure will be used to provide answers to other peers that broadcast +/// mDNS queries - i.e. this is the "responder" aspect of the mDNS protocol. +pub struct HostAnswersMdnsHandler { + answers: T, +} + +impl HostAnswersMdnsHandler { + /// Create a new `HostAnswersMdnsHandler` instance from an entity that provides answers. + pub const fn new(answers: T) -> Self { + Self { answers } + } +} + +impl MdnsHandler for HostAnswersMdnsHandler +where + T: HostAnswers, +{ + fn handle<'a>( + &mut self, + request: MdnsRequest<'_>, + response_buf: &'a mut [u8], + ) -> Result, MdnsError> { + let buf = Buf(response_buf, 0); + + let mut mb = MessageBuilder::from_target(buf)?; + + let mut pushed = false; + + let buf = if let MdnsRequest::Request { legacy, data, .. } = request { + let message = Message::from_octets(data)?; + + if !matches!(message.header().opcode(), Opcode::QUERY) + || !matches!(message.header().rcode(), Rcode::NOERROR) + || message.header().qr() + // Not a query but a response + { + return Ok(MdnsResponse::None); + } + + let mut ab = if legacy { + set_header(&mut mb, message.header().id(), true); + + let mut qb = mb.question(); + + // As per spec, for legacy requests we need to fill-in the questions section + for question in message.question() { + qb.push(question?)?; + } + + qb.answer() + } else { + set_header(&mut mb, 0, true); + + mb.answer() + }; + + let mut additional_a = false; + let mut additional_srv_txt = false; + + for question in message.question() { + let question = question?; + + self.answers.visit(|answer| { + if matches!(answer.data(), RecordDataChain::Next(AllRecordData::Srv(_))) { + additional_a = true; + } + + if !answer.owner().name_eq(&DNS_SD_OWNER) + && matches!(answer.data(), RecordDataChain::Next(AllRecordData::Ptr(_))) + { + additional_a = true; + + // Over-simplifying here in that we'll send all our SRV and TXT records, however + // sending only some SRV and PTR records is too complex to implement. + additional_srv_txt = true; + } + + if question.qname().name_eq(&answer.owner()) { + debug!("Answering question [{question}] with: [{answer}]"); + + ab.push(answer)?; + + pushed = true; + } + + Ok::<_, MdnsError>(()) + })?; + } + + if additional_a || additional_srv_txt { + // Fill-in the additional section as well + + let mut aa = ab.additional(); + + self.answers.visit(|answer| { + if matches!( + answer.data(), + RecordDataChain::Next(AllRecordData::A(_)) + | RecordDataChain::Next(AllRecordData::Aaaa(_)) + | RecordDataChain::Next(AllRecordData::Srv(_)) + | RecordDataChain::Next(AllRecordData::Txt(_)) + | RecordDataChain::This(Txt(_)) + ) { + debug!("Additional answer: [{answer}]"); + + aa.push(answer)?; + + pushed = true; + } + + Ok::<_, MdnsError>(()) + })?; + + aa.finish() + } else { + ab.finish() + } + } else { + set_header(&mut mb, 0, true); + + let mut ab = mb.answer(); + + self.answers.visit(|answer| { + ab.push(answer)?; + + pushed = true; + + Ok::<_, MdnsError>(()) + })?; + + ab.finish() + }; + + if pushed { + Ok(MdnsResponse::Reply { + data: &buf.0[..buf.1], + delay: false, + }) + } else { + Ok(MdnsResponse::None) + } + } +} + +/// A type alias for the answer which is expected to be returned by instances +/// implementing the `PeerAnswers` trait. +pub type PeerAnswer<'a> = + Record, AllRecordData<&'a [u8], ParsedName<&'a [u8]>>>; + +/// A trait that abstracts the logic for processing answers from incoming mDNS queries. +/// +/// Rather than dealing with the whole mDNS message, this trait is focused on processing +/// the answers from the message (in the `answer` and `additional` mDNS message sections). +pub trait PeerAnswers { + /// Processes the answers from an incoming mDNS message. + fn answers<'a, T, A>(&self, answers: T, additional: A) -> Result<(), MdnsError> + where + T: IntoIterator, MdnsError>> + Clone + 'a, + A: IntoIterator, MdnsError>> + Clone + 'a; +} + +impl PeerAnswers for &mut T +where + T: PeerAnswers, +{ + fn answers<'a, U, V>(&self, answers: U, additional: V) -> Result<(), MdnsError> + where + U: IntoIterator, MdnsError>> + Clone + 'a, + V: IntoIterator, MdnsError>> + Clone + 'a, + { + (**self).answers(answers, additional) + } +} + +impl PeerAnswers for &T +where + T: PeerAnswers, +{ + fn answers<'a, U, V>(&self, answers: U, additional: V) -> Result<(), MdnsError> + where + U: IntoIterator, MdnsError>> + Clone + 'a, + V: IntoIterator, MdnsError>> + Clone + 'a, + { + (*self).answers(answers, additional) + } +} + +/// A structure implementing the `MdnsHandler` trait by processing all answers from an +/// incoming mDNS message via delegating to an entity implementing the `PeerAnswers` trait. +/// +/// Typically, this structure will be used to process answers which are replies to mDNS +/// queries that we have sent using the `HostQuestions::query` method, i.e., this is the +/// "querying" part of the mDNS protocol. +/// +/// Since the "querying" aspect of the mDNS protocol is modeled here, this handler never +/// answers anything, i.e. it always returns a 0 `usize`, because - unlike the +/// `HostAnswersMdnsHandler` - it does not have any answers to provide, as it - itself - +/// processes answers provided by peers, which were themselves sent because we issued a query +/// using e.g. the `HostQuestions::query` method at an earlier point in time. +pub struct PeerAnswersMdnsHandler { + answers: T, +} + +impl PeerAnswersMdnsHandler { + /// Create a new `PeerAnswersMdnsHandler` instance from an entity that processes answers. + pub const fn new(answers: T) -> Self { + Self { answers } + } +} + +impl MdnsHandler for PeerAnswersMdnsHandler +where + T: PeerAnswers, +{ + fn handle<'a>( + &mut self, + request: MdnsRequest<'_>, + _response_buf: &'a mut [u8], + ) -> Result, MdnsError> { + let MdnsRequest::Request { data, legacy, .. } = request else { + return Ok(MdnsResponse::None); + }; + + if legacy { + // Legacy packets should not contain mDNS answers anyway, per spec + return Ok(MdnsResponse::None); + } + + let message = Message::from_octets(data)?; + + if !matches!(message.header().opcode(), Opcode::QUERY) + || !matches!(message.header().rcode(), Rcode::NOERROR) + || !message.header().qr() + // Not a response but a query + { + return Ok(MdnsResponse::None); + } + + let answers = message.answer()?; + let additional = message.additional()?; + + let answers = answers.filter_map(|answer| { + match answer { + Ok(answer) => answer.into_record::>(), + Err(e) => Err(e), + } + .map_err(|_| MdnsError::InvalidMessage) + .transpose() + }); + + let additional = additional.filter_map(|answer| { + match answer { + Ok(answer) => answer.into_record::>(), + Err(e) => Err(e), + } + .map_err(|_| MdnsError::InvalidMessage) + .transpose() + }); + + self.answers.answers(answers, additional)?; + + Ok(MdnsResponse::None) + } +} + +/// Utility function that sets the header of an mDNS `domain` message builder +/// to be a response or a query. +pub fn set_header(answer: &mut MessageBuilder, id: u16, response: bool) { + let header = answer.header_mut(); + header.set_id(id); + header.set_opcode(Opcode::QUERY); + header.set_rcode(Rcode::NOERROR); + + let mut flags = Flags::new(); + flags.qr = response; + flags.aa = response; + header.set_flags(flags); +} diff --git a/examples/mdns_responder.rs b/examples/mdns_responder.rs index d385b4c..532a6a1 100644 --- a/examples/mdns_responder.rs +++ b/examples/mdns_responder.rs @@ -1,11 +1,18 @@ -use core::net::Ipv4Addr; +use core::net::{Ipv4Addr, Ipv6Addr}; -use edge_mdns::io::{self, MdnsIoError, MdnsRunBuffers, DEFAULT_SOCKET}; -use edge_mdns::Host; -use edge_nal::{MulticastV4, MulticastV6, UdpBind, UdpSplit}; +use edge_mdns::buf::{BufferAccess, VecBufAccess}; +use edge_mdns::domain::base::Ttl; +use edge_mdns::io::{self, MdnsIoError, DEFAULT_SOCKET}; +use edge_mdns::{host::Host, HostAnswersMdnsHandler}; +use edge_nal::{UdpBind, UdpSplit}; + +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embassy_sync::signal::Signal; use log::*; +use rand::{thread_rng, RngCore}; + // Change this to the IP address of the machine where you'll run this example const OUR_IP: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1); @@ -18,45 +25,56 @@ fn main() { let stack = edge_nal_std::Stack::new(); - let mut buffers = MdnsRunBuffers::new(); + let (recv_buf, send_buf) = ( + VecBufAccess::::new(), + VecBufAccess::::new(), + ); - futures_lite::future::block_on(run::( - &stack, - &mut buffers, - OUR_NAME, - OUR_IP, + futures_lite::future::block_on(run::( + &stack, &recv_buf, &send_buf, OUR_NAME, OUR_IP, )) .unwrap(); } -async fn run( +async fn run( stack: &T, - buffers: &mut MdnsRunBuffers, + recv_buf: RB, + send_buf: SB, our_name: &str, our_ip: Ipv4Addr, ) -> Result<(), MdnsIoError> where T: UdpBind, - for<'a> ::Socket<'a>: - MulticastV4 + MulticastV6 + UdpSplit, + RB: BufferAccess<[u8]>, + SB: BufferAccess<[u8]>, { info!("About to run an mDNS responder for our PC. It will be addressable using {our_name}.local, so try to `ping {our_name}.local`."); + let mut socket = io::bind(stack, DEFAULT_SOCKET, Some(Ipv4Addr::UNSPECIFIED), Some(0)).await?; + + let (recv, send) = socket.split(); + let host = Host { - id: 0, hostname: our_name, - ip: our_ip.octets(), - ipv6: None, + ipv4: our_ip, + ipv6: Ipv6Addr::UNSPECIFIED, + ttl: Ttl::from_secs(60), }; - io::run( - &host, + // A way to notify the mDNS responder that the data in `Host` had changed + // We don't use it in this example, because the data is hard-coded + let signal = Signal::new(); + + let mdns = io::Mdns::::new( Some(Ipv4Addr::UNSPECIFIED), Some(0), - [], - stack, - DEFAULT_SOCKET, - buffers, - ) - .await + recv, + send, + recv_buf, + send_buf, + |buf| thread_rng().fill_bytes(buf), + &signal, + ); + + mdns.run(HostAnswersMdnsHandler::new(&host)).await }